SagarVelamuri commited on
Commit
fffb78c
·
verified ·
1 Parent(s): a0b4a31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -124
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os, re, types, traceback, torch, gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from IndicTransToolkit import IndicProcessor
 
4
 
5
  # --------------------- Device ---------------------
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -13,64 +14,18 @@ TE_CODE = "tel_Telu"
13
 
14
  ip = IndicProcessor(inference=True)
15
 
16
- # --------------------- Regex / Helpers ---------------------
17
- TAG_REGEX = re.compile(
18
- r"(?:_src\S+)|(?:tgt\S+)|"
19
- r"(?:>>\s*\S+\s*<<)|"
20
- r"\b(?:eng_Latn|hin_Deva|hin_deva|tel_Telu|tel_telu)\b|"
21
- r"<ID\d*>"
22
- )
23
 
24
- def strip_lang_tags(text: str) -> str:
25
- s = TAG_REGEX.sub(" ", text)
26
- return re.sub(r"\s{2,}", " ", s).strip()
27
-
28
- def ensure_hindi_danda(s: str) -> str:
29
- s = re.sub(r"\.\s*$", "।", s)
30
- if not re.search(r"[।?!…]\s*$", s) and re.search(r"[\u0900-\u097F]\s*$", s):
31
- s += "।"
32
- return s
33
-
34
- # Sentence splitting (pysbd or fallback)
35
- try:
36
- import pysbd
37
- _SEGMENTER = pysbd.Segmenter(language="en", clean=True)
38
- except Exception:
39
- _SEGMENTER = None
40
-
41
- _LEGAL_JOIN_RE = re.compile(r'\b([A-Za-z]{1,6})\.\s*$')
42
- _NEXT_CONT_RE = re.compile(r'^\s*(?:[\(\[\{]|\d|[a-z])')
43
-
44
- def _merge_legal_abbrev_breaks(sents):
45
- merged, i = [], 0
46
- while i < len(sents):
47
- cur = sents[i].strip()
48
- while i + 1 < len(sents):
49
- nxt = sents[i + 1].lstrip()
50
- if _LEGAL_JOIN_RE.search(cur) and _NEXT_CONT_RE.match(nxt):
51
- cur = f"{cur} {nxt}"
52
- i += 1
53
- else:
54
- break
55
- merged.append(cur)
56
- i += 1
57
- return [s for s in merged if s]
58
-
59
- def split_into_sentences(text: str):
60
- if _SEGMENTER is not None:
61
- return _merge_legal_abbrev_breaks(_SEGMENTER.segment(text))
62
- PLACEHOLDER = "\uE000"
63
- protected = re.sub(
64
- r'\b([A-Za-z]{1,6})\.(?=\s*(?:[\(\[\{]|\d|[a-z]))',
65
- r'\1' + PLACEHOLDER, text.strip()
66
- )
67
- protected = re.sub(
68
- r'\b([A-Za-z]{1,5})\.(?=\s+[A-Z])',
69
- r'\1' + PLACEHOLDER, protected
70
- )
71
- parts = re.split(r'(?<=[.?!])\s+', protected)
72
- return _merge_legal_abbrev_breaks([p.replace(PLACEHOLDER, '.') for p in parts if p.strip()])
73
 
 
 
 
 
74
 
75
  # --------------------- Model Loader ---------------------
76
  MODELS = {
@@ -92,10 +47,10 @@ def load_model(model_name: str):
92
  )
93
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
94
  model_name, trust_remote_code=True,
95
- low_cpu_mem_usage=True, dtype=dtype, token = token
96
  ).to(device).eval()
97
 
98
- # Fix vocab (some HF models have mismatched config.vocab_size)
99
  try:
100
  mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
101
  except Exception:
@@ -104,100 +59,63 @@ def load_model(model_name: str):
104
  _model_cache[model_name] = (tok, mdl)
105
  return tok, mdl
106
 
107
-
108
- def build_bad_words_ids_from_vocab(tok):
109
- vocab = tok.get_vocab()
110
- candidates = [
111
- "eng_Latn","hin_Deva","hin_deva","tel_Telu","tel_telu",
112
- "_srceng_Latn","tgthin_Deva","tgt_tel_Telu",
113
- ">>hin_Deva<<",">>tel_Telu<<",
114
- ] + [f"<ID{i}>" for i in range(10)]
115
- out = []
116
- for c in candidates:
117
- if c in vocab:
118
- out.append([vocab[c]])
119
- continue
120
- sp_c = "▁" + c
121
- if sp_c in vocab:
122
- out.append([vocab[sp_c]])
123
- return out
124
-
125
-
126
  # --------------------- Streaming Translation ---------------------
127
- BATCH_SIZE = 6
128
-
129
  @torch.inference_mode()
130
  def translate_dual_stream(text, model_choice, num_beams, max_new):
131
- """
132
- Generator that yields (hindi_accumulated_text, telugu_accumulated_text)
133
- after each processed batch so the UI updates progressively.
134
- """
135
  if not text or not text.strip():
136
  yield "", ""
137
  return
138
 
139
- # Prepare once
140
  tok, mdl = load_model(MODELS[model_choice])
141
- BAD_WORDS_IDS = build_bad_words_ids_from_vocab(tok)
142
  sentences = split_into_sentences(text)
143
-
144
  hi_acc, te_acc = [], []
145
 
146
- # Clear outputs immediately for a snappy feel
147
  yield "", ""
148
 
149
- for i in range(0, len(sentences), BATCH_SIZE):
150
- batch = sentences[i:i + BATCH_SIZE]
151
-
152
- # --- Hindi batch ---
153
  try:
154
- proc_hi = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=HI_CODE)
155
- enc_hi = tok(
156
- proc_hi, padding=True, truncation=True, max_length=256, return_tensors="pt"
157
- ).to(device)
158
- out_hi = mdl.generate(
159
  **enc_hi,
160
- max_length=max_new, # keep semantics same as your original
161
  num_beams=int(num_beams),
 
162
  early_stopping=True,
163
  no_repeat_ngram_size=3,
164
- use_cache=False,
165
- bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
166
  )
167
- dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True)
168
- dec_hi = [strip_lang_tags(t) for t in dec_hi]
169
  post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
170
- post_hi = [ensure_hindi_danda(x) for x in post_hi]
171
- hi_acc.extend(p.strip() for p in post_hi)
172
  except Exception as e:
173
- hi_acc.append(f"⚠️ Hindi failed (batch {i//BATCH_SIZE+1}): {e}")
174
 
175
- # --- Telugu batch ---
176
  try:
177
- proc_te = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=TE_CODE)
178
- enc_te = tok(
179
- proc_te, padding=True, truncation=True, max_length=256, return_tensors="pt"
180
- ).to(device)
181
- out_te = mdl.generate(
182
  **enc_te,
183
- max_length=max_new,
184
  num_beams=int(num_beams),
 
185
  early_stopping=True,
186
  no_repeat_ngram_size=3,
187
- use_cache=False,
188
- bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
189
  )
190
- dec_te = tok.batch_decode(out_te, skip_special_tokens=True)
191
- dec_te = [strip_lang_tags(t) for t in dec_te]
192
  post_te = ip.postprocess_batch(dec_te, lang=TE_CODE)
193
- te_acc.extend(p.strip() for p in post_te)
194
  except Exception as e:
195
- te_acc.append(f"⚠️ Telugu failed (batch {i//BATCH_SIZE+1}): {e}")
196
 
197
- # Stream the accumulators so far
198
  yield (" ".join(hi_acc), " ".join(te_acc))
199
 
200
-
201
  # --------------------- Dark Theme ---------------------
202
  THEME = gr.themes.Soft(
203
  primary_hue="blue", neutral_hue="slate"
@@ -267,13 +185,13 @@ button { border-radius:8px !important; font-weight:600 !important; }
267
  with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
268
  with gr.Group(elem_id="hdr"):
269
  gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
270
- gr.Markdown("<p>IndicTrans2 with batch sentence decomposition</p>")
271
 
272
  model_choice = gr.Dropdown(
273
  label="Choose Model",
274
  choices=list(MODELS.keys()),
275
  value="Default (Public)",
276
- elem_id="model_dd" # <-- for targeted styling
277
  )
278
 
279
  with gr.Row():
@@ -296,10 +214,10 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as
296
  with gr.Column(scale=1):
297
  with gr.Group(elem_classes="panel"):
298
  gr.Markdown("<h2>Settings</h2>")
299
- num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
300
- max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
301
 
302
- # Use streaming generator
303
  translate_btn.click(
304
  translate_dual_stream,
305
  inputs=[src, model_choice, num_beams, max_new],
 
1
  import os, re, types, traceback, torch, gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from IndicTransToolkit import IndicProcessor
4
+ import spacy
5
 
6
  # --------------------- Device ---------------------
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
14
 
15
  ip = IndicProcessor(inference=True)
16
 
17
+ # --------------------- Sentence Splitting (spaCy) ---------------------
18
+ nlp = spacy.load("en_core_web_sm")
 
 
 
 
 
19
 
20
+ def split_into_sentences(text):
21
+ """Split English text into sentences using spaCy."""
22
+ doc = nlp(text.strip())
23
+ return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # --------------------- Cleanup Helper ---------------------
26
+ def clean_translation(text):
27
+ """Remove unresolved placeholder tags such as <ID1>, <ID2>."""
28
+ return re.sub(r"<ID\d+>", "", text).strip()
29
 
30
  # --------------------- Model Loader ---------------------
31
  MODELS = {
 
47
  )
48
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
49
  model_name, trust_remote_code=True,
50
+ low_cpu_mem_usage=True, dtype=dtype, token=token
51
  ).to(device).eval()
52
 
53
+ # Fix vocab mismatch if any
54
  try:
55
  mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
56
  except Exception:
 
59
  _model_cache[model_name] = (tok, mdl)
60
  return tok, mdl
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # --------------------- Streaming Translation ---------------------
 
 
63
  @torch.inference_mode()
64
  def translate_dual_stream(text, model_choice, num_beams, max_new):
65
+ """Generator that yields progressive Hindi & Telugu translations one sentence at a time."""
 
 
 
66
  if not text or not text.strip():
67
  yield "", ""
68
  return
69
 
 
70
  tok, mdl = load_model(MODELS[model_choice])
 
71
  sentences = split_into_sentences(text)
 
72
  hi_acc, te_acc = [], []
73
 
74
+ # Yield empty for immediate UI update
75
  yield "", ""
76
 
77
+ for i, sentence in enumerate(sentences, 1):
78
+ # --- Hindi Translation ---
 
 
79
  try:
80
+ batch_hi = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=HI_CODE)
81
+ enc_hi = tok(batch_hi, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
82
+ out_hi = mdl.generate(
 
 
83
  **enc_hi,
84
+ max_length=int(max_new),
85
  num_beams=int(num_beams),
86
+ do_sample=False,
87
  early_stopping=True,
88
  no_repeat_ngram_size=3,
89
+ use_cache=False
 
90
  )
91
+ dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True, clean_up_tokenization_spaces=True)
 
92
  post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
93
+ hi_acc.append(clean_translation(post_hi[0]))
 
94
  except Exception as e:
95
+ hi_acc.append(f"⚠️ Hindi failed (sentence {i}): {e}")
96
 
97
+ # --- Telugu Translation ---
98
  try:
99
+ batch_te = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=TE_CODE)
100
+ enc_te = tok(batch_te, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
101
+ out_te = mdl.generate(
 
 
102
  **enc_te,
103
+ max_length=int(max_new),
104
  num_beams=int(num_beams),
105
+ do_sample=False,
106
  early_stopping=True,
107
  no_repeat_ngram_size=3,
108
+ use_cache=False
 
109
  )
110
+ dec_te = tok.batch_decode(out_te, skip_special_tokens=True, clean_up_tokenization_spaces=True)
 
111
  post_te = ip.postprocess_batch(dec_te, lang=TE_CODE)
112
+ te_acc.append(clean_translation(post_te[0]))
113
  except Exception as e:
114
+ te_acc.append(f"⚠️ Telugu failed (sentence {i}): {e}")
115
 
116
+ # Stream progressive output
117
  yield (" ".join(hi_acc), " ".join(te_acc))
118
 
 
119
  # --------------------- Dark Theme ---------------------
120
  THEME = gr.themes.Soft(
121
  primary_hue="blue", neutral_hue="slate"
 
185
  with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
186
  with gr.Group(elem_id="hdr"):
187
  gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
188
+ gr.Markdown("<p>IndicTrans2 with simplified preprocessing and sentence-wise translation</p>")
189
 
190
  model_choice = gr.Dropdown(
191
  label="Choose Model",
192
  choices=list(MODELS.keys()),
193
  value="Default (Public)",
194
+ elem_id="model_dd"
195
  )
196
 
197
  with gr.Row():
 
214
  with gr.Column(scale=1):
215
  with gr.Group(elem_classes="panel"):
216
  gr.Markdown("<h2>Settings</h2>")
217
+ num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
218
+ max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
219
 
220
+ # Stream generator connection
221
  translate_btn.click(
222
  translate_dual_stream,
223
  inputs=[src, model_choice, num_beams, max_new],