LiamKhoaLe commited on
Commit
1d46eb9
·
1 Parent(s): cfa5d44

Upd vietnamese transl

Browse files
Files changed (12) hide show
  1. Dockerfile +3 -0
  2. README.md +12 -12
  3. app.py +45 -9
  4. requirements.txt +2 -0
  5. trans_test.py +78 -0
  6. utils/processor.py +30 -18
  7. utils/rag.py +23 -8
  8. vi/README.md +95 -0
  9. vi/__init__.py +10 -0
  10. vi/download.py +89 -0
  11. vi/processing.py +95 -0
  12. vi/translator.py +266 -0
Dockerfile CHANGED
@@ -16,6 +16,9 @@ RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt
16
  # Copy the application
17
  COPY --chown=user . .
18
 
 
 
 
19
  # Hugging Face cache setup
20
  ENV HF_HOME="$HOME/.cache/huggingface"
21
  ENV SENTENCE_TRANSFORMERS_HOME="$HOME/.cache/huggingface/sentence-transformers"
 
16
  # Copy the application
17
  COPY --chown=user . .
18
 
19
+ # Download Vietnamese translation model
20
+ RUN python vi/download.py
21
+
22
  # Hugging Face cache setup
23
  ENV HF_HOME="$HOME/.cache/huggingface"
24
  ENV SENTENCE_TRANSFORMERS_HOME="$HOME/.cache/huggingface/sentence-transformers"
README.md CHANGED
@@ -1,32 +1,32 @@
1
  ---
2
- title: MedAI Processing
3
  emoji: ⚕️
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
- short_description: Process and centralise medical doc for llm finetuning
10
  ---
11
 
12
  ## Quick Access:
13
 
14
- [HF Space](https://huggingface.co/spaces/BinKhoaLe1812/MedAI_Processing)
15
 
16
- [MedDialog-100k](https://huggingface.co/datasets/BinKhoaLe1812/MedDialog-EN-100k)
17
 
18
- [MedDialog-100k](https://huggingface.co/datasets/BinKhoaLe1812/MedDialog-EN-10k)
19
 
20
- [PubMedQA-Labelled](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-L)
21
 
22
- [PubMedQA-Unlabelled](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-U)
23
 
24
- [PubMedQA-Mapper](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-MAP)
25
 
26
 
27
  ## CURL Request Instruction
28
- [Request Doc](https://huggingface.co/spaces/MedAI-COS30018/MedAI_Processing/blob/main/REQUEST.md)
29
 
30
  ## License
31
- [Apache-2.0 LICENSE](https://huggingface.co/spaces/MedAI-COS30018/MedAI_Processing/blob/main/LICENSE.txt)
32
 
 
1
  ---
2
+ title: MedVietAI Processing
3
  emoji: ⚕️
4
+ colorFrom: green
5
+ colorTo: pink
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
+ short_description: Data processing with en-vi translation. Derived from 500k mi
10
  ---
11
 
12
  ## Quick Access:
13
 
14
+ [HF Space](https://huggingface.co/spaces/MedVietAI/processing)
15
 
16
+ [MedDialog-100k](https://huggingface.co/datasets/MedAI-COS30018/MedDialog-EN-100k)
17
 
18
+ [MedDialog-100k](https://huggingface.co/datasets/MedAI-COS30018/MedDialog-EN-10k)
19
 
20
+ [PubMedQA-Labelled](https://huggingface.co/datasets/MedAI-COS30018/PubMedQA-L)
21
 
22
+ [PubMedQA-Unlabelled](https://huggingface.co/datasets/MedAI-COS30018/PubMedQA-U)
23
 
24
+ [PubMedQA-Mapper](https://huggingface.co/datasets/MedAI-COS30018/PubMedQA-MAP)
25
 
26
 
27
  ## CURL Request Instruction
28
+ [Request Doc](https://huggingface.co/spaces/MedVietAI/processing/blob/main/REQUEST.md)
29
 
30
  ## License
31
+ [Apache-2.0 LICENSE](https://huggingface.co/spaces/MedVietAI/processing/blob/main/LICENSE.txt)
32
 
app.py CHANGED
@@ -18,6 +18,7 @@ from utils.drive_saver import DriveSaver
18
  from utils.llm import Paraphraser
19
  from utils.schema import CentralisedWriter
20
  from utils.token import get_credentials, exchange_code, build_auth_url
 
21
 
22
  # ────────── Log ───────────
23
  logger = logging.getLogger("app")
@@ -53,6 +54,9 @@ paraphraser = Paraphraser(
53
  gemini_model_hard=os.getenv("GEMINI_MODEL_HARD", "gemini-2.5-flash"),
54
  )
55
 
 
 
 
56
  app = FastAPI(title="Medical Dataset Augmenter", version="1.1.0")
57
 
58
  STATE_LOCK = threading.Lock()
@@ -85,6 +89,7 @@ class ProcessParams(BaseModel):
85
  sample_limit: Optional[int] = None # Set data sampling if needed
86
  seed: int = 42
87
  rag_processing: bool = False # Enable RAG-specific processing
 
88
 
89
  def set_state(**kwargs):
90
  with STATE_LOCK:
@@ -122,6 +127,14 @@ def root():
122
  <div class="section">
123
  <h2>⚡ Quick Actions</h2>
124
  <p>Click a button below to start processing a dataset with default augmentation parameters.</p>
 
 
 
 
 
 
 
 
125
  <button onclick="startJob('healthcaremagic')">▶ProcAugment HealthCareMagic (100k)</button><br>
126
  <button onclick="startJob('icliniq')">▶ProcAugment iCliniq (10k-derived)</button><br>
127
  <button onclick="startJob('pubmedqa_l')">▶ProcAugment PubMedQA (Labelled)</button><br>
@@ -155,10 +168,10 @@ def root():
155
  <script>
156
  async function startJob(dataset) {{
157
  const log = document.getElementById("log");
158
- const ragToggle = document.getElementById("ragToggle");
159
- const isRagMode = ragToggle.checked;
160
 
161
- log.innerHTML = "⏳ Starting " + (isRagMode ? "RAG " : "") + "job for <b>" + dataset + "</b>...";
162
  try {{
163
  const resp = await fetch("/process/" + dataset, {{
164
  method: "POST",
@@ -177,7 +190,8 @@ def root():
177
  }},
178
  sample_limit: null, // Sample down (currently disabled)
179
  seed: 42,
180
- rag_processing: isRagMode
 
181
  }})
182
  }});
183
  const data = await resp.json();
@@ -193,14 +207,18 @@ def root():
193
 
194
  async function startRagJob(dataset) {{
195
  const log = document.getElementById("log");
196
- log.innerHTML = "⏳ Starting RAG processing for <b>" + dataset + "</b>...";
 
 
 
197
  try {{
198
  const resp = await fetch("/rag/" + dataset, {{
199
  method: "POST",
200
  headers: {{ "Content-Type": "application/json" }},
201
  body: JSON.stringify({{
202
  sample_limit: null,
203
- seed: 42
 
204
  }})
205
  }});
206
  const data = await resp.json();
@@ -366,6 +384,18 @@ def _run_job(dataset_key: str, params: ProcessParams):
366
  # Writer
367
  writer = CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
368
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  if params.rag_processing:
370
  # RAG processing mode
371
  set_state(message="RAG processing", progress=0.1)
@@ -376,20 +406,26 @@ def _run_job(dataset_key: str, params: ProcessParams):
376
  nvidia_model=os.getenv("NVIDIA_MODEL", "meta/llama-3.1-8b-instruct"),
377
  sample_limit=params.sample_limit,
378
  seed=params.seed,
379
- progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"])
 
380
  )
381
  else:
382
  # Standard SFT processing mode
383
  set_state(message="SFT processing", progress=0.1)
 
 
 
 
384
  count, stats = process_file_into_sft(
385
  dataset_key=dataset_key,
386
  input_path=local_path,
387
  writer=writer,
388
  paraphraser=paraphraser,
389
- augment_opts=params.augment.dict(),
390
  sample_limit=params.sample_limit,
391
  seed=params.seed,
392
- progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"])
 
393
  )
394
  logger.info(f"[JOB] Processed dataset={dataset_key} rows={count} stats={stats}")
395
  writer.close()
 
18
  from utils.llm import Paraphraser
19
  from utils.schema import CentralisedWriter
20
  from utils.token import get_credentials, exchange_code, build_auth_url
21
+ from vi.translator import VietnameseTranslator
22
 
23
  # ────────── Log ───────────
24
  logger = logging.getLogger("app")
 
54
  gemini_model_hard=os.getenv("GEMINI_MODEL_HARD", "gemini-2.5-flash"),
55
  )
56
 
57
+ # Vietnamese translator
58
+ vietnamese_translator = VietnameseTranslator()
59
+
60
  app = FastAPI(title="Medical Dataset Augmenter", version="1.1.0")
61
 
62
  STATE_LOCK = threading.Lock()
 
89
  sample_limit: Optional[int] = None # Set data sampling if needed
90
  seed: int = 42
91
  rag_processing: bool = False # Enable RAG-specific processing
92
+ vietnamese_translation: bool = False # Enable Vietnamese translation
93
 
94
  def set_state(**kwargs):
95
  with STATE_LOCK:
 
127
  <div class="section">
128
  <h2>⚡ Quick Actions</h2>
129
  <p>Click a button below to start processing a dataset with default augmentation parameters.</p>
130
+
131
+ <div style="margin-bottom: 15px; padding: 10px; background: #f8f9fa; border-radius: 5px; border-left: 4px solid #2d89ef;">
132
+ <label style="display: flex; align-items: center; cursor: pointer;">
133
+ <input type="checkbox" id="vietnameseTranslation" style="margin-right: 8px; transform: scale(1.2);">
134
+ <strong>🇻🇳 Vietnamese Translation</strong> - Translate all content to Vietnamese before processing
135
+ </label>
136
+ </div>
137
+
138
  <button onclick="startJob('healthcaremagic')">▶ProcAugment HealthCareMagic (100k)</button><br>
139
  <button onclick="startJob('icliniq')">▶ProcAugment iCliniq (10k-derived)</button><br>
140
  <button onclick="startJob('pubmedqa_l')">▶ProcAugment PubMedQA (Labelled)</button><br>
 
168
  <script>
169
  async function startJob(dataset) {{
170
  const log = document.getElementById("log");
171
+ const vietnameseToggle = document.getElementById("vietnameseTranslation");
172
+ const isVietnameseMode = vietnameseToggle.checked;
173
 
174
+ log.innerHTML = "⏳ Starting job for <b>" + dataset + "</b>" + (isVietnameseMode ? " with Vietnamese translation" : "") + "...";
175
  try {{
176
  const resp = await fetch("/process/" + dataset, {{
177
  method: "POST",
 
190
  }},
191
  sample_limit: null, // Sample down (currently disabled)
192
  seed: 42,
193
+ rag_processing: false,
194
+ vietnamese_translation: isVietnameseMode
195
  }})
196
  }});
197
  const data = await resp.json();
 
207
 
208
  async function startRagJob(dataset) {{
209
  const log = document.getElementById("log");
210
+ const vietnameseToggle = document.getElementById("vietnameseTranslation");
211
+ const isVietnameseMode = vietnameseToggle.checked;
212
+
213
+ log.innerHTML = "⏳ Starting RAG processing for <b>" + dataset + "</b>" + (isVietnameseMode ? " with Vietnamese translation" : "") + "...";
214
  try {{
215
  const resp = await fetch("/rag/" + dataset, {{
216
  method: "POST",
217
  headers: {{ "Content-Type": "application/json" }},
218
  body: JSON.stringify({{
219
  sample_limit: null,
220
+ seed: 42,
221
+ vietnamese_translation: isVietnameseMode
222
  }})
223
  }});
224
  const data = await resp.json();
 
384
  # Writer
385
  writer = CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
386
 
387
+ # Load translator if Vietnamese translation is requested
388
+ translator = None
389
+ if params.vietnamese_translation:
390
+ set_state(message="Loading Vietnamese translator", progress=0.05)
391
+ try:
392
+ vietnamese_translator.load_model()
393
+ translator = vietnamese_translator
394
+ logger.info("✅ Vietnamese translator loaded successfully")
395
+ except Exception as e:
396
+ logger.error(f"❌ Failed to load Vietnamese translator: {e}")
397
+ set_state(message=f"Warning: Vietnamese translation failed - {e}", progress=0.1)
398
+
399
  if params.rag_processing:
400
  # RAG processing mode
401
  set_state(message="RAG processing", progress=0.1)
 
406
  nvidia_model=os.getenv("NVIDIA_MODEL", "meta/llama-3.1-8b-instruct"),
407
  sample_limit=params.sample_limit,
408
  seed=params.seed,
409
+ progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
410
+ translator=translator
411
  )
412
  else:
413
  # Standard SFT processing mode
414
  set_state(message="SFT processing", progress=0.1)
415
+ # Add Vietnamese translation flag to augment options
416
+ augment_opts = params.augment.dict()
417
+ augment_opts["vietnamese_translation"] = params.vietnamese_translation
418
+
419
  count, stats = process_file_into_sft(
420
  dataset_key=dataset_key,
421
  input_path=local_path,
422
  writer=writer,
423
  paraphraser=paraphraser,
424
+ augment_opts=augment_opts,
425
  sample_limit=params.sample_limit,
426
  seed=params.seed,
427
+ progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
428
+ translator=translator
429
  )
430
  logger.info(f"[JOB] Processed dataset={dataset_key} rows={count} stats={stats}")
431
  writer.close()
requirements.txt CHANGED
@@ -11,3 +11,5 @@ google-auth-oauthlib
11
  orjson
12
  ftfy
13
  langid
 
 
 
11
  orjson
12
  ftfy
13
  langid
14
+ transformers
15
+ torch
trans_test.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for Vietnamese translation functionality
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import logging
9
+ from dotenv import load_dotenv
10
+
11
+ # Add the current directory to Python path
12
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ from vi.translator import VietnameseTranslator
15
+
16
+ # Setup logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def test_translation():
21
+ """Test the Vietnamese translation functionality"""
22
+ load_dotenv()
23
+
24
+ # Initialize translator
25
+ translator = VietnameseTranslator()
26
+
27
+ try:
28
+ # Load the model
29
+ logger.info("Loading translation model...")
30
+ translator.load_model()
31
+ logger.info("✅ Model loaded successfully")
32
+
33
+ # Test single text translation
34
+ test_text = "Hello, how are you today? I hope you are feeling well."
35
+ logger.info(f"Original text: {test_text}")
36
+
37
+ translated = translator.translate_text(test_text)
38
+ logger.info(f"Translated text: {translated}")
39
+
40
+ # Test batch translation
41
+ test_texts = [
42
+ "What are the symptoms of diabetes?",
43
+ "How do I treat a headache?",
44
+ "What is the recommended dosage for this medication?"
45
+ ]
46
+
47
+ logger.info("Testing batch translation...")
48
+ batch_translated = translator.translate_batch(test_texts)
49
+
50
+ for i, (original, translated) in enumerate(zip(test_texts, batch_translated)):
51
+ logger.info(f"Batch {i+1}:")
52
+ logger.info(f" Original: {original}")
53
+ logger.info(f" Translated: {translated}")
54
+
55
+ # Test dictionary translation
56
+ test_dict = {
57
+ "instruction": "Answer the medical question",
58
+ "input": "What are the side effects of aspirin?",
59
+ "output": "Common side effects include stomach irritation and bleeding."
60
+ }
61
+
62
+ logger.info("Testing dictionary translation...")
63
+ dict_translated = translator.translate_dict(test_dict, ["instruction", "input", "output"])
64
+
65
+ logger.info("Dictionary translation result:")
66
+ for key, value in dict_translated.items():
67
+ logger.info(f" {key}: {value}")
68
+
69
+ logger.info("🎉 All translation tests completed successfully!")
70
+ return True
71
+
72
+ except Exception as e:
73
+ logger.error(f"❌ Translation test failed: {e}")
74
+ return False
75
+
76
+ if __name__ == "__main__":
77
+ success = test_translation()
78
+ sys.exit(0 if success else 1)
utils/processor.py CHANGED
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Dict, Tuple
7
 
8
  from utils.schema import sft_row
9
  from utils import augment as A
 
10
 
11
  # Logger
12
  logger = logging.getLogger("processor")
@@ -40,7 +41,8 @@ def process_file_into_sft(
40
  augment_opts: Dict,
41
  sample_limit: Optional[int],
42
  seed: int,
43
- progress_cb: Optional[Callable[[float, str], None]]
 
44
  ) -> Tuple[int, Dict]:
45
  random.seed(seed)
46
  stats = {
@@ -68,13 +70,13 @@ def process_file_into_sft(
68
  if key in ("healthcaremagic", "icliniq"):
69
  count = _proc_med_dialog(source=key, path=input_path, writer=writer,
70
  paraphraser=paraphraser, opts=augment_opts,
71
- sample_limit=sample_limit, stats=stats, cb=progress_cb, dedupe_seen=dedupe_seen)
72
  elif key == "pubmedqa_l":
73
- count = _proc_pubmedqa_l(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
74
  elif key == "pubmedqa_u":
75
- count = _proc_pubmedqa_u(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
76
  elif key == "pubmedqa_map":
77
- count = _proc_pubmedqa_map(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
78
  else:
79
  raise ValueError(f"Unknown dataset: {dataset_key}")
80
  logger.info(f"[PROC] End dataset={dataset_key} stats={stats}")
@@ -135,7 +137,7 @@ def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphr
135
 
136
  return instr, user, out, applied
137
 
138
- def _commit_row(writer, source, rid, task, instr, user, out, opts, stats, aug_applied, extra_meta=None, dedupe_seen=None):
139
  # Dedup entry
140
  if dedupe_seen is not None:
141
  fp = A.fingerprint(instr, user, out)
@@ -149,13 +151,23 @@ def _commit_row(writer, source, rid, task, instr, user, out, opts, stats, aug_ap
149
  meta.update(extra_meta)
150
 
151
  row = sft_row(instr, user, out, source=source, rid=rid, task=task, meta=meta)
 
 
 
 
 
 
 
 
 
 
152
  writer.write(row)
153
  stats["written"] += 1
154
  return True
155
 
156
  # ——————————— dataset processors ———————————
157
 
158
- def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
159
  count = 0
160
  written = 0
161
  for i, obj in enumerate(_iter_json_or_jsonl(path), start=1):
@@ -184,12 +196,12 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
184
  applied.append("consistency_flag")
185
 
186
  # 2) If expansion is enabled, add augmented copies
187
- _commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen)
188
  # Add augmented copies if expand
189
  if opts.get("expand", True):
190
  for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
191
  rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
192
- _commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
193
 
194
  # Increment count only on success
195
  count += 1
@@ -205,7 +217,7 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
205
  logger.info(f"[PROC] {source} done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
206
  return count
207
 
208
- def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
209
  with open(path, "r", encoding="utf-8") as f:
210
  data = json.load(f)
211
  count = 0
@@ -236,12 +248,12 @@ def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, d
236
 
237
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats)
238
  _commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied,
239
- extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen)
240
  if opts.get("expand", True):
241
  for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
242
  rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
243
  _commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa",
244
- instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
245
 
246
  # Increment count only on success
247
  count += 1
@@ -257,7 +269,7 @@ def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, d
257
  logger.info(f"[PROC] pubmedqa_l done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
258
  return count
259
 
260
- def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
261
  with open(path, "r", encoding="utf-8") as f:
262
  data = json.load(f)
263
  count = 0
@@ -290,12 +302,12 @@ def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, d
290
  out = guess.strip()
291
 
292
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats)
293
- _commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen)
294
  if opts.get("expand", True):
295
  for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
296
  rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
297
  _commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa",
298
- instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
299
 
300
  # Increment count only on success
301
  count += 1
@@ -311,7 +323,7 @@ def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, d
311
  logger.info(f"[PROC] pubmedqa_u done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
312
  return count
313
 
314
- def _proc_pubmedqa_map(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
315
  with open(path, "r", encoding="utf-8") as f:
316
  obj = json.load(f)
317
 
@@ -383,14 +395,14 @@ def _proc_pubmedqa_map(path, writer, paraphraser, opts, sample_limit, stats, cb,
383
 
384
  # Process the item
385
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats)
386
- _commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen)
387
 
388
  # Handle expansion if enabled
389
  if opts.get("expand", True):
390
  for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
391
  rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
392
  _commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa",
393
- instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
394
 
395
  # Increment count only on success
396
  count += 1
 
7
 
8
  from utils.schema import sft_row
9
  from utils import augment as A
10
+ from vi.processing import translate_sft_row, should_translate, log_translation_stats
11
 
12
  # Logger
13
  logger = logging.getLogger("processor")
 
41
  augment_opts: Dict,
42
  sample_limit: Optional[int],
43
  seed: int,
44
+ progress_cb: Optional[Callable[[float, str], None]],
45
+ translator=None
46
  ) -> Tuple[int, Dict]:
47
  random.seed(seed)
48
  stats = {
 
70
  if key in ("healthcaremagic", "icliniq"):
71
  count = _proc_med_dialog(source=key, path=input_path, writer=writer,
72
  paraphraser=paraphraser, opts=augment_opts,
73
+ sample_limit=sample_limit, stats=stats, cb=progress_cb, dedupe_seen=dedupe_seen, translator=translator)
74
  elif key == "pubmedqa_l":
75
+ count = _proc_pubmedqa_l(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen, translator=translator)
76
  elif key == "pubmedqa_u":
77
+ count = _proc_pubmedqa_u(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen, translator=translator)
78
  elif key == "pubmedqa_map":
79
+ count = _proc_pubmedqa_map(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen, translator=translator)
80
  else:
81
  raise ValueError(f"Unknown dataset: {dataset_key}")
82
  logger.info(f"[PROC] End dataset={dataset_key} stats={stats}")
 
137
 
138
  return instr, user, out, applied
139
 
140
+ def _commit_row(writer, source, rid, task, instr, user, out, opts, stats, aug_applied, extra_meta=None, dedupe_seen=None, translator=None):
141
  # Dedup entry
142
  if dedupe_seen is not None:
143
  fp = A.fingerprint(instr, user, out)
 
151
  meta.update(extra_meta)
152
 
153
  row = sft_row(instr, user, out, source=source, rid=rid, task=task, meta=meta)
154
+
155
+ # Apply Vietnamese translation if requested
156
+ if should_translate(opts.get("vietnamese_translation", False), translator):
157
+ try:
158
+ row = translate_sft_row(row, translator)
159
+ meta["vietnamese_translated"] = True
160
+ row["meta"] = meta
161
+ except Exception as e:
162
+ logger.error(f"Failed to translate SFT row: {e}")
163
+
164
  writer.write(row)
165
  stats["written"] += 1
166
  return True
167
 
168
  # ——————————— dataset processors ———————————
169
 
170
+ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None, translator=None):
171
  count = 0
172
  written = 0
173
  for i, obj in enumerate(_iter_json_or_jsonl(path), start=1):
 
196
  applied.append("consistency_flag")
197
 
198
  # 2) If expansion is enabled, add augmented copies
199
+ _commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen, translator=translator)
200
  # Add augmented copies if expand
201
  if opts.get("expand", True):
202
  for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
203
  rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
204
+ _commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
205
 
206
  # Increment count only on success
207
  count += 1
 
217
  logger.info(f"[PROC] {source} done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
218
  return count
219
 
220
+ def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None, translator=None):
221
  with open(path, "r", encoding="utf-8") as f:
222
  data = json.load(f)
223
  count = 0
 
248
 
249
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats)
250
  _commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied,
251
+ extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen, translator=translator)
252
  if opts.get("expand", True):
253
  for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
254
  rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
255
  _commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa",
256
+ instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
257
 
258
  # Increment count only on success
259
  count += 1
 
269
  logger.info(f"[PROC] pubmedqa_l done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
270
  return count
271
 
272
+ def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None, translator=None):
273
  with open(path, "r", encoding="utf-8") as f:
274
  data = json.load(f)
275
  count = 0
 
302
  out = guess.strip()
303
 
304
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats)
305
+ _commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
306
  if opts.get("expand", True):
307
  for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
308
  rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
309
  _commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa",
310
+ instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
311
 
312
  # Increment count only on success
313
  count += 1
 
323
  logger.info(f"[PROC] pubmedqa_u done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
324
  return count
325
 
326
+ def _proc_pubmedqa_map(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None, translator=None):
327
  with open(path, "r", encoding="utf-8") as f:
328
  obj = json.load(f)
329
 
 
395
 
396
  # Process the item
397
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats)
398
+ _commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
399
 
400
  # Handle expansion if enabled
401
  if opts.get("expand", True):
402
  for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
403
  rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
404
  _commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa",
405
+ instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
406
 
407
  # Increment count only on success
408
  count += 1
utils/rag.py CHANGED
@@ -7,6 +7,7 @@ from typing import Dict, List, Tuple, Optional, Callable
7
 
8
  from utils.schema import sft_row
9
  from utils.llm import NvidiaClient, KeyRotator
 
10
 
11
  # Logger
12
  logger = logging.getLogger("rag_processor")
@@ -165,7 +166,7 @@ class RAGProcessor:
165
  return ""
166
 
167
  def process_medical_dialog(self, source: str, path: str, writer, sample_limit: Optional[int],
168
- stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None) -> int:
169
  """Process medical dialogue datasets into RAG format"""
170
  count = 0
171
  written = 0
@@ -199,7 +200,7 @@ class RAGProcessor:
199
  # Commit the RAG-formatted row
200
  if self._commit_rag_row(writer, source, rid, "rag_medical_qa",
201
  rag_instruction, rag_user, answer,
202
- stats, dedupe_seen=dedupe_seen):
203
  written += 1
204
 
205
  count += 1
@@ -220,7 +221,7 @@ class RAGProcessor:
220
  return count
221
 
222
  def process_pubmedqa(self, source: str, path: str, writer, sample_limit: Optional[int],
223
- stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None) -> int:
224
  """Process PubMedQA datasets into RAG format"""
225
  with open(path, "r", encoding="utf-8") as f:
226
  data = json.load(f)
@@ -265,7 +266,7 @@ class RAGProcessor:
265
  # Commit the RAG-formatted row
266
  if self._commit_rag_row(writer, source, rid, "rag_biomedical_qa",
267
  rag_instruction, rag_user, answer,
268
- stats, dedupe_seen=dedupe_seen):
269
  written += 1
270
 
271
  count += 1
@@ -287,7 +288,7 @@ class RAGProcessor:
287
 
288
  def _commit_rag_row(self, writer, source: str, rid: str, task: str,
289
  instruction: str, user_input: str, output: str,
290
- stats: Dict, dedupe_seen: set = None) -> bool:
291
  """Commit a RAG-formatted row to the writer"""
292
  # Simple deduplication based on content hash
293
  if dedupe_seen is not None:
@@ -299,6 +300,16 @@ class RAGProcessor:
299
 
300
  meta = {"rag_processing": True, "format": "qca"}
301
  row = sft_row(instruction, user_input, output, source=source, rid=rid, task=task, meta=meta)
 
 
 
 
 
 
 
 
 
 
302
  writer.write(row)
303
  stats["written"] = stats.get("written", 0) + 1
304
  return True
@@ -310,7 +321,8 @@ def process_file_into_rag(
310
  nvidia_model: str,
311
  sample_limit: Optional[int],
312
  seed: int,
313
- progress_cb: Optional[Callable[[float, str], None]]
 
314
  ) -> Tuple[int, Dict]:
315
  """Main entry point for RAG processing"""
316
  random.seed(seed)
@@ -326,17 +338,20 @@ def process_file_into_rag(
326
  dedupe_seen = set()
327
 
328
  key = dataset_key.lower()
 
 
 
329
  if key in ("healthcaremagic", "icliniq"):
330
  count = rag_processor.process_medical_dialog(
331
  source=key, path=input_path, writer=writer,
332
  sample_limit=sample_limit, stats=stats,
333
- progress_cb=progress_cb, dedupe_seen=dedupe_seen
334
  )
335
  elif key in ("pubmedqa_l", "pubmedqa_u", "pubmedqa_map"):
336
  count = rag_processor.process_pubmedqa(
337
  source=key, path=input_path, writer=writer,
338
  sample_limit=sample_limit, stats=stats,
339
- progress_cb=progress_cb, dedupe_seen=dedupe_seen
340
  )
341
  else:
342
  raise ValueError(f"Unknown dataset for RAG processing: {dataset_key}")
 
7
 
8
  from utils.schema import sft_row
9
  from utils.llm import NvidiaClient, KeyRotator
10
+ from vi.processing import translate_rag_row, should_translate, log_translation_stats
11
 
12
  # Logger
13
  logger = logging.getLogger("rag_processor")
 
166
  return ""
167
 
168
  def process_medical_dialog(self, source: str, path: str, writer, sample_limit: Optional[int],
169
+ stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None, translator=None, opts=None) -> int:
170
  """Process medical dialogue datasets into RAG format"""
171
  count = 0
172
  written = 0
 
200
  # Commit the RAG-formatted row
201
  if self._commit_rag_row(writer, source, rid, "rag_medical_qa",
202
  rag_instruction, rag_user, answer,
203
+ stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
204
  written += 1
205
 
206
  count += 1
 
221
  return count
222
 
223
  def process_pubmedqa(self, source: str, path: str, writer, sample_limit: Optional[int],
224
+ stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None, translator=None, opts=None) -> int:
225
  """Process PubMedQA datasets into RAG format"""
226
  with open(path, "r", encoding="utf-8") as f:
227
  data = json.load(f)
 
266
  # Commit the RAG-formatted row
267
  if self._commit_rag_row(writer, source, rid, "rag_biomedical_qa",
268
  rag_instruction, rag_user, answer,
269
+ stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
270
  written += 1
271
 
272
  count += 1
 
288
 
289
  def _commit_rag_row(self, writer, source: str, rid: str, task: str,
290
  instruction: str, user_input: str, output: str,
291
+ stats: Dict, dedupe_seen: set = None, translator=None, opts=None) -> bool:
292
  """Commit a RAG-formatted row to the writer"""
293
  # Simple deduplication based on content hash
294
  if dedupe_seen is not None:
 
300
 
301
  meta = {"rag_processing": True, "format": "qca"}
302
  row = sft_row(instruction, user_input, output, source=source, rid=rid, task=task, meta=meta)
303
+
304
+ # Apply Vietnamese translation if requested
305
+ if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
306
+ try:
307
+ row = translate_rag_row(row, translator)
308
+ meta["vietnamese_translated"] = True
309
+ row["meta"] = meta
310
+ except Exception as e:
311
+ logger.error(f"Failed to translate RAG row: {e}")
312
+
313
  writer.write(row)
314
  stats["written"] = stats.get("written", 0) + 1
315
  return True
 
321
  nvidia_model: str,
322
  sample_limit: Optional[int],
323
  seed: int,
324
+ progress_cb: Optional[Callable[[float, str], None]],
325
+ translator=None
326
  ) -> Tuple[int, Dict]:
327
  """Main entry point for RAG processing"""
328
  random.seed(seed)
 
338
  dedupe_seen = set()
339
 
340
  key = dataset_key.lower()
341
+ # Create opts with Vietnamese translation flag
342
+ opts = {"vietnamese_translation": translator is not None}
343
+
344
  if key in ("healthcaremagic", "icliniq"):
345
  count = rag_processor.process_medical_dialog(
346
  source=key, path=input_path, writer=writer,
347
  sample_limit=sample_limit, stats=stats,
348
+ progress_cb=progress_cb, dedupe_seen=dedupe_seen, translator=translator, opts=opts
349
  )
350
  elif key in ("pubmedqa_l", "pubmedqa_u", "pubmedqa_map"):
351
  count = rag_processor.process_pubmedqa(
352
  source=key, path=input_path, writer=writer,
353
  sample_limit=sample_limit, stats=stats,
354
+ progress_cb=progress_cb, dedupe_seen=dedupe_seen, translator=translator, opts=opts
355
  )
356
  else:
357
  raise ValueError(f"Unknown dataset for RAG processing: {dataset_key}")
vi/README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vietnamese Translation Module
2
+
3
+ This module provides Vietnamese translation functionality for the MedAI Processing application using the Helsinki-NLP/opus-mt-en-vi model.
4
+
5
+ ## Features
6
+
7
+ - **English to Vietnamese Translation**: Translates English text to Vietnamese using the Helsinki-NLP/opus-mt-en-vi model
8
+ - **Batch Processing**: Efficiently translates multiple texts at once
9
+ - **Dictionary Translation**: Translates specific fields in data dictionaries
10
+ - **Integration**: Seamlessly integrates with both SFT and RAG processing workflows
11
+ - **Error Handling**: Graceful fallback to original text if translation fails
12
+ - **Logging**: Comprehensive logging for debugging and monitoring
13
+
14
+ ## Configuration
15
+
16
+ Add the following environment variable to your `.env` file:
17
+
18
+ ```bash
19
+ EN_VI=Helsinki-NLP/opus-mt-en-vi
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ ### Basic Translation
25
+
26
+ ```python
27
+ from vi.translator import VietnameseTranslator
28
+
29
+ # Initialize translator
30
+ translator = VietnameseTranslator()
31
+
32
+ # Load the model
33
+ translator.load_model()
34
+
35
+ # Translate single text
36
+ translated = translator.translate_text("Hello, how are you?")
37
+
38
+ # Translate batch of texts
39
+ texts = ["Text 1", "Text 2", "Text 3"]
40
+ translated_batch = translator.translate_batch(texts)
41
+ ```
42
+
43
+ ### Dictionary Translation
44
+
45
+ ```python
46
+ # Translate specific fields in a dictionary
47
+ data = {
48
+ "instruction": "Answer the question",
49
+ "input": "What is diabetes?",
50
+ "output": "Diabetes is a metabolic disorder..."
51
+ }
52
+
53
+ translated_data = translator.translate_dict(data, ["instruction", "input", "output"])
54
+ ```
55
+
56
+ ## Integration
57
+
58
+ The translation functionality is automatically integrated into the processing workflows:
59
+
60
+ 1. **UI Toggle**: Users can enable Vietnamese translation via the checkbox in the web interface
61
+ 2. **SFT Processing**: All text fields in SFT format are translated when enabled
62
+ 3. **RAG Processing**: All text fields in RAG format are translated when enabled
63
+ 4. **Metadata**: Translated rows are marked with `vietnamese_translated: true` in metadata
64
+
65
+ ## Model Information
66
+
67
+ - **Model**: Helsinki-NLP/opus-mt-en-vi
68
+ - **Source Language**: English
69
+ - **Target Language**: Vietnamese
70
+ - **BLEU Score**: 37.2
71
+ - **chrF Score**: 0.542
72
+ - **License**: Apache 2.0
73
+
74
+ ## Testing
75
+
76
+ Run the test script to verify translation functionality:
77
+
78
+ ```bash
79
+ python test_translation.py
80
+ ```
81
+
82
+ ## Files
83
+
84
+ - `translator.py`: Main translation class
85
+ - `download.py`: Model download script for Docker
86
+ - `processing_utils.py`: Utility functions for processing integration
87
+ - `__init__.py`: Module initialization
88
+ - `README.md`: This documentation
89
+
90
+ ## Notes
91
+
92
+ - The model is automatically downloaded during Docker build
93
+ - Translation is performed on the CPU by default, but can use GPU if available
94
+ - The model requires the target language token `>>vie<<` for proper translation
95
+ - All translation operations include comprehensive error handling and logging
vi/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vietnamese Translation Module
3
+
4
+ This module provides utilities for translating English text to Vietnamese
5
+ using the Helsinki-NLP/opus-mt-en-vi model from Hugging Face.
6
+ """
7
+
8
+ from .translator import VietnameseTranslator
9
+
10
+ __all__ = ['VietnameseTranslator']
vi/download.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Download Script for Vietnamese Translation
3
+
4
+ This script downloads the Helsinki-NLP/opus-mt-en-vi model
5
+ and saves it to the Hugging Face cache directory.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import logging
11
+ from pathlib import Path
12
+ from transformers import MarianMTModel, MarianTokenizer
13
+
14
+ # Setup logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def download_model(model_name: str = "Helsinki-NLP/opus-mt-en-vi", cache_dir: str = None):
22
+ """
23
+ Download the translation model and tokenizer.
24
+
25
+ Args:
26
+ model_name: Hugging Face model name
27
+ cache_dir: Cache directory for the model. If None, uses HF_HOME env var
28
+ """
29
+ if cache_dir is None:
30
+ cache_dir = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
31
+
32
+ logger.info(f"Downloading model: {model_name}")
33
+ logger.info(f"Cache directory: {cache_dir}")
34
+
35
+ try:
36
+ # Ensure cache directory exists
37
+ os.makedirs(cache_dir, exist_ok=True)
38
+
39
+ # Download tokenizer
40
+ logger.info("Downloading tokenizer...")
41
+ tokenizer = MarianTokenizer.from_pretrained(
42
+ model_name,
43
+ cache_dir=cache_dir
44
+ )
45
+ logger.info("✅ Tokenizer downloaded successfully")
46
+
47
+ # Download model
48
+ logger.info("Downloading model...")
49
+ model = MarianMTModel.from_pretrained(
50
+ model_name,
51
+ cache_dir=cache_dir
52
+ )
53
+ logger.info("✅ Model downloaded successfully")
54
+
55
+ # Test the model
56
+ logger.info("Testing model...")
57
+ test_text = "Hello, how are you?"
58
+ inputs = tokenizer(f">>vie<< {test_text}", return_tensors="pt")
59
+ with model.eval():
60
+ outputs = model.generate(**inputs, max_length=50, num_beams=4)
61
+ translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+ logger.info(f"Test translation: '{test_text}' -> '{translated}'")
63
+
64
+ logger.info("🎉 Model download and test completed successfully!")
65
+ return True
66
+
67
+ except Exception as e:
68
+ logger.error(f"❌ Failed to download model: {e}")
69
+ return False
70
+
71
+ def main():
72
+ """Main function to download the model."""
73
+ # Get model name from environment variable or use default
74
+ model_name = os.getenv("EN_VI", "Helsinki-NLP/opus-mt-en-vi")
75
+
76
+ logger.info("Starting model download process...")
77
+ logger.info(f"Model: {model_name}")
78
+
79
+ success = download_model(model_name)
80
+
81
+ if success:
82
+ logger.info("Model download completed successfully!")
83
+ sys.exit(0)
84
+ else:
85
+ logger.error("Model download failed!")
86
+ sys.exit(1)
87
+
88
+ if __name__ == "__main__":
89
+ main()
vi/processing.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processing utilities for Vietnamese translation integration
3
+ """
4
+
5
+ import logging
6
+ from typing import Dict, Any, List, Optional, Callable
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
11
+ """
12
+ Translate specific text fields in an SFT row from English to Vietnamese.
13
+
14
+ Args:
15
+ row: SFT row dictionary
16
+ translator: VietnameseTranslator instance
17
+ text_fields: List of field names to translate. If None, uses default fields.
18
+
19
+ Returns:
20
+ Translated SFT row dictionary
21
+ """
22
+ if not translator or not translator.is_loaded():
23
+ logger.warning("Translator not available, skipping translation")
24
+ return row
25
+
26
+ if text_fields is None:
27
+ # Default fields to translate in SFT format
28
+ text_fields = ["instruction", "input", "output"]
29
+
30
+ try:
31
+ translated_row = translator.translate_dict(row, text_fields)
32
+ logger.debug(f"Translated SFT row with fields: {text_fields}")
33
+ return translated_row
34
+ except Exception as e:
35
+ logger.error(f"Failed to translate SFT row: {e}")
36
+ return row
37
+
38
+ def translate_rag_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
39
+ """
40
+ Translate specific text fields in a RAG row from English to Vietnamese.
41
+
42
+ Args:
43
+ row: RAG row dictionary
44
+ translator: VietnameseTranslator instance
45
+ text_fields: List of field names to translate. If None, uses default fields.
46
+
47
+ Returns:
48
+ Translated RAG row dictionary
49
+ """
50
+ if not translator or not translator.is_loaded():
51
+ logger.warning("Translator not available, skipping translation")
52
+ return row
53
+
54
+ if text_fields is None:
55
+ # Default fields to translate in RAG format
56
+ text_fields = ["instruction", "input", "output"]
57
+
58
+ try:
59
+ translated_row = translator.translate_dict(row, text_fields)
60
+ logger.debug(f"Translated RAG row with fields: {text_fields}")
61
+ return translated_row
62
+ except Exception as e:
63
+ logger.error(f"Failed to translate RAG row: {e}")
64
+ return row
65
+
66
+ def should_translate(vietnamese_translation: bool, translator) -> bool:
67
+ """
68
+ Check if translation should be performed.
69
+
70
+ Args:
71
+ vietnamese_translation: Flag from user input
72
+ translator: VietnameseTranslator instance
73
+
74
+ Returns:
75
+ True if translation should be performed
76
+ """
77
+ if not vietnamese_translation:
78
+ return False
79
+
80
+ if not translator or not translator.is_loaded():
81
+ logger.warning("Vietnamese translation requested but translator not available")
82
+ return False
83
+
84
+ return True
85
+
86
+ def log_translation_stats(stats: Dict[str, Any], translated_count: int) -> None:
87
+ """
88
+ Log translation statistics.
89
+
90
+ Args:
91
+ stats: Statistics dictionary to update
92
+ translated_count: Number of items translated
93
+ """
94
+ stats["vietnamese_translated"] = translated_count
95
+ logger.info(f"Vietnamese translation completed: {translated_count} items translated")
vi/translator.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vietnamese Translator using Helsinki-NLP/opus-mt-en-vi model
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ from typing import List, Dict, Any, Optional, Union
8
+ from transformers import MarianMTModel, MarianTokenizer
9
+ import torch
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class VietnameseTranslator:
14
+ """
15
+ Vietnamese translator using Helsinki-NLP/opus-mt-en-vi model.
16
+
17
+ This class handles translation from English to Vietnamese using the
18
+ MarianMT model from Hugging Face Transformers.
19
+ """
20
+
21
+ def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None):
22
+ """
23
+ Initialize the Vietnamese translator.
24
+
25
+ Args:
26
+ model_name: Hugging Face model name. Defaults to EN_VI env var or Helsinki-NLP/opus-mt-en-vi
27
+ device: Device to run the model on ('cpu', 'cuda', 'auto'). Defaults to 'auto'
28
+ """
29
+ self.model_name = model_name or os.getenv("EN_VI", "Helsinki-NLP/opus-mt-en-vi")
30
+ self.device = self._get_device(device)
31
+ self.model = None
32
+ self.tokenizer = None
33
+ self._is_loaded = False
34
+
35
+ logger.info(f"VietnameseTranslator initialized with model: {self.model_name}")
36
+ logger.info(f"Using device: {self.device}")
37
+
38
+ def _get_device(self, device: Optional[str]) -> str:
39
+ """Determine the best device to use for the model."""
40
+ if device:
41
+ return device
42
+
43
+ if torch.cuda.is_available():
44
+ return "cuda"
45
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
46
+ return "mps"
47
+ else:
48
+ return "cpu"
49
+
50
+ def load_model(self) -> None:
51
+ """Load the translation model and tokenizer."""
52
+ if self._is_loaded:
53
+ logger.debug("Model already loaded, skipping...")
54
+ return
55
+
56
+ try:
57
+ logger.info(f"Loading translation model: {self.model_name}")
58
+ logger.info(f"Loading on device: {self.device}")
59
+
60
+ # Load tokenizer
61
+ self.tokenizer = MarianTokenizer.from_pretrained(
62
+ self.model_name,
63
+ cache_dir=os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
64
+ )
65
+
66
+ # Load model
67
+ self.model = MarianMTModel.from_pretrained(
68
+ self.model_name,
69
+ cache_dir=os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
70
+ )
71
+
72
+ # Move model to device
73
+ self.model = self.model.to(self.device)
74
+ self.model.eval()
75
+
76
+ self._is_loaded = True
77
+ logger.info("✅ Translation model loaded successfully")
78
+
79
+ except Exception as e:
80
+ logger.error(f"❌ Failed to load translation model: {e}")
81
+ raise RuntimeError(f"Failed to load translation model: {e}")
82
+
83
+ def translate_text(self, text: str) -> str:
84
+ """
85
+ Translate a single text from English to Vietnamese.
86
+
87
+ Args:
88
+ text: English text to translate
89
+
90
+ Returns:
91
+ Translated Vietnamese text
92
+ """
93
+ if not self._is_loaded:
94
+ self.load_model()
95
+
96
+ if not text or not text.strip():
97
+ return text
98
+
99
+ try:
100
+ # Prepare input with target language token
101
+ # The model requires a target language token in the format >>id<<
102
+ input_text = f">>vie<< {text.strip()}"
103
+
104
+ # Tokenize
105
+ inputs = self.tokenizer(
106
+ input_text,
107
+ return_tensors="pt",
108
+ padding=True,
109
+ truncation=True,
110
+ max_length=512
111
+ ).to(self.device)
112
+
113
+ # Translate
114
+ with torch.no_grad():
115
+ outputs = self.model.generate(
116
+ **inputs,
117
+ max_length=512,
118
+ num_beams=4,
119
+ early_stopping=True,
120
+ do_sample=False
121
+ )
122
+
123
+ # Decode
124
+ translated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
125
+
126
+ logger.debug(f"Translated: '{text[:50]}...' -> '{translated[:50]}...'")
127
+ return translated.strip()
128
+
129
+ except Exception as e:
130
+ logger.error(f"Translation failed for text: '{text[:100]}...' - Error: {e}")
131
+ # Return original text if translation fails
132
+ return text
133
+
134
+ def translate_batch(self, texts: List[str], batch_size: int = 8) -> List[str]:
135
+ """
136
+ Translate a batch of texts from English to Vietnamese.
137
+
138
+ Args:
139
+ texts: List of English texts to translate
140
+ batch_size: Number of texts to process in each batch
141
+
142
+ Returns:
143
+ List of translated Vietnamese texts
144
+ """
145
+ if not self._is_loaded:
146
+ self.load_model()
147
+
148
+ if not texts:
149
+ return []
150
+
151
+ results = []
152
+
153
+ try:
154
+ for i in range(0, len(texts), batch_size):
155
+ batch = texts[i:i + batch_size]
156
+ logger.debug(f"Processing batch {i//batch_size + 1}/{(len(texts) + batch_size - 1)//batch_size}")
157
+
158
+ # Prepare batch with target language tokens
159
+ batch_inputs = [f">>vie<< {text.strip()}" for text in batch]
160
+
161
+ # Tokenize batch
162
+ inputs = self.tokenizer(
163
+ batch_inputs,
164
+ return_tensors="pt",
165
+ padding=True,
166
+ truncation=True,
167
+ max_length=512
168
+ ).to(self.device)
169
+
170
+ # Translate batch
171
+ with torch.no_grad():
172
+ outputs = self.model.generate(
173
+ **inputs,
174
+ max_length=512,
175
+ num_beams=4,
176
+ early_stopping=True,
177
+ do_sample=False
178
+ )
179
+
180
+ # Decode batch
181
+ batch_translations = [
182
+ self.tokenizer.decode(output, skip_special_tokens=True).strip()
183
+ for output in outputs
184
+ ]
185
+
186
+ results.extend(batch_translations)
187
+
188
+ except Exception as e:
189
+ logger.error(f"Batch translation failed: {e}")
190
+ # Return original texts if translation fails
191
+ results = texts
192
+
193
+ logger.info(f"Translated {len(texts)} texts successfully")
194
+ return results
195
+
196
+ def translate_dict(self, data: Dict[str, Any], text_fields: List[str]) -> Dict[str, Any]:
197
+ """
198
+ Translate specific text fields in a dictionary from English to Vietnamese.
199
+
200
+ Args:
201
+ data: Dictionary containing the data
202
+ text_fields: List of field names to translate
203
+
204
+ Returns:
205
+ Dictionary with translated text fields
206
+ """
207
+ if not self._is_loaded:
208
+ self.load_model()
209
+
210
+ result = data.copy()
211
+
212
+ for field in text_fields:
213
+ if field in data and isinstance(data[field], str) and data[field].strip():
214
+ try:
215
+ result[field] = self.translate_text(data[field])
216
+ logger.debug(f"Translated field '{field}': '{data[field][:50]}...' -> '{result[field][:50]}...'")
217
+ except Exception as e:
218
+ logger.error(f"Failed to translate field '{field}': {e}")
219
+ # Keep original text if translation fails
220
+ result[field] = data[field]
221
+
222
+ return result
223
+
224
+ def translate_list_of_dicts(self, data_list: List[Dict[str, Any]], text_fields: List[str]) -> List[Dict[str, Any]]:
225
+ """
226
+ Translate specific text fields in a list of dictionaries.
227
+
228
+ Args:
229
+ data_list: List of dictionaries containing the data
230
+ text_fields: List of field names to translate in each dictionary
231
+
232
+ Returns:
233
+ List of dictionaries with translated text fields
234
+ """
235
+ if not data_list:
236
+ return []
237
+
238
+ logger.info(f"Translating {len(data_list)} items with fields: {text_fields}")
239
+
240
+ results = []
241
+ for i, data in enumerate(data_list):
242
+ try:
243
+ translated_data = self.translate_dict(data, text_fields)
244
+ results.append(translated_data)
245
+
246
+ if (i + 1) % 100 == 0:
247
+ logger.info(f"Translated {i + 1}/{len(data_list)} items")
248
+
249
+ except Exception as e:
250
+ logger.error(f"Failed to translate item {i}: {e}")
251
+ results.append(data) # Keep original data if translation fails
252
+
253
+ logger.info(f"Completed translation of {len(data_list)} items")
254
+ return results
255
+
256
+ def is_loaded(self) -> bool:
257
+ """Check if the model is loaded."""
258
+ return self._is_loaded
259
+
260
+ def get_model_info(self) -> Dict[str, str]:
261
+ """Get information about the loaded model."""
262
+ return {
263
+ "model_name": self.model_name,
264
+ "device": self.device,
265
+ "is_loaded": self._is_loaded
266
+ }