Spaces:
Sleeping
Sleeping
| import os, re, torch, gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from IndicTransToolkit import IndicProcessor | |
| import spacy | |
| # --------------------- Device --------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # --------------------- Languages ------------------ | |
| SRC_CODE = "eng_Latn" | |
| HI_CODE = "hin_Deva" | |
| TE_CODE = "tel_Telu" | |
| ip = IndicProcessor(inference=True) | |
| # --------------------- spaCy Sentence Splitter --------------------- | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| except OSError: | |
| from spacy.cli import download | |
| download("en_core_web_sm") | |
| nlp = spacy.load("en_core_web_sm") | |
| def split_into_sentences(text): | |
| """Split English text into sentences using spaCy.""" | |
| doc = nlp(text.strip()) | |
| return [sent.text.strip() for sent in doc.sents if sent.text.strip()] | |
| # --------------------- Abbreviation Expansion --------------------- | |
| ABBREVIATION_MAP = { | |
| "subs.": "subsection", | |
| "cl.": "clause", | |
| "art.": "article", | |
| "sec.": "section", | |
| "s.": "section", | |
| "no.": "number", | |
| "sch.": "schedule", | |
| "para.": "paragraph", | |
| "r.": "rule", | |
| "reg.": "regulation", | |
| "dept.": "department", | |
| } | |
| _ABBR_PATTERN = re.compile( | |
| r'(?<![A-Za-z])(' + '|'.join(re.escape(k) for k in ABBREVIATION_MAP.keys()) + r')(?=\s*(?:\(|\d|[A-Z]|[a-z]))', | |
| flags=re.IGNORECASE | |
| ) | |
| def expand_abbreviations(text: str) -> str: | |
| """Replace known abbreviations with full forms safely (without affecting natural words).""" | |
| def replacer(match): | |
| key = match.group(0) | |
| repl = ABBREVIATION_MAP.get(key.lower(), key) | |
| if key.isupper(): | |
| return repl.upper() | |
| elif key[0].isupper(): | |
| return repl.capitalize() | |
| return repl | |
| return _ABBR_PATTERN.sub(replacer, text) | |
| # --------------------- Clean Up Placeholder Tags --------------------- | |
| def clean_translation(text): | |
| """Remove unresolved placeholder tags such as <ID1>, <ID2>.""" | |
| return re.sub(r"<ID\d+>", "", text).strip() | |
| # --------------------- Model Loader --------------------- | |
| MODELS = { | |
| "Default (Public)": "law-ai/InLegalTrans-En2Indic-1B", | |
| "Fine-tuned (Private)": "SagarVelamuri/InLegalTrans-En2Indic-FineTuned-Tel-Hin" | |
| } | |
| _model_cache = {} | |
| def load_model(model_name: str): | |
| if model_name in _model_cache: | |
| return _model_cache[model_name] | |
| token = os.getenv("hf_token") | |
| tok = AutoTokenizer.from_pretrained( | |
| "ai4bharat/indictrans2-en-indic-1B", | |
| trust_remote_code=True, use_fast=True | |
| ) | |
| mdl = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, trust_remote_code=True, | |
| low_cpu_mem_usage=True, dtype=dtype, token=token | |
| ).to(device).eval() | |
| try: | |
| mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0] | |
| except Exception: | |
| pass | |
| _model_cache[model_name] = (tok, mdl) | |
| return tok, mdl | |
| # --------------------- Translation --------------------- | |
| def translate_dual_stream(text, model_choice, num_beams, max_new): | |
| """Stream Hindi and Telugu translations, one sentence at a time.""" | |
| if not text or not text.strip(): | |
| yield "", "" | |
| return | |
| tok, mdl = load_model(MODELS[model_choice]) | |
| # Expand known abbreviations | |
| text = expand_abbreviations(text) | |
| sentences = split_into_sentences(text) | |
| hi_acc, te_acc = [], [] | |
| yield "", "" # Clear UI early | |
| for i, sentence in enumerate(sentences, 1): | |
| # --- Hindi --- | |
| try: | |
| batch_hi = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=HI_CODE) | |
| enc_hi = tok(batch_hi, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device) | |
| out_hi = mdl.generate( | |
| **enc_hi, | |
| max_length=int(max_new), | |
| num_beams=int(num_beams), | |
| do_sample=False, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3, | |
| use_cache=False | |
| ) | |
| dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE) | |
| hi_text = clean_translation(post_hi[0]) | |
| # Optionally ensure danda for Hindi if missing | |
| if not re.search(r"[।?!…]$", hi_text): | |
| hi_text += "।" | |
| hi_acc.append(hi_text) | |
| except Exception as e: | |
| hi_acc.append(f"⚠️ Hindi failed (sentence {i}): {e}") | |
| # --- Telugu --- | |
| try: | |
| batch_te = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=TE_CODE) | |
| enc_te = tok(batch_te, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device) | |
| out_te = mdl.generate( | |
| **enc_te, | |
| max_length=int(max_new), | |
| num_beams=int(num_beams), | |
| do_sample=False, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3, | |
| use_cache=False | |
| ) | |
| dec_te = tok.batch_decode(out_te, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| post_te = ip.postprocess_batch(dec_te, lang=TE_CODE) | |
| te_acc.append(clean_translation(post_te[0])) | |
| except Exception as e: | |
| te_acc.append(f"⚠️ Telugu failed (sentence {i}): {e}") | |
| yield (" ".join(hi_acc), " ".join(te_acc)) | |
| # --------------------- Dark Theme --------------------- | |
| THEME = gr.themes.Soft( | |
| primary_hue="blue", neutral_hue="slate" | |
| ).set( | |
| body_background_fill="#0b0f19", | |
| body_text_color="#f3f4f6", | |
| block_background_fill="#111827", | |
| block_border_color="#1f2937", | |
| block_title_text_color="#123456", | |
| button_primary_background_fill="#2563eb", | |
| button_primary_text_color="#ffffff", | |
| ) | |
| CUSTOM_CSS = """ | |
| /* Header + Panels */ | |
| #hdr { text-align:center; padding:16px; } | |
| #hdr h1 { font-size:24px; font-weight:700; color:#f9fafb; margin:0; } | |
| #hdr p { font-size:14px; color:#9ca3af; margin-top:4px; } | |
| .panel { border:1px solid #1f2937; border-radius:10px; padding:12px; background:#111827; box-shadow:0 1px 2px rgba(0,0,0,0.4);} | |
| .panel h2 { font-size:16px; font-weight:600; margin-bottom:6px; color:#f3f4f6; } | |
| /* Inputs */ | |
| textarea { background:#0b0f19 !important; color:#f9fafb !important; border-radius:8px !important; border:1px solid #374151 !important; font-size:15px !important; line-height:1.55; } | |
| button { border-radius:8px !important; font-weight:600 !important; } | |
| /* Labels */ | |
| .gradio-container label, | |
| .gradio-container .label, | |
| .gradio-container .block-title, | |
| .gradio-container .prose h2, | |
| .gradio-container .prose h3 { | |
| color:#093999 !important; | |
| } | |
| /* Dropdown Styling */ | |
| #model_dd .wrap, | |
| #model_dd .container { | |
| background:#111827 !important; | |
| border:1px solid #374151 !important; | |
| border-radius:8px !important; | |
| } | |
| #model_dd input, | |
| #model_dd .value, | |
| #model_dd ::placeholder, | |
| #model_dd select, | |
| #model_dd option { | |
| color:#ffffff!important; | |
| background:#111827 !important; | |
| } | |
| #model_dd .options, | |
| #model_dd .options .item { | |
| background:#111827 !important; | |
| color:#ffffff !important; | |
| } | |
| #model_dd label { | |
| color:#efe4b0 !important; | |
| } | |
| /* Slider labels */ | |
| .gradio-container .range-block label, | |
| .gradio-container .gr-slider label { | |
| color:#efe4b0 !important; | |
| } | |
| """ | |
| # --------------------- UI --------------------- | |
| with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo: | |
| with gr.Group(elem_id="hdr"): | |
| gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>") | |
| gr.Markdown("<p>IndicTrans2 with abbreviation expansion and sentence-wise translation</p>") | |
| model_choice = gr.Dropdown( | |
| label="Choose Model", | |
| choices=list(MODELS.keys()), | |
| value="Default (Public)", | |
| elem_id="model_dd" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("<h2>English Input</h2>") | |
| src = gr.Textbox(lines=12, placeholder="Enter English text...", show_label=False) | |
| with gr.Row(): | |
| translate_btn = gr.Button("Translate", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("<h2>Hindi Translation</h2>") | |
| hi_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False) | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("<h2>Telugu Translation</h2>") | |
| te_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False) | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("<h2>Settings</h2>") | |
| num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd") | |
| max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd") | |
| translate_btn.click( | |
| translate_dual_stream, | |
| inputs=[src, model_choice, num_beams, max_new], | |
| outputs=[hi_out, te_out] | |
| ) | |
| clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out]) | |
| demo.queue(max_size=48).launch() | |