SagarVelamuri commited on
Commit
a6cfb1d
·
verified ·
1 Parent(s): 95a28b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -36
app.py CHANGED
@@ -1,24 +1,27 @@
1
- import os, torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- from IndicTransToolkit import IndicProcessor # https://github.com/VarunGumma/IndicTransToolkit
5
 
6
- # --------- Config (override via Space Variables if you like) ----------
 
 
 
 
 
 
7
  TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
8
  MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
9
-
10
- # (Optional) pin revisions to avoid surprise upstream changes
11
- TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # e.g., "b1a2c3d"
12
- MODEL_REV = os.getenv("MODEL_REV", None) # e.g., "e4f5a6b"
13
 
14
  SRC_CODE = "eng_Latn"
15
  HI_CODE = "hin_Deva"
16
  TE_CODE = "tel_Telu"
17
 
18
- # -------------------- Load model & tokenizer --------------------------
19
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
 
 
22
  tok_kwargs = dict(trust_remote_code=True, use_fast=True)
23
  if TOKENIZER_REV: tok_kwargs["revision"] = TOKENIZER_REV
24
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
@@ -27,20 +30,29 @@ mdl_kwargs = dict(
27
  trust_remote_code=True,
28
  attn_implementation="eager",
29
  low_cpu_mem_usage=True,
30
- dtype=dtype, # <- fixes the torch_dtype deprecation warning
31
  )
32
  if MODEL_REV: mdl_kwargs["revision"] = MODEL_REV
 
33
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
34
  model.eval()
35
 
36
  ip = IndicProcessor(inference=True)
37
 
38
- # -------------------- Inference helpers -------------------------------
 
 
 
 
 
 
39
  @torch.inference_mode()
40
  def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
41
  temperature: float, top_p: float, top_k: int):
 
42
  batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
43
 
 
44
  enc = tokenizer(
45
  batch,
46
  max_length=256,
@@ -50,8 +62,10 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
50
  return_attention_mask=True,
51
  ).to(device)
52
 
 
53
  do_sample = (temperature is not None) and (float(temperature) > 0)
54
 
 
55
  outputs = model.generate(
56
  **enc,
57
  max_new_tokens=int(max_new_tokens),
@@ -62,16 +76,13 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
62
  top_k=int(top_k) if do_sample else None,
63
  use_cache=True,
64
  early_stopping=False,
65
- pad_token_id=tokenizer.pad_token_id or 0,
66
  )
67
 
68
- with tokenizer.as_target_tokenizer():
69
- decoded = tokenizer.batch_decode(
70
- outputs.detach().cpu().tolist(),
71
- skip_special_tokens=True,
72
- clean_up_tokenization_spaces=True,
73
- )
74
 
 
75
  final = ip.postprocess_batch(decoded, lang=tgt_code)
76
  return final[0].strip()
77
 
@@ -79,15 +90,25 @@ def translate_dual(text, num_beams, max_new_tokens, temperature, top_p, top_k):
79
  text = (text or "").strip()
80
  if not text:
81
  return "", ""
82
- hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
83
- te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return hi, te
85
 
86
- # -------------------- UI (professional, clean) ------------------------
87
- THEME = gr.themes.Soft(
88
- primary_hue="blue",
89
- neutral_hue="slate",
90
- ).set(
91
  body_background_fill="#0b1220",
92
  body_text_color_subdued="#cbd5e1",
93
  block_background_fill="#0f172a",
@@ -114,12 +135,7 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN→HI / EN→TE Translator"
114
 
115
  with gr.Row():
116
  with gr.Column(scale=1):
117
- src = gr.Textbox(
118
- label="English Text",
119
- placeholder="Type English here…",
120
- lines=8,
121
- autofocus=True,
122
- )
123
  with gr.Accordion("Advanced settings", open=False):
124
  num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam search: num_beams")
125
  max_new = gr.Slider(16, 512, value=128, step=8, label="Max new tokens")
@@ -135,11 +151,9 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN→HI / EN→TE Translator"
135
  te_out = gr.Textbox(label="Telugu (tel_Telu)", lines=8, show_copy_button=True)
136
 
137
  examples = gr.Examples(
138
- examples=[
139
- ["The Constitution guarantees fundamental rights to every citizen of India."],
140
- ["Maintenance proceedings shall commence within thirty days from the date of application."],
141
- ["The agreement shall remain in force unless terminated by mutual consent in writing."],
142
- ],
143
  inputs=[src],
144
  label="Quick examples",
145
  )
@@ -154,5 +168,4 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN→HI / EN→TE Translator"
154
 
155
  gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B · Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
156
 
157
- # IMPORTANT: remove unsupported arg; keep queue to enable request buffering
158
  demo.queue(max_size=48).launch()
 
1
+ import os, traceback, torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
4
 
5
+ # --- Robust import for IndicProcessor (fallback path per toolkit README) ---
6
+ try:
7
+ from IndicTransToolkit import IndicProcessor # preferred
8
+ except Exception:
9
+ from IndicTransToolkit.IndicTransToolkit import IndicProcessor # fallback
10
+
11
+ # ------------------- Config -------------------
12
  TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
13
  MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
14
+ TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # optional pin
15
+ MODEL_REV = os.getenv("MODEL_REV", None) # optional pin
 
 
16
 
17
  SRC_CODE = "eng_Latn"
18
  HI_CODE = "hin_Deva"
19
  TE_CODE = "tel_Telu"
20
 
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
23
 
24
+ # ------------------- Load model/tokenizer -------------------
25
  tok_kwargs = dict(trust_remote_code=True, use_fast=True)
26
  if TOKENIZER_REV: tok_kwargs["revision"] = TOKENIZER_REV
27
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
 
30
  trust_remote_code=True,
31
  attn_implementation="eager",
32
  low_cpu_mem_usage=True,
33
+ dtype=dtype, # modern kw (no deprecation warning)
34
  )
35
  if MODEL_REV: mdl_kwargs["revision"] = MODEL_REV
36
+
37
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
38
  model.eval()
39
 
40
  ip = IndicProcessor(inference=True)
41
 
42
+ # Ensure pad/eos ids are set to avoid edge-case crashes
43
+ if getattr(model.generation_config, "pad_token_id", None) is None:
44
+ model.generation_config.pad_token_id = getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", 0)
45
+ if getattr(model.generation_config, "eos_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
46
+ model.generation_config.eos_token_id = tokenizer.eos_token_id
47
+
48
+ # ------------------- Inference -------------------
49
  @torch.inference_mode()
50
  def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
51
  temperature: float, top_p: float, top_k: int):
52
+ # Preprocess
53
  batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
54
 
55
+ # Tokenize
56
  enc = tokenizer(
57
  batch,
58
  max_length=256,
 
62
  return_attention_mask=True,
63
  ).to(device)
64
 
65
+ # Sampling toggles
66
  do_sample = (temperature is not None) and (float(temperature) > 0)
67
 
68
+ # Generate
69
  outputs = model.generate(
70
  **enc,
71
  max_new_tokens=int(max_new_tokens),
 
76
  top_k=int(top_k) if do_sample else None,
77
  use_cache=True,
78
  early_stopping=False,
79
+ pad_token_id=model.generation_config.pad_token_id,
80
  )
81
 
82
+ # Decode (no deprecated as_target_tokenizer)
83
+ decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
 
 
 
 
84
 
85
+ # Postprocess
86
  final = ip.postprocess_batch(decoded, lang=tgt_code)
87
  return final[0].strip()
88
 
 
90
  text = (text or "").strip()
91
  if not text:
92
  return "", ""
93
+ try:
94
+ hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
95
+ except Exception as e:
96
+ # Surface a friendly error instead of Gradio's generic "Error"
97
+ msg = f"⚠️ Hindi translation failed: {type(e).__name__}: {str(e).splitlines()[-1]}"
98
+ print("HI ERROR:\n", traceback.format_exc())
99
+ hi = msg
100
+
101
+ try:
102
+ te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
103
+ except Exception as e:
104
+ msg = f"⚠️ Telugu translation failed: {type(e).__name__}: {str(e).splitlines()[-1]}"
105
+ print("TE ERROR:\n", traceback.format_exc())
106
+ te = msg
107
+
108
  return hi, te
109
 
110
+ # ------------------- UI -------------------
111
+ THEME = gr.themes.Soft(primary_hue="blue", neutral_hue="slate").set(
 
 
 
112
  body_background_fill="#0b1220",
113
  body_text_color_subdued="#cbd5e1",
114
  block_background_fill="#0f172a",
 
135
 
136
  with gr.Row():
137
  with gr.Column(scale=1):
138
+ src = gr.Textbox(label="English Text", placeholder="Type English here…", lines=8, autofocus=True)
 
 
 
 
 
139
  with gr.Accordion("Advanced settings", open=False):
140
  num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam search: num_beams")
141
  max_new = gr.Slider(16, 512, value=128, step=8, label="Max new tokens")
 
151
  te_out = gr.Textbox(label="Telugu (tel_Telu)", lines=8, show_copy_button=True)
152
 
153
  examples = gr.Examples(
154
+ examples=[["The Constitution guarantees fundamental rights to every citizen of India."],
155
+ ["Maintenance proceedings shall commence within thirty days from the date of application."],
156
+ ["The agreement shall remain in force unless terminated by mutual consent in writing."]],
 
 
157
  inputs=[src],
158
  label="Quick examples",
159
  )
 
168
 
169
  gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B · Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
170
 
 
171
  demo.queue(max_size=48).launch()