Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,24 +1,27 @@
|
|
| 1 |
-
import os, torch
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
-
from IndicTransToolkit import IndicProcessor # https://github.com/VarunGumma/IndicTransToolkit
|
| 5 |
|
| 6 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
|
| 8 |
MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # e.g., "b1a2c3d"
|
| 12 |
-
MODEL_REV = os.getenv("MODEL_REV", None) # e.g., "e4f5a6b"
|
| 13 |
|
| 14 |
SRC_CODE = "eng_Latn"
|
| 15 |
HI_CODE = "hin_Deva"
|
| 16 |
TE_CODE = "tel_Telu"
|
| 17 |
|
| 18 |
-
# -------------------- Load model & tokenizer --------------------------
|
| 19 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 20 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 21 |
|
|
|
|
| 22 |
tok_kwargs = dict(trust_remote_code=True, use_fast=True)
|
| 23 |
if TOKENIZER_REV: tok_kwargs["revision"] = TOKENIZER_REV
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
|
|
@@ -27,20 +30,29 @@ mdl_kwargs = dict(
|
|
| 27 |
trust_remote_code=True,
|
| 28 |
attn_implementation="eager",
|
| 29 |
low_cpu_mem_usage=True,
|
| 30 |
-
dtype=dtype,
|
| 31 |
)
|
| 32 |
if MODEL_REV: mdl_kwargs["revision"] = MODEL_REV
|
|
|
|
| 33 |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
|
| 34 |
model.eval()
|
| 35 |
|
| 36 |
ip = IndicProcessor(inference=True)
|
| 37 |
|
| 38 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
@torch.inference_mode()
|
| 40 |
def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
|
| 41 |
temperature: float, top_p: float, top_k: int):
|
|
|
|
| 42 |
batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
|
| 43 |
|
|
|
|
| 44 |
enc = tokenizer(
|
| 45 |
batch,
|
| 46 |
max_length=256,
|
|
@@ -50,8 +62,10 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
|
|
| 50 |
return_attention_mask=True,
|
| 51 |
).to(device)
|
| 52 |
|
|
|
|
| 53 |
do_sample = (temperature is not None) and (float(temperature) > 0)
|
| 54 |
|
|
|
|
| 55 |
outputs = model.generate(
|
| 56 |
**enc,
|
| 57 |
max_new_tokens=int(max_new_tokens),
|
|
@@ -62,16 +76,13 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
|
|
| 62 |
top_k=int(top_k) if do_sample else None,
|
| 63 |
use_cache=True,
|
| 64 |
early_stopping=False,
|
| 65 |
-
pad_token_id=
|
| 66 |
)
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
outputs.detach().cpu().tolist(),
|
| 71 |
-
skip_special_tokens=True,
|
| 72 |
-
clean_up_tokenization_spaces=True,
|
| 73 |
-
)
|
| 74 |
|
|
|
|
| 75 |
final = ip.postprocess_batch(decoded, lang=tgt_code)
|
| 76 |
return final[0].strip()
|
| 77 |
|
|
@@ -79,15 +90,25 @@ def translate_dual(text, num_beams, max_new_tokens, temperature, top_p, top_k):
|
|
| 79 |
text = (text or "").strip()
|
| 80 |
if not text:
|
| 81 |
return "", ""
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
return hi, te
|
| 85 |
|
| 86 |
-
# -------------------
|
| 87 |
-
THEME = gr.themes.Soft(
|
| 88 |
-
primary_hue="blue",
|
| 89 |
-
neutral_hue="slate",
|
| 90 |
-
).set(
|
| 91 |
body_background_fill="#0b1220",
|
| 92 |
body_text_color_subdued="#cbd5e1",
|
| 93 |
block_background_fill="#0f172a",
|
|
@@ -114,12 +135,7 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN→HI / EN→TE Translator"
|
|
| 114 |
|
| 115 |
with gr.Row():
|
| 116 |
with gr.Column(scale=1):
|
| 117 |
-
src = gr.Textbox(
|
| 118 |
-
label="English Text",
|
| 119 |
-
placeholder="Type English here…",
|
| 120 |
-
lines=8,
|
| 121 |
-
autofocus=True,
|
| 122 |
-
)
|
| 123 |
with gr.Accordion("Advanced settings", open=False):
|
| 124 |
num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam search: num_beams")
|
| 125 |
max_new = gr.Slider(16, 512, value=128, step=8, label="Max new tokens")
|
|
@@ -135,11 +151,9 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN→HI / EN→TE Translator"
|
|
| 135 |
te_out = gr.Textbox(label="Telugu (tel_Telu)", lines=8, show_copy_button=True)
|
| 136 |
|
| 137 |
examples = gr.Examples(
|
| 138 |
-
examples=[
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
["The agreement shall remain in force unless terminated by mutual consent in writing."],
|
| 142 |
-
],
|
| 143 |
inputs=[src],
|
| 144 |
label="Quick examples",
|
| 145 |
)
|
|
@@ -154,5 +168,4 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN→HI / EN→TE Translator"
|
|
| 154 |
|
| 155 |
gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B · Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
|
| 156 |
|
| 157 |
-
# IMPORTANT: remove unsupported arg; keep queue to enable request buffering
|
| 158 |
demo.queue(max_size=48).launch()
|
|
|
|
| 1 |
+
import os, traceback, torch
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
| 4 |
|
| 5 |
+
# --- Robust import for IndicProcessor (fallback path per toolkit README) ---
|
| 6 |
+
try:
|
| 7 |
+
from IndicTransToolkit import IndicProcessor # preferred
|
| 8 |
+
except Exception:
|
| 9 |
+
from IndicTransToolkit.IndicTransToolkit import IndicProcessor # fallback
|
| 10 |
+
|
| 11 |
+
# ------------------- Config -------------------
|
| 12 |
TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
|
| 13 |
MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
|
| 14 |
+
TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # optional pin
|
| 15 |
+
MODEL_REV = os.getenv("MODEL_REV", None) # optional pin
|
|
|
|
|
|
|
| 16 |
|
| 17 |
SRC_CODE = "eng_Latn"
|
| 18 |
HI_CODE = "hin_Deva"
|
| 19 |
TE_CODE = "tel_Telu"
|
| 20 |
|
|
|
|
| 21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 23 |
|
| 24 |
+
# ------------------- Load model/tokenizer -------------------
|
| 25 |
tok_kwargs = dict(trust_remote_code=True, use_fast=True)
|
| 26 |
if TOKENIZER_REV: tok_kwargs["revision"] = TOKENIZER_REV
|
| 27 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
|
|
|
|
| 30 |
trust_remote_code=True,
|
| 31 |
attn_implementation="eager",
|
| 32 |
low_cpu_mem_usage=True,
|
| 33 |
+
dtype=dtype, # modern kw (no deprecation warning)
|
| 34 |
)
|
| 35 |
if MODEL_REV: mdl_kwargs["revision"] = MODEL_REV
|
| 36 |
+
|
| 37 |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
|
| 38 |
model.eval()
|
| 39 |
|
| 40 |
ip = IndicProcessor(inference=True)
|
| 41 |
|
| 42 |
+
# Ensure pad/eos ids are set to avoid edge-case crashes
|
| 43 |
+
if getattr(model.generation_config, "pad_token_id", None) is None:
|
| 44 |
+
model.generation_config.pad_token_id = getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", 0)
|
| 45 |
+
if getattr(model.generation_config, "eos_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
|
| 46 |
+
model.generation_config.eos_token_id = tokenizer.eos_token_id
|
| 47 |
+
|
| 48 |
+
# ------------------- Inference -------------------
|
| 49 |
@torch.inference_mode()
|
| 50 |
def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
|
| 51 |
temperature: float, top_p: float, top_k: int):
|
| 52 |
+
# Preprocess
|
| 53 |
batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
|
| 54 |
|
| 55 |
+
# Tokenize
|
| 56 |
enc = tokenizer(
|
| 57 |
batch,
|
| 58 |
max_length=256,
|
|
|
|
| 62 |
return_attention_mask=True,
|
| 63 |
).to(device)
|
| 64 |
|
| 65 |
+
# Sampling toggles
|
| 66 |
do_sample = (temperature is not None) and (float(temperature) > 0)
|
| 67 |
|
| 68 |
+
# Generate
|
| 69 |
outputs = model.generate(
|
| 70 |
**enc,
|
| 71 |
max_new_tokens=int(max_new_tokens),
|
|
|
|
| 76 |
top_k=int(top_k) if do_sample else None,
|
| 77 |
use_cache=True,
|
| 78 |
early_stopping=False,
|
| 79 |
+
pad_token_id=model.generation_config.pad_token_id,
|
| 80 |
)
|
| 81 |
|
| 82 |
+
# Decode (no deprecated as_target_tokenizer)
|
| 83 |
+
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
# Postprocess
|
| 86 |
final = ip.postprocess_batch(decoded, lang=tgt_code)
|
| 87 |
return final[0].strip()
|
| 88 |
|
|
|
|
| 90 |
text = (text or "").strip()
|
| 91 |
if not text:
|
| 92 |
return "", ""
|
| 93 |
+
try:
|
| 94 |
+
hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
# Surface a friendly error instead of Gradio's generic "Error"
|
| 97 |
+
msg = f"⚠️ Hindi translation failed: {type(e).__name__}: {str(e).splitlines()[-1]}"
|
| 98 |
+
print("HI ERROR:\n", traceback.format_exc())
|
| 99 |
+
hi = msg
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
msg = f"⚠️ Telugu translation failed: {type(e).__name__}: {str(e).splitlines()[-1]}"
|
| 105 |
+
print("TE ERROR:\n", traceback.format_exc())
|
| 106 |
+
te = msg
|
| 107 |
+
|
| 108 |
return hi, te
|
| 109 |
|
| 110 |
+
# ------------------- UI -------------------
|
| 111 |
+
THEME = gr.themes.Soft(primary_hue="blue", neutral_hue="slate").set(
|
|
|
|
|
|
|
|
|
|
| 112 |
body_background_fill="#0b1220",
|
| 113 |
body_text_color_subdued="#cbd5e1",
|
| 114 |
block_background_fill="#0f172a",
|
|
|
|
| 135 |
|
| 136 |
with gr.Row():
|
| 137 |
with gr.Column(scale=1):
|
| 138 |
+
src = gr.Textbox(label="English Text", placeholder="Type English here…", lines=8, autofocus=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
with gr.Accordion("Advanced settings", open=False):
|
| 140 |
num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam search: num_beams")
|
| 141 |
max_new = gr.Slider(16, 512, value=128, step=8, label="Max new tokens")
|
|
|
|
| 151 |
te_out = gr.Textbox(label="Telugu (tel_Telu)", lines=8, show_copy_button=True)
|
| 152 |
|
| 153 |
examples = gr.Examples(
|
| 154 |
+
examples=[["The Constitution guarantees fundamental rights to every citizen of India."],
|
| 155 |
+
["Maintenance proceedings shall commence within thirty days from the date of application."],
|
| 156 |
+
["The agreement shall remain in force unless terminated by mutual consent in writing."]],
|
|
|
|
|
|
|
| 157 |
inputs=[src],
|
| 158 |
label="Quick examples",
|
| 159 |
)
|
|
|
|
| 168 |
|
| 169 |
gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B · Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
|
| 170 |
|
|
|
|
| 171 |
demo.queue(max_size=48).launch()
|