SagarVelamuri commited on
Commit
fc7b4e3
verified
1 Parent(s): 5b0fadd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -3,10 +3,14 @@ import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from IndicTransToolkit import IndicProcessor # https://github.com/VarunGumma/IndicTransToolkit
5
 
6
- # --------- Config (override via Space variables if you like) ----------
7
  TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
8
  MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
9
 
 
 
 
 
10
  SRC_CODE = "eng_Latn"
11
  HI_CODE = "hin_Deva"
12
  TE_CODE = "tel_Telu"
@@ -15,18 +19,18 @@ TE_CODE = "tel_Telu"
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
17
 
18
- tokenizer = AutoTokenizer.from_pretrained(
19
- TOKENIZER_ID, trust_remote_code=True, use_fast=True
20
- )
21
 
22
- model = AutoModelForSeq2SeqLM.from_pretrained(
23
- MODEL_ID,
24
  trust_remote_code=True,
25
  attn_implementation="eager",
26
  low_cpu_mem_usage=True,
27
- torch_dtype=dtype,
28
  )
29
- model = model.to(device)
 
30
  model.eval()
31
 
32
  ip = IndicProcessor(inference=True)
@@ -35,7 +39,6 @@ ip = IndicProcessor(inference=True)
35
  @torch.inference_mode()
36
  def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
37
  temperature: float, top_p: float, top_k: int):
38
- """Runs IndicTrans2-style preprocess -> generate -> postprocess for a single target language."""
39
  batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
40
 
41
  enc = tokenizer(
@@ -62,7 +65,6 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
62
  pad_token_id=tokenizer.pad_token_id or 0,
63
  )
64
 
65
- # decode
66
  with tokenizer.as_target_tokenizer():
67
  decoded = tokenizer.batch_decode(
68
  outputs.detach().cpu().tolist(),
@@ -70,7 +72,6 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
70
  clean_up_tokenization_spaces=True,
71
  )
72
 
73
- # postprocess
74
  final = ip.postprocess_batch(decoded, lang=tgt_code)
75
  return final[0].strip()
76
 
@@ -78,7 +79,6 @@ def translate_dual(text, num_beams, max_new_tokens, temperature, top_p, top_k):
78
  text = (text or "").strip()
79
  if not text:
80
  return "", ""
81
-
82
  hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
83
  te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
84
  return hi, te
@@ -154,4 +154,5 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN鈫扝I / EN鈫扵E Translator"
154
 
155
  gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B 路 Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
156
 
157
- demo.queue(concurrency_count=4, max_size=48).launch()
 
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from IndicTransToolkit import IndicProcessor # https://github.com/VarunGumma/IndicTransToolkit
5
 
6
+ # --------- Config (override via Space Variables if you like) ----------
7
  TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
8
  MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
9
 
10
+ # (Optional) pin revisions to avoid surprise upstream changes
11
+ TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # e.g., "b1a2c3d"
12
+ MODEL_REV = os.getenv("MODEL_REV", None) # e.g., "e4f5a6b"
13
+
14
  SRC_CODE = "eng_Latn"
15
  HI_CODE = "hin_Deva"
16
  TE_CODE = "tel_Telu"
 
19
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
 
22
+ tok_kwargs = dict(trust_remote_code=True, use_fast=True)
23
+ if TOKENIZER_REV: tok_kwargs["revision"] = TOKENIZER_REV
24
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
25
 
26
+ mdl_kwargs = dict(
 
27
  trust_remote_code=True,
28
  attn_implementation="eager",
29
  low_cpu_mem_usage=True,
30
+ dtype=dtype, # <- fixes the torch_dtype deprecation warning
31
  )
32
+ if MODEL_REV: mdl_kwargs["revision"] = MODEL_REV
33
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
34
  model.eval()
35
 
36
  ip = IndicProcessor(inference=True)
 
39
  @torch.inference_mode()
40
  def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
41
  temperature: float, top_p: float, top_k: int):
 
42
  batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
43
 
44
  enc = tokenizer(
 
65
  pad_token_id=tokenizer.pad_token_id or 0,
66
  )
67
 
 
68
  with tokenizer.as_target_tokenizer():
69
  decoded = tokenizer.batch_decode(
70
  outputs.detach().cpu().tolist(),
 
72
  clean_up_tokenization_spaces=True,
73
  )
74
 
 
75
  final = ip.postprocess_batch(decoded, lang=tgt_code)
76
  return final[0].strip()
77
 
 
79
  text = (text or "").strip()
80
  if not text:
81
  return "", ""
 
82
  hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
83
  te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
84
  return hi, te
 
154
 
155
  gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B 路 Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
156
 
157
+ # IMPORTANT: remove unsupported arg; keep queue to enable request buffering
158
+ demo.queue(max_size=48).launch()