LiamKhoaLe commited on
Commit
5dcfc82
·
1 Parent(s): a7fd3ba

Enrich augmentation with different QA variants. Ensure Vnmese output, add graceful fallback

Browse files
Files changed (5) hide show
  1. app.py +2 -1
  2. utils/augment.py +49 -0
  3. utils/processor.py +120 -10
  4. utils/rag.py +24 -5
  5. vi/processing.py +44 -4
app.py CHANGED
@@ -408,7 +408,8 @@ def _run_job(dataset_key: str, params: ProcessParams):
408
  sample_limit=params.sample_limit,
409
  seed=params.seed,
410
  progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
411
- translator=translator
 
412
  )
413
  else:
414
  # Standard SFT processing mode
 
408
  sample_limit=params.sample_limit,
409
  seed=params.seed,
410
  progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
411
+ translator=translator,
412
+ paraphraser=paraphraser
413
  )
414
  else:
415
  # Standard SFT processing mode
utils/augment.py CHANGED
@@ -118,3 +118,52 @@ def consistency_ok(user: str, out: str, ratio: float, paraphraser) -> bool:
118
  if random.random() >= ratio:
119
  return True
120
  return paraphraser.consistency_check(user, out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if random.random() >= ratio:
119
  return True
120
  return paraphraser.consistency_check(user, out)
121
+
122
+ def is_invalid_response(text: str) -> bool:
123
+ """Check if model response is invalid (Fail, Invalid, etc.)"""
124
+ if not text or not isinstance(text, str):
125
+ return True
126
+
127
+ text_lower = text.lower().strip()
128
+ invalid_patterns = [
129
+ "fail", "invalid", "i couldn't", "i can't", "i cannot", "unable to",
130
+ "sorry", "error", "not available", "no answer", "insufficient",
131
+ "don't know", "do not know", "not sure", "cannot determine",
132
+ "unable to provide", "not possible", "not applicable", "n/a"
133
+ ]
134
+
135
+ # Check if response is too short or matches invalid patterns
136
+ if len(text_lower) < 3:
137
+ return True
138
+
139
+ for pattern in invalid_patterns:
140
+ if pattern in text_lower:
141
+ return True
142
+
143
+ return False
144
+
145
+ def clean_invalid_response(text: str, fallback: str = "") -> str:
146
+ """Clean invalid responses by returning fallback or empty string"""
147
+ if is_invalid_response(text):
148
+ return fallback
149
+ return text
150
+
151
+ def retry_invalid_response(text: str, paraphraser, max_retries: int = 3) -> str:
152
+ """Retry generating valid response for invalid text, max 3 retries"""
153
+ if not is_invalid_response(text):
154
+ return text
155
+
156
+ for attempt in range(max_retries):
157
+ try:
158
+ # Try paraphrasing with different difficulty levels
159
+ difficulty = "easy" if attempt == 0 else "hard" if attempt == 1 else "easy"
160
+ retry_text = paraphraser.paraphrase(text, difficulty=difficulty)
161
+
162
+ if retry_text and not is_invalid_response(retry_text):
163
+ return retry_text
164
+ except Exception as e:
165
+ logger.warning(f"Retry attempt {attempt + 1} failed: {e}")
166
+ continue
167
+
168
+ # If all retries failed, return empty string to indicate drop
169
+ return ""
utils/processor.py CHANGED
@@ -113,6 +113,77 @@ def _build_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict):
113
  variants.append((u3, o3, applied))
114
  return variants
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
117
  # Base cleanup & caps (returns cleaned strings)
118
  user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
@@ -126,6 +197,13 @@ def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphr
126
  # Stack list of entries that has been applied augmentation and stylings
127
  applied = []
128
 
 
 
 
 
 
 
 
129
  # Style standardizing the answer
130
  if opts.get("style_standardize", True):
131
  out = A.style_standardize_answer(out)
@@ -188,6 +266,11 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
188
  try:
189
  instr, user, out, applied = _apply_aug(instr, user, out, source, opts, paraphraser, stats)
190
 
 
 
 
 
 
191
  # 1) ALWAYS write the original (cleaned/style-standardised only)
192
  # Optional consistency spot-check (cheap)
193
  if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
@@ -195,12 +278,15 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
195
  # keep the sample but tag it
196
  applied.append("consistency_flag")
197
 
198
- # 2) If expansion is enabled, add augmented copies
199
  _commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen, translator=translator)
200
- # Add augmented copies if expand
 
201
  if opts.get("expand", True):
202
- for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
203
- rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
 
 
204
  _commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
205
 
206
  # Increment count only on success
@@ -247,11 +333,19 @@ def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, d
247
  rid = str(k)
248
 
249
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats)
 
 
 
 
 
 
250
  _commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied,
251
  extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen, translator=translator)
252
  if opts.get("expand", True):
253
- for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
254
- rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
 
 
255
  _commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa",
256
  instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
257
 
@@ -302,10 +396,18 @@ def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, d
302
  out = guess.strip()
303
 
304
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats)
 
 
 
 
 
 
305
  _commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
306
  if opts.get("expand", True):
307
- for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
308
- rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
 
 
309
  _commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa",
310
  instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
311
 
@@ -395,12 +497,20 @@ def _proc_pubmedqa_map(path, writer, paraphraser, opts, sample_limit, stats, cb,
395
 
396
  # Process the item
397
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats)
 
 
 
 
 
 
398
  _commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
399
 
400
  # Handle expansion if enabled
401
  if opts.get("expand", True):
402
- for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
403
- rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
 
 
404
  _commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa",
405
  instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
406
 
 
113
  variants.append((u3, o3, applied))
114
  return variants
115
 
116
+ def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict, translator=None):
117
+ """Build multiple paraphrased variants for SFT enrichment (2-3 answers per question, 2-3 questions per answer)"""
118
+ variants = []
119
+
120
+ # Generate 2-3 different answers for the same question
121
+ answer_variants = []
122
+ for i in range(3):
123
+ if i == 0:
124
+ # Original answer
125
+ answer_variants.append((out, ["original_answer"]))
126
+ else:
127
+ # Paraphrased answers with different difficulties
128
+ difficulty = "easy" if i == 1 else "hard"
129
+ try:
130
+ paraphrased_out = paraphraser.paraphrase(out, difficulty=difficulty)
131
+ if paraphrased_out and not A.is_invalid_response(paraphrased_out):
132
+ if opts.get("style_standardize", True):
133
+ paraphrased_out = A.style_standardize_answer(paraphrased_out)
134
+ paraphrased_out = A.ensure_terminal_punct(paraphrased_out)
135
+ answer_variants.append((paraphrased_out, [f"paraphrase_answer_{difficulty}"]))
136
+ stats["paraphrased_output"] += 1
137
+ except Exception as e:
138
+ logger.warning(f"Failed to paraphrase answer variant {i}: {e}")
139
+ continue
140
+
141
+ # Generate 2-3 different questions for the same answer
142
+ question_variants = []
143
+ for i in range(3):
144
+ if i == 0:
145
+ # Original question
146
+ question_variants.append((user, ["original_question"]))
147
+ else:
148
+ # Paraphrased questions with different difficulties
149
+ difficulty = "easy" if i == 1 else "hard"
150
+ try:
151
+ paraphrased_user = paraphraser.paraphrase(user, difficulty=difficulty)
152
+ if paraphrased_user and not A.is_invalid_response(paraphrased_user):
153
+ paraphrased_user = A.ensure_terminal_punct(paraphrased_user)
154
+ question_variants.append((paraphrased_user, [f"paraphrase_question_{difficulty}"]))
155
+ stats["paraphrased_input"] += 1
156
+ except Exception as e:
157
+ logger.warning(f"Failed to paraphrase question variant {i}: {e}")
158
+ continue
159
+
160
+ # Create combinations: each question variant with each answer variant
161
+ for q_user, q_tags in question_variants:
162
+ for a_out, a_tags in answer_variants:
163
+ combined_tags = q_tags + a_tags
164
+ variants.append((q_user, a_out, combined_tags))
165
+
166
+ # Add Vietnamese variants if translator is available
167
+ if translator and translator.is_loaded():
168
+ vi_variants = []
169
+ for q_user, a_out, tags in variants[:3]: # Limit to first 3 to avoid too many variants
170
+ try:
171
+ # Translate question and answer
172
+ vi_q = translator.translate_text(q_user)
173
+ vi_a = translator.translate_text(a_out)
174
+
175
+ if vi_q and vi_a and not A.is_invalid_response(vi_q) and not A.is_invalid_response(vi_a):
176
+ vi_tags = tags + ["vietnamese_translated"]
177
+ vi_variants.append((vi_q, vi_a, vi_tags))
178
+ stats["vietnamese_variants"] = stats.get("vietnamese_variants", 0) + 1
179
+ except Exception as e:
180
+ logger.warning(f"Failed to create Vietnamese variant: {e}")
181
+ continue
182
+
183
+ variants.extend(vi_variants)
184
+
185
+ return variants
186
+
187
  def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
188
  # Base cleanup & caps (returns cleaned strings)
189
  user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
 
197
  # Stack list of entries that has been applied augmentation and stylings
198
  applied = []
199
 
200
+ # Clean invalid responses with retry logic
201
+ if A.is_invalid_response(out):
202
+ out = A.retry_invalid_response(out, paraphraser, max_retries=3)
203
+ if not out: # If retry failed, return empty to indicate drop
204
+ return instr, user, "", applied
205
+ applied.append("invalid_response_retried")
206
+
207
  # Style standardizing the answer
208
  if opts.get("style_standardize", True):
209
  out = A.style_standardize_answer(out)
 
266
  try:
267
  instr, user, out, applied = _apply_aug(instr, user, out, source, opts, paraphraser, stats)
268
 
269
+ # Skip if retry failed (empty output)
270
+ if not out:
271
+ stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1
272
+ continue
273
+
274
  # 1) ALWAYS write the original (cleaned/style-standardised only)
275
  # Optional consistency spot-check (cheap)
276
  if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
 
278
  # keep the sample but tag it
279
  applied.append("consistency_flag")
280
 
281
+ # 2) If expansion is enabled, add enriched variants for SFT
282
  _commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen, translator=translator)
283
+
284
+ # Add enriched variants if expand is enabled
285
  if opts.get("expand", True):
286
+ # Use enriched variants for SFT (multiple Q&A combinations)
287
+ enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator)
288
+ for (u_aug, o_aug, aug_tags) in enriched_variants:
289
+ rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
290
  _commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
291
 
292
  # Increment count only on success
 
333
  rid = str(k)
334
 
335
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats)
336
+
337
+ # Skip if retry failed (empty output)
338
+ if not out:
339
+ stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1
340
+ continue
341
+
342
  _commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied,
343
  extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen, translator=translator)
344
  if opts.get("expand", True):
345
+ # Use enriched variants for SFT (multiple Q&A combinations)
346
+ enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator)
347
+ for (u_aug, o_aug, aug_tags) in enriched_variants:
348
+ rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
349
  _commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa",
350
  instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
351
 
 
396
  out = guess.strip()
397
 
398
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats)
399
+
400
+ # Skip if retry failed (empty output)
401
+ if not out:
402
+ stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1
403
+ continue
404
+
405
  _commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
406
  if opts.get("expand", True):
407
+ # Use enriched variants for SFT (multiple Q&A combinations)
408
+ enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator)
409
+ for (u_aug, o_aug, aug_tags) in enriched_variants:
410
+ rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
411
  _commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa",
412
  instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
413
 
 
497
 
498
  # Process the item
499
  instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats)
500
+
501
+ # Skip if retry failed (empty output)
502
+ if not out:
503
+ stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1
504
+ continue
505
+
506
  _commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
507
 
508
  # Handle expansion if enabled
509
  if opts.get("expand", True):
510
+ # Use enriched variants for SFT (multiple Q&A combinations)
511
+ enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator)
512
+ for (u_aug, o_aug, aug_tags) in enriched_variants:
513
+ rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
514
  _commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa",
515
  instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
516
 
utils/rag.py CHANGED
@@ -7,7 +7,8 @@ from typing import Dict, List, Tuple, Optional, Callable
7
 
8
  from utils.schema import sft_row, rag_row
9
  from utils.llm import NvidiaClient, KeyRotator
10
- from vi.processing import should_translate
 
11
 
12
  # Logger
13
  logger = logging.getLogger("rag_processor")
@@ -190,6 +191,15 @@ class RAGProcessor:
190
  # Convert to QCA format
191
  question, context, answer = self.convert_to_qca_format(instr, user, out)
192
 
 
 
 
 
 
 
 
 
 
193
  if not question or not answer:
194
  continue
195
 
@@ -246,6 +256,15 @@ class RAGProcessor:
246
  context = self.clean_conversational_content(context)
247
  answer = self.clean_conversational_content(answer)
248
 
 
 
 
 
 
 
 
 
 
249
  # Generate context if missing
250
  if not context:
251
  context = self.generate_context_from_qa(question, answer)
@@ -289,9 +308,8 @@ class RAGProcessor:
289
  # Apply Vietnamese translation if requested (translate Q/A/C fields directly)
290
  if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
291
  try:
292
- if translator:
293
- row = translator.translate_dict(row, ["question", "answer", "context"])
294
- row["vi_translated"] = True
295
  except Exception as e:
296
  logger.error(f"Failed to translate RAG row: {e}")
297
 
@@ -307,7 +325,8 @@ def process_file_into_rag(
307
  sample_limit: Optional[int],
308
  seed: int,
309
  progress_cb: Optional[Callable[[float, str], None]],
310
- translator=None
 
311
  ) -> Tuple[int, Dict]:
312
  """Main entry point for RAG processing"""
313
  random.seed(seed)
 
7
 
8
  from utils.schema import sft_row, rag_row
9
  from utils.llm import NvidiaClient, KeyRotator
10
+ from vi.processing import should_translate, translate_rag_row
11
+ from utils import augment as A
12
 
13
  # Logger
14
  logger = logging.getLogger("rag_processor")
 
191
  # Convert to QCA format
192
  question, context, answer = self.convert_to_qca_format(instr, user, out)
193
 
194
+ # Clean invalid responses with retry logic
195
+ if A.is_invalid_response(answer):
196
+ if paraphraser:
197
+ answer = A.retry_invalid_response(answer, paraphraser, max_retries=3)
198
+ else:
199
+ answer = A.clean_invalid_response(answer, "")
200
+ if not answer: # If retry failed, skip this sample
201
+ continue
202
+
203
  if not question or not answer:
204
  continue
205
 
 
256
  context = self.clean_conversational_content(context)
257
  answer = self.clean_conversational_content(answer)
258
 
259
+ # Clean invalid responses with retry logic
260
+ if A.is_invalid_response(answer):
261
+ if paraphraser:
262
+ answer = A.retry_invalid_response(answer, paraphraser, max_retries=3)
263
+ else:
264
+ answer = A.clean_invalid_response(answer, "")
265
+ if not answer: # If retry failed, skip this sample
266
+ continue
267
+
268
  # Generate context if missing
269
  if not context:
270
  context = self.generate_context_from_qa(question, answer)
 
308
  # Apply Vietnamese translation if requested (translate Q/A/C fields directly)
309
  if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
310
  try:
311
+ row = translate_rag_row(row, translator, ["question", "answer", "context"])
312
+ row["vi_translated"] = True
 
313
  except Exception as e:
314
  logger.error(f"Failed to translate RAG row: {e}")
315
 
 
325
  sample_limit: Optional[int],
326
  seed: int,
327
  progress_cb: Optional[Callable[[float, str], None]],
328
+ translator=None,
329
+ paraphraser=None
330
  ) -> Tuple[int, Dict]:
331
  """Main entry point for RAG processing"""
332
  random.seed(seed)
vi/processing.py CHANGED
@@ -31,6 +31,34 @@ def _vi_sanitize_text(s: str) -> str:
31
  t = " ".join(filtered)
32
  return t
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
35
  """
36
  Translate specific text fields in an SFT row from English to Vietnamese.
@@ -53,10 +81,16 @@ def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] =
53
 
54
  try:
55
  translated_row = translator.translate_dict(row, text_fields)
56
- # Sanitize translated fields
57
  for f in text_fields:
58
  if f in translated_row.get("sft", {}):
59
- translated_row["sft"][f] = _vi_sanitize_text(translated_row["sft"][f])
 
 
 
 
 
 
60
  logger.debug(f"Translated SFT row with fields: {text_fields}")
61
  return translated_row
62
  except Exception as e:
@@ -85,10 +119,16 @@ def translate_rag_row(row: Dict[str, Any], translator, text_fields: List[str] =
85
 
86
  try:
87
  translated_row = translator.translate_dict(row, text_fields)
88
- # Sanitize translated fields
89
  for f in text_fields:
90
  if f in translated_row:
91
- translated_row[f] = _vi_sanitize_text(translated_row[f])
 
 
 
 
 
 
92
  logger.debug(f"Translated RAG row with fields: {text_fields}")
93
  return translated_row
94
  except Exception as e:
 
31
  t = " ".join(filtered)
32
  return t
33
 
34
+ def _validate_vi_translation(original: str, translated: str) -> bool:
35
+ """Validate Vietnamese translation quality"""
36
+ if not translated or not isinstance(translated, str):
37
+ return False
38
+
39
+ # Check if translation is too short or too different in length
40
+ if len(translated.strip()) < 3:
41
+ return False
42
+
43
+ # Check if translation contains too much English (should be mostly Vietnamese)
44
+ import re
45
+ english_chars = len(re.findall(r'[a-zA-Z]', translated))
46
+ total_chars = len(re.sub(r'\s', '', translated))
47
+ if total_chars > 0 and english_chars / total_chars > 0.7:
48
+ return False
49
+
50
+ # Check for common translation failure patterns
51
+ failure_patterns = [
52
+ "translation", "error", "failed", "unable", "cannot",
53
+ "not available", "not found", "invalid", "error"
54
+ ]
55
+ translated_lower = translated.lower()
56
+ for pattern in failure_patterns:
57
+ if pattern in translated_lower:
58
+ return False
59
+
60
+ return True
61
+
62
  def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
63
  """
64
  Translate specific text fields in an SFT row from English to Vietnamese.
 
81
 
82
  try:
83
  translated_row = translator.translate_dict(row, text_fields)
84
+ # Validate and sanitize translated fields
85
  for f in text_fields:
86
  if f in translated_row.get("sft", {}):
87
+ original = row.get("sft", {}).get(f, "")
88
+ translated = translated_row["sft"][f]
89
+ if _validate_vi_translation(original, translated):
90
+ translated_row["sft"][f] = _vi_sanitize_text(translated)
91
+ else:
92
+ logger.warning(f"Invalid Vietnamese translation for field {f}, keeping original")
93
+ translated_row["sft"][f] = original
94
  logger.debug(f"Translated SFT row with fields: {text_fields}")
95
  return translated_row
96
  except Exception as e:
 
119
 
120
  try:
121
  translated_row = translator.translate_dict(row, text_fields)
122
+ # Validate and sanitize translated fields
123
  for f in text_fields:
124
  if f in translated_row:
125
+ original = row.get(f, "")
126
+ translated = translated_row[f]
127
+ if _validate_vi_translation(original, translated):
128
+ translated_row[f] = _vi_sanitize_text(translated)
129
+ else:
130
+ logger.warning(f"Invalid Vietnamese translation for field {f}, keeping original")
131
+ translated_row[f] = original
132
  logger.debug(f"Translated RAG row with fields: {text_fields}")
133
  return translated_row
134
  except Exception as e: