Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os, re, types, traceback, torch, gradio as gr
|
| 2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 3 |
from IndicTransToolkit import IndicProcessor
|
|
|
|
| 4 |
|
| 5 |
# --------------------- Device ---------------------
|
| 6 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -13,64 +14,18 @@ TE_CODE = "tel_Telu"
|
|
| 13 |
|
| 14 |
ip = IndicProcessor(inference=True)
|
| 15 |
|
| 16 |
-
# ---------------------
|
| 17 |
-
|
| 18 |
-
r"(?:_src\S+)|(?:tgt\S+)|"
|
| 19 |
-
r"(?:>>\s*\S+\s*<<)|"
|
| 20 |
-
r"\b(?:eng_Latn|hin_Deva|hin_deva|tel_Telu|tel_telu)\b|"
|
| 21 |
-
r"<ID\d*>"
|
| 22 |
-
)
|
| 23 |
|
| 24 |
-
def
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def ensure_hindi_danda(s: str) -> str:
|
| 29 |
-
s = re.sub(r"\.\s*$", "।", s)
|
| 30 |
-
if not re.search(r"[।?!…]\s*$", s) and re.search(r"[\u0900-\u097F]\s*$", s):
|
| 31 |
-
s += "।"
|
| 32 |
-
return s
|
| 33 |
-
|
| 34 |
-
# Sentence splitting (pysbd or fallback)
|
| 35 |
-
try:
|
| 36 |
-
import pysbd
|
| 37 |
-
_SEGMENTER = pysbd.Segmenter(language="en", clean=True)
|
| 38 |
-
except Exception:
|
| 39 |
-
_SEGMENTER = None
|
| 40 |
-
|
| 41 |
-
_LEGAL_JOIN_RE = re.compile(r'\b([A-Za-z]{1,6})\.\s*$')
|
| 42 |
-
_NEXT_CONT_RE = re.compile(r'^\s*(?:[\(\[\{]|\d|[a-z])')
|
| 43 |
-
|
| 44 |
-
def _merge_legal_abbrev_breaks(sents):
|
| 45 |
-
merged, i = [], 0
|
| 46 |
-
while i < len(sents):
|
| 47 |
-
cur = sents[i].strip()
|
| 48 |
-
while i + 1 < len(sents):
|
| 49 |
-
nxt = sents[i + 1].lstrip()
|
| 50 |
-
if _LEGAL_JOIN_RE.search(cur) and _NEXT_CONT_RE.match(nxt):
|
| 51 |
-
cur = f"{cur} {nxt}"
|
| 52 |
-
i += 1
|
| 53 |
-
else:
|
| 54 |
-
break
|
| 55 |
-
merged.append(cur)
|
| 56 |
-
i += 1
|
| 57 |
-
return [s for s in merged if s]
|
| 58 |
-
|
| 59 |
-
def split_into_sentences(text: str):
|
| 60 |
-
if _SEGMENTER is not None:
|
| 61 |
-
return _merge_legal_abbrev_breaks(_SEGMENTER.segment(text))
|
| 62 |
-
PLACEHOLDER = "\uE000"
|
| 63 |
-
protected = re.sub(
|
| 64 |
-
r'\b([A-Za-z]{1,6})\.(?=\s*(?:[\(\[\{]|\d|[a-z]))',
|
| 65 |
-
r'\1' + PLACEHOLDER, text.strip()
|
| 66 |
-
)
|
| 67 |
-
protected = re.sub(
|
| 68 |
-
r'\b([A-Za-z]{1,5})\.(?=\s+[A-Z])',
|
| 69 |
-
r'\1' + PLACEHOLDER, protected
|
| 70 |
-
)
|
| 71 |
-
parts = re.split(r'(?<=[.?!])\s+', protected)
|
| 72 |
-
return _merge_legal_abbrev_breaks([p.replace(PLACEHOLDER, '.') for p in parts if p.strip()])
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# --------------------- Model Loader ---------------------
|
| 76 |
MODELS = {
|
|
@@ -92,10 +47,10 @@ def load_model(model_name: str):
|
|
| 92 |
)
|
| 93 |
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
| 94 |
model_name, trust_remote_code=True,
|
| 95 |
-
low_cpu_mem_usage=True, dtype=dtype, token
|
| 96 |
).to(device).eval()
|
| 97 |
|
| 98 |
-
# Fix vocab
|
| 99 |
try:
|
| 100 |
mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
|
| 101 |
except Exception:
|
|
@@ -104,100 +59,63 @@ def load_model(model_name: str):
|
|
| 104 |
_model_cache[model_name] = (tok, mdl)
|
| 105 |
return tok, mdl
|
| 106 |
|
| 107 |
-
|
| 108 |
-
def build_bad_words_ids_from_vocab(tok):
|
| 109 |
-
vocab = tok.get_vocab()
|
| 110 |
-
candidates = [
|
| 111 |
-
"eng_Latn","hin_Deva","hin_deva","tel_Telu","tel_telu",
|
| 112 |
-
"_srceng_Latn","tgthin_Deva","tgt_tel_Telu",
|
| 113 |
-
">>hin_Deva<<",">>tel_Telu<<",
|
| 114 |
-
] + [f"<ID{i}>" for i in range(10)]
|
| 115 |
-
out = []
|
| 116 |
-
for c in candidates:
|
| 117 |
-
if c in vocab:
|
| 118 |
-
out.append([vocab[c]])
|
| 119 |
-
continue
|
| 120 |
-
sp_c = "▁" + c
|
| 121 |
-
if sp_c in vocab:
|
| 122 |
-
out.append([vocab[sp_c]])
|
| 123 |
-
return out
|
| 124 |
-
|
| 125 |
-
|
| 126 |
# --------------------- Streaming Translation ---------------------
|
| 127 |
-
BATCH_SIZE = 6
|
| 128 |
-
|
| 129 |
@torch.inference_mode()
|
| 130 |
def translate_dual_stream(text, model_choice, num_beams, max_new):
|
| 131 |
-
"""
|
| 132 |
-
Generator that yields (hindi_accumulated_text, telugu_accumulated_text)
|
| 133 |
-
after each processed batch so the UI updates progressively.
|
| 134 |
-
"""
|
| 135 |
if not text or not text.strip():
|
| 136 |
yield "", ""
|
| 137 |
return
|
| 138 |
|
| 139 |
-
# Prepare once
|
| 140 |
tok, mdl = load_model(MODELS[model_choice])
|
| 141 |
-
BAD_WORDS_IDS = build_bad_words_ids_from_vocab(tok)
|
| 142 |
sentences = split_into_sentences(text)
|
| 143 |
-
|
| 144 |
hi_acc, te_acc = [], []
|
| 145 |
|
| 146 |
-
#
|
| 147 |
yield "", ""
|
| 148 |
|
| 149 |
-
for i in
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
# --- Hindi batch ---
|
| 153 |
try:
|
| 154 |
-
|
| 155 |
-
enc_hi
|
| 156 |
-
|
| 157 |
-
).to(device)
|
| 158 |
-
out_hi = mdl.generate(
|
| 159 |
**enc_hi,
|
| 160 |
-
max_length=max_new,
|
| 161 |
num_beams=int(num_beams),
|
|
|
|
| 162 |
early_stopping=True,
|
| 163 |
no_repeat_ngram_size=3,
|
| 164 |
-
use_cache=False
|
| 165 |
-
bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
|
| 166 |
)
|
| 167 |
-
dec_hi
|
| 168 |
-
dec_hi = [strip_lang_tags(t) for t in dec_hi]
|
| 169 |
post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
|
| 170 |
-
post_hi
|
| 171 |
-
hi_acc.extend(p.strip() for p in post_hi)
|
| 172 |
except Exception as e:
|
| 173 |
-
hi_acc.append(f"⚠️ Hindi failed (
|
| 174 |
|
| 175 |
-
# --- Telugu
|
| 176 |
try:
|
| 177 |
-
|
| 178 |
-
enc_te
|
| 179 |
-
|
| 180 |
-
).to(device)
|
| 181 |
-
out_te = mdl.generate(
|
| 182 |
**enc_te,
|
| 183 |
-
max_length=max_new,
|
| 184 |
num_beams=int(num_beams),
|
|
|
|
| 185 |
early_stopping=True,
|
| 186 |
no_repeat_ngram_size=3,
|
| 187 |
-
use_cache=False
|
| 188 |
-
bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
|
| 189 |
)
|
| 190 |
-
dec_te
|
| 191 |
-
dec_te = [strip_lang_tags(t) for t in dec_te]
|
| 192 |
post_te = ip.postprocess_batch(dec_te, lang=TE_CODE)
|
| 193 |
-
te_acc.
|
| 194 |
except Exception as e:
|
| 195 |
-
te_acc.append(f"⚠️ Telugu failed (
|
| 196 |
|
| 197 |
-
# Stream
|
| 198 |
yield (" ".join(hi_acc), " ".join(te_acc))
|
| 199 |
|
| 200 |
-
|
| 201 |
# --------------------- Dark Theme ---------------------
|
| 202 |
THEME = gr.themes.Soft(
|
| 203 |
primary_hue="blue", neutral_hue="slate"
|
|
@@ -267,13 +185,13 @@ button { border-radius:8px !important; font-weight:600 !important; }
|
|
| 267 |
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
|
| 268 |
with gr.Group(elem_id="hdr"):
|
| 269 |
gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
|
| 270 |
-
gr.Markdown("<p>IndicTrans2 with
|
| 271 |
|
| 272 |
model_choice = gr.Dropdown(
|
| 273 |
label="Choose Model",
|
| 274 |
choices=list(MODELS.keys()),
|
| 275 |
value="Default (Public)",
|
| 276 |
-
elem_id="model_dd"
|
| 277 |
)
|
| 278 |
|
| 279 |
with gr.Row():
|
|
@@ -296,10 +214,10 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as
|
|
| 296 |
with gr.Column(scale=1):
|
| 297 |
with gr.Group(elem_classes="panel"):
|
| 298 |
gr.Markdown("<h2>Settings</h2>")
|
| 299 |
-
num_beams
|
| 300 |
-
max_new
|
| 301 |
|
| 302 |
-
#
|
| 303 |
translate_btn.click(
|
| 304 |
translate_dual_stream,
|
| 305 |
inputs=[src, model_choice, num_beams, max_new],
|
|
|
|
| 1 |
import os, re, types, traceback, torch, gradio as gr
|
| 2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 3 |
from IndicTransToolkit import IndicProcessor
|
| 4 |
+
import spacy
|
| 5 |
|
| 6 |
# --------------------- Device ---------------------
|
| 7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 14 |
|
| 15 |
ip = IndicProcessor(inference=True)
|
| 16 |
|
| 17 |
+
# --------------------- Sentence Splitting (spaCy) ---------------------
|
| 18 |
+
nlp = spacy.load("en_core_web_sm")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
def split_into_sentences(text):
|
| 21 |
+
"""Split English text into sentences using spaCy."""
|
| 22 |
+
doc = nlp(text.strip())
|
| 23 |
+
return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
# --------------------- Cleanup Helper ---------------------
|
| 26 |
+
def clean_translation(text):
|
| 27 |
+
"""Remove unresolved placeholder tags such as <ID1>, <ID2>."""
|
| 28 |
+
return re.sub(r"<ID\d+>", "", text).strip()
|
| 29 |
|
| 30 |
# --------------------- Model Loader ---------------------
|
| 31 |
MODELS = {
|
|
|
|
| 47 |
)
|
| 48 |
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
| 49 |
model_name, trust_remote_code=True,
|
| 50 |
+
low_cpu_mem_usage=True, dtype=dtype, token=token
|
| 51 |
).to(device).eval()
|
| 52 |
|
| 53 |
+
# Fix vocab mismatch if any
|
| 54 |
try:
|
| 55 |
mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
|
| 56 |
except Exception:
|
|
|
|
| 59 |
_model_cache[model_name] = (tok, mdl)
|
| 60 |
return tok, mdl
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
# --------------------- Streaming Translation ---------------------
|
|
|
|
|
|
|
| 63 |
@torch.inference_mode()
|
| 64 |
def translate_dual_stream(text, model_choice, num_beams, max_new):
|
| 65 |
+
"""Generator that yields progressive Hindi & Telugu translations one sentence at a time."""
|
|
|
|
|
|
|
|
|
|
| 66 |
if not text or not text.strip():
|
| 67 |
yield "", ""
|
| 68 |
return
|
| 69 |
|
|
|
|
| 70 |
tok, mdl = load_model(MODELS[model_choice])
|
|
|
|
| 71 |
sentences = split_into_sentences(text)
|
|
|
|
| 72 |
hi_acc, te_acc = [], []
|
| 73 |
|
| 74 |
+
# Yield empty for immediate UI update
|
| 75 |
yield "", ""
|
| 76 |
|
| 77 |
+
for i, sentence in enumerate(sentences, 1):
|
| 78 |
+
# --- Hindi Translation ---
|
|
|
|
|
|
|
| 79 |
try:
|
| 80 |
+
batch_hi = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=HI_CODE)
|
| 81 |
+
enc_hi = tok(batch_hi, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
|
| 82 |
+
out_hi = mdl.generate(
|
|
|
|
|
|
|
| 83 |
**enc_hi,
|
| 84 |
+
max_length=int(max_new),
|
| 85 |
num_beams=int(num_beams),
|
| 86 |
+
do_sample=False,
|
| 87 |
early_stopping=True,
|
| 88 |
no_repeat_ngram_size=3,
|
| 89 |
+
use_cache=False
|
|
|
|
| 90 |
)
|
| 91 |
+
dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
|
|
| 92 |
post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
|
| 93 |
+
hi_acc.append(clean_translation(post_hi[0]))
|
|
|
|
| 94 |
except Exception as e:
|
| 95 |
+
hi_acc.append(f"⚠️ Hindi failed (sentence {i}): {e}")
|
| 96 |
|
| 97 |
+
# --- Telugu Translation ---
|
| 98 |
try:
|
| 99 |
+
batch_te = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=TE_CODE)
|
| 100 |
+
enc_te = tok(batch_te, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
|
| 101 |
+
out_te = mdl.generate(
|
|
|
|
|
|
|
| 102 |
**enc_te,
|
| 103 |
+
max_length=int(max_new),
|
| 104 |
num_beams=int(num_beams),
|
| 105 |
+
do_sample=False,
|
| 106 |
early_stopping=True,
|
| 107 |
no_repeat_ngram_size=3,
|
| 108 |
+
use_cache=False
|
|
|
|
| 109 |
)
|
| 110 |
+
dec_te = tok.batch_decode(out_te, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
|
|
| 111 |
post_te = ip.postprocess_batch(dec_te, lang=TE_CODE)
|
| 112 |
+
te_acc.append(clean_translation(post_te[0]))
|
| 113 |
except Exception as e:
|
| 114 |
+
te_acc.append(f"⚠️ Telugu failed (sentence {i}): {e}")
|
| 115 |
|
| 116 |
+
# Stream progressive output
|
| 117 |
yield (" ".join(hi_acc), " ".join(te_acc))
|
| 118 |
|
|
|
|
| 119 |
# --------------------- Dark Theme ---------------------
|
| 120 |
THEME = gr.themes.Soft(
|
| 121 |
primary_hue="blue", neutral_hue="slate"
|
|
|
|
| 185 |
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
|
| 186 |
with gr.Group(elem_id="hdr"):
|
| 187 |
gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
|
| 188 |
+
gr.Markdown("<p>IndicTrans2 with simplified preprocessing and sentence-wise translation</p>")
|
| 189 |
|
| 190 |
model_choice = gr.Dropdown(
|
| 191 |
label="Choose Model",
|
| 192 |
choices=list(MODELS.keys()),
|
| 193 |
value="Default (Public)",
|
| 194 |
+
elem_id="model_dd"
|
| 195 |
)
|
| 196 |
|
| 197 |
with gr.Row():
|
|
|
|
| 214 |
with gr.Column(scale=1):
|
| 215 |
with gr.Group(elem_classes="panel"):
|
| 216 |
gr.Markdown("<h2>Settings</h2>")
|
| 217 |
+
num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
|
| 218 |
+
max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
|
| 219 |
|
| 220 |
+
# Stream generator connection
|
| 221 |
translate_btn.click(
|
| 222 |
translate_dual_stream,
|
| 223 |
inputs=[src, model_choice, num_beams, max_new],
|