SagarVelamuri's picture
Update app.py
cd9cc66 verified
import os, re, torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from IndicTransToolkit import IndicProcessor
import spacy
# --------------------- Device ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# --------------------- Languages ------------------
SRC_CODE = "eng_Latn"
HI_CODE = "hin_Deva"
TE_CODE = "tel_Telu"
ip = IndicProcessor(inference=True)
# --------------------- spaCy Sentence Splitter ---------------------
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
from spacy.cli import download
download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
def split_into_sentences(text):
"""Split English text into sentences using spaCy."""
doc = nlp(text.strip())
return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
# --------------------- Abbreviation Expansion ---------------------
ABBREVIATION_MAP = {
"subs.": "subsection",
"cl.": "clause",
"art.": "article",
"sec.": "section",
"s.": "section",
"no.": "number",
"sch.": "schedule",
"para.": "paragraph",
"r.": "rule",
"reg.": "regulation",
"dept.": "department",
}
_ABBR_PATTERN = re.compile(
r'(?<![A-Za-z])(' + '|'.join(re.escape(k) for k in ABBREVIATION_MAP.keys()) + r')(?=\s*(?:\(|\d|[A-Z]|[a-z]))',
flags=re.IGNORECASE
)
def expand_abbreviations(text: str) -> str:
"""Replace known abbreviations with full forms safely (without affecting natural words)."""
def replacer(match):
key = match.group(0)
repl = ABBREVIATION_MAP.get(key.lower(), key)
if key.isupper():
return repl.upper()
elif key[0].isupper():
return repl.capitalize()
return repl
return _ABBR_PATTERN.sub(replacer, text)
# --------------------- Clean Up Placeholder Tags ---------------------
def clean_translation(text):
"""Remove unresolved placeholder tags such as <ID1>, <ID2>."""
return re.sub(r"<ID\d+>", "", text).strip()
# --------------------- Model Loader ---------------------
MODELS = {
"Default (Public)": "law-ai/InLegalTrans-En2Indic-1B",
"Fine-tuned (Private)": "SagarVelamuri/InLegalTrans-En2Indic-FineTuned-Tel-Hin"
}
_model_cache = {}
def load_model(model_name: str):
if model_name in _model_cache:
return _model_cache[model_name]
token = os.getenv("hf_token")
tok = AutoTokenizer.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B",
trust_remote_code=True, use_fast=True
)
mdl = AutoModelForSeq2SeqLM.from_pretrained(
model_name, trust_remote_code=True,
low_cpu_mem_usage=True, dtype=dtype, token=token
).to(device).eval()
try:
mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
except Exception:
pass
_model_cache[model_name] = (tok, mdl)
return tok, mdl
# --------------------- Translation ---------------------
@torch.inference_mode()
def translate_dual_stream(text, model_choice, num_beams, max_new):
"""Stream Hindi and Telugu translations, one sentence at a time."""
if not text or not text.strip():
yield "", ""
return
tok, mdl = load_model(MODELS[model_choice])
# Expand known abbreviations
text = expand_abbreviations(text)
sentences = split_into_sentences(text)
hi_acc, te_acc = [], []
yield "", "" # Clear UI early
for i, sentence in enumerate(sentences, 1):
# --- Hindi ---
try:
batch_hi = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=HI_CODE)
enc_hi = tok(batch_hi, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
out_hi = mdl.generate(
**enc_hi,
max_length=int(max_new),
num_beams=int(num_beams),
do_sample=False,
early_stopping=True,
no_repeat_ngram_size=3,
use_cache=False
)
dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True, clean_up_tokenization_spaces=True)
post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
hi_text = clean_translation(post_hi[0])
# Optionally ensure danda for Hindi if missing
if not re.search(r"[।?!…]$", hi_text):
hi_text += "।"
hi_acc.append(hi_text)
except Exception as e:
hi_acc.append(f"⚠️ Hindi failed (sentence {i}): {e}")
# --- Telugu ---
try:
batch_te = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=TE_CODE)
enc_te = tok(batch_te, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
out_te = mdl.generate(
**enc_te,
max_length=int(max_new),
num_beams=int(num_beams),
do_sample=False,
early_stopping=True,
no_repeat_ngram_size=3,
use_cache=False
)
dec_te = tok.batch_decode(out_te, skip_special_tokens=True, clean_up_tokenization_spaces=True)
post_te = ip.postprocess_batch(dec_te, lang=TE_CODE)
te_acc.append(clean_translation(post_te[0]))
except Exception as e:
te_acc.append(f"⚠️ Telugu failed (sentence {i}): {e}")
yield (" ".join(hi_acc), " ".join(te_acc))
# --------------------- Dark Theme ---------------------
THEME = gr.themes.Soft(
primary_hue="blue", neutral_hue="slate"
).set(
body_background_fill="#0b0f19",
body_text_color="#f3f4f6",
block_background_fill="#111827",
block_border_color="#1f2937",
block_title_text_color="#123456",
button_primary_background_fill="#2563eb",
button_primary_text_color="#ffffff",
)
CUSTOM_CSS = """
/* Header + Panels */
#hdr { text-align:center; padding:16px; }
#hdr h1 { font-size:24px; font-weight:700; color:#f9fafb; margin:0; }
#hdr p { font-size:14px; color:#9ca3af; margin-top:4px; }
.panel { border:1px solid #1f2937; border-radius:10px; padding:12px; background:#111827; box-shadow:0 1px 2px rgba(0,0,0,0.4);}
.panel h2 { font-size:16px; font-weight:600; margin-bottom:6px; color:#f3f4f6; }
/* Inputs */
textarea { background:#0b0f19 !important; color:#f9fafb !important; border-radius:8px !important; border:1px solid #374151 !important; font-size:15px !important; line-height:1.55; }
button { border-radius:8px !important; font-weight:600 !important; }
/* Labels */
.gradio-container label,
.gradio-container .label,
.gradio-container .block-title,
.gradio-container .prose h2,
.gradio-container .prose h3 {
color:#093999 !important;
}
/* Dropdown Styling */
#model_dd .wrap,
#model_dd .container {
background:#111827 !important;
border:1px solid #374151 !important;
border-radius:8px !important;
}
#model_dd input,
#model_dd .value,
#model_dd ::placeholder,
#model_dd select,
#model_dd option {
color:#ffffff!important;
background:#111827 !important;
}
#model_dd .options,
#model_dd .options .item {
background:#111827 !important;
color:#ffffff !important;
}
#model_dd label {
color:#efe4b0 !important;
}
/* Slider labels */
.gradio-container .range-block label,
.gradio-container .gr-slider label {
color:#efe4b0 !important;
}
"""
# --------------------- UI ---------------------
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
with gr.Group(elem_id="hdr"):
gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
gr.Markdown("<p>IndicTrans2 with abbreviation expansion and sentence-wise translation</p>")
model_choice = gr.Dropdown(
label="Choose Model",
choices=list(MODELS.keys()),
value="Default (Public)",
elem_id="model_dd"
)
with gr.Row():
with gr.Column(scale=2):
with gr.Group(elem_classes="panel"):
gr.Markdown("<h2>English Input</h2>")
src = gr.Textbox(lines=12, placeholder="Enter English text...", show_label=False)
with gr.Row():
translate_btn = gr.Button("Translate", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column(scale=2):
with gr.Group(elem_classes="panel"):
gr.Markdown("<h2>Hindi Translation</h2>")
hi_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
with gr.Group(elem_classes="panel"):
gr.Markdown("<h2>Telugu Translation</h2>")
te_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
with gr.Column(scale=1):
with gr.Group(elem_classes="panel"):
gr.Markdown("<h2>Settings</h2>")
num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
translate_btn.click(
translate_dual_stream,
inputs=[src, model_choice, num_beams, max_new],
outputs=[hi_out, te_out]
)
clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
demo.queue(max_size=48).launch()