Spaces:
Sleeping
Sleeping
Commit
·
5dcfc82
1
Parent(s):
a7fd3ba
Enrich augmentation with different QA variants. Ensure Vnmese output, add graceful fallback
Browse files- app.py +2 -1
- utils/augment.py +49 -0
- utils/processor.py +120 -10
- utils/rag.py +24 -5
- vi/processing.py +44 -4
app.py
CHANGED
|
@@ -408,7 +408,8 @@ def _run_job(dataset_key: str, params: ProcessParams):
|
|
| 408 |
sample_limit=params.sample_limit,
|
| 409 |
seed=params.seed,
|
| 410 |
progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
|
| 411 |
-
translator=translator
|
|
|
|
| 412 |
)
|
| 413 |
else:
|
| 414 |
# Standard SFT processing mode
|
|
|
|
| 408 |
sample_limit=params.sample_limit,
|
| 409 |
seed=params.seed,
|
| 410 |
progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
|
| 411 |
+
translator=translator,
|
| 412 |
+
paraphraser=paraphraser
|
| 413 |
)
|
| 414 |
else:
|
| 415 |
# Standard SFT processing mode
|
utils/augment.py
CHANGED
|
@@ -118,3 +118,52 @@ def consistency_ok(user: str, out: str, ratio: float, paraphraser) -> bool:
|
|
| 118 |
if random.random() >= ratio:
|
| 119 |
return True
|
| 120 |
return paraphraser.consistency_check(user, out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
if random.random() >= ratio:
|
| 119 |
return True
|
| 120 |
return paraphraser.consistency_check(user, out)
|
| 121 |
+
|
| 122 |
+
def is_invalid_response(text: str) -> bool:
|
| 123 |
+
"""Check if model response is invalid (Fail, Invalid, etc.)"""
|
| 124 |
+
if not text or not isinstance(text, str):
|
| 125 |
+
return True
|
| 126 |
+
|
| 127 |
+
text_lower = text.lower().strip()
|
| 128 |
+
invalid_patterns = [
|
| 129 |
+
"fail", "invalid", "i couldn't", "i can't", "i cannot", "unable to",
|
| 130 |
+
"sorry", "error", "not available", "no answer", "insufficient",
|
| 131 |
+
"don't know", "do not know", "not sure", "cannot determine",
|
| 132 |
+
"unable to provide", "not possible", "not applicable", "n/a"
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
# Check if response is too short or matches invalid patterns
|
| 136 |
+
if len(text_lower) < 3:
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
for pattern in invalid_patterns:
|
| 140 |
+
if pattern in text_lower:
|
| 141 |
+
return True
|
| 142 |
+
|
| 143 |
+
return False
|
| 144 |
+
|
| 145 |
+
def clean_invalid_response(text: str, fallback: str = "") -> str:
|
| 146 |
+
"""Clean invalid responses by returning fallback or empty string"""
|
| 147 |
+
if is_invalid_response(text):
|
| 148 |
+
return fallback
|
| 149 |
+
return text
|
| 150 |
+
|
| 151 |
+
def retry_invalid_response(text: str, paraphraser, max_retries: int = 3) -> str:
|
| 152 |
+
"""Retry generating valid response for invalid text, max 3 retries"""
|
| 153 |
+
if not is_invalid_response(text):
|
| 154 |
+
return text
|
| 155 |
+
|
| 156 |
+
for attempt in range(max_retries):
|
| 157 |
+
try:
|
| 158 |
+
# Try paraphrasing with different difficulty levels
|
| 159 |
+
difficulty = "easy" if attempt == 0 else "hard" if attempt == 1 else "easy"
|
| 160 |
+
retry_text = paraphraser.paraphrase(text, difficulty=difficulty)
|
| 161 |
+
|
| 162 |
+
if retry_text and not is_invalid_response(retry_text):
|
| 163 |
+
return retry_text
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.warning(f"Retry attempt {attempt + 1} failed: {e}")
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
# If all retries failed, return empty string to indicate drop
|
| 169 |
+
return ""
|
utils/processor.py
CHANGED
|
@@ -113,6 +113,77 @@ def _build_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict):
|
|
| 113 |
variants.append((u3, o3, applied))
|
| 114 |
return variants
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
|
| 117 |
# Base cleanup & caps (returns cleaned strings)
|
| 118 |
user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
|
|
@@ -126,6 +197,13 @@ def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphr
|
|
| 126 |
# Stack list of entries that has been applied augmentation and stylings
|
| 127 |
applied = []
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# Style standardizing the answer
|
| 130 |
if opts.get("style_standardize", True):
|
| 131 |
out = A.style_standardize_answer(out)
|
|
@@ -188,6 +266,11 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
|
|
| 188 |
try:
|
| 189 |
instr, user, out, applied = _apply_aug(instr, user, out, source, opts, paraphraser, stats)
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
# 1) ALWAYS write the original (cleaned/style-standardised only)
|
| 192 |
# Optional consistency spot-check (cheap)
|
| 193 |
if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
|
|
@@ -195,12 +278,15 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
|
|
| 195 |
# keep the sample but tag it
|
| 196 |
applied.append("consistency_flag")
|
| 197 |
|
| 198 |
-
# 2) If expansion is enabled, add
|
| 199 |
_commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen, translator=translator)
|
| 200 |
-
|
|
|
|
| 201 |
if opts.get("expand", True):
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
| 204 |
_commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 205 |
|
| 206 |
# Increment count only on success
|
|
@@ -247,11 +333,19 @@ def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, d
|
|
| 247 |
rid = str(k)
|
| 248 |
|
| 249 |
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
_commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied,
|
| 251 |
extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen, translator=translator)
|
| 252 |
if opts.get("expand", True):
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
| 255 |
_commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa",
|
| 256 |
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 257 |
|
|
@@ -302,10 +396,18 @@ def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, d
|
|
| 302 |
out = guess.strip()
|
| 303 |
|
| 304 |
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
_commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
|
| 306 |
if opts.get("expand", True):
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
| 309 |
_commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa",
|
| 310 |
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 311 |
|
|
@@ -395,12 +497,20 @@ def _proc_pubmedqa_map(path, writer, paraphraser, opts, sample_limit, stats, cb,
|
|
| 395 |
|
| 396 |
# Process the item
|
| 397 |
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
_commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
|
| 399 |
|
| 400 |
# Handle expansion if enabled
|
| 401 |
if opts.get("expand", True):
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
| 404 |
_commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa",
|
| 405 |
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 406 |
|
|
|
|
| 113 |
variants.append((u3, o3, applied))
|
| 114 |
return variants
|
| 115 |
|
| 116 |
+
def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict, translator=None):
|
| 117 |
+
"""Build multiple paraphrased variants for SFT enrichment (2-3 answers per question, 2-3 questions per answer)"""
|
| 118 |
+
variants = []
|
| 119 |
+
|
| 120 |
+
# Generate 2-3 different answers for the same question
|
| 121 |
+
answer_variants = []
|
| 122 |
+
for i in range(3):
|
| 123 |
+
if i == 0:
|
| 124 |
+
# Original answer
|
| 125 |
+
answer_variants.append((out, ["original_answer"]))
|
| 126 |
+
else:
|
| 127 |
+
# Paraphrased answers with different difficulties
|
| 128 |
+
difficulty = "easy" if i == 1 else "hard"
|
| 129 |
+
try:
|
| 130 |
+
paraphrased_out = paraphraser.paraphrase(out, difficulty=difficulty)
|
| 131 |
+
if paraphrased_out and not A.is_invalid_response(paraphrased_out):
|
| 132 |
+
if opts.get("style_standardize", True):
|
| 133 |
+
paraphrased_out = A.style_standardize_answer(paraphrased_out)
|
| 134 |
+
paraphrased_out = A.ensure_terminal_punct(paraphrased_out)
|
| 135 |
+
answer_variants.append((paraphrased_out, [f"paraphrase_answer_{difficulty}"]))
|
| 136 |
+
stats["paraphrased_output"] += 1
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.warning(f"Failed to paraphrase answer variant {i}: {e}")
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
# Generate 2-3 different questions for the same answer
|
| 142 |
+
question_variants = []
|
| 143 |
+
for i in range(3):
|
| 144 |
+
if i == 0:
|
| 145 |
+
# Original question
|
| 146 |
+
question_variants.append((user, ["original_question"]))
|
| 147 |
+
else:
|
| 148 |
+
# Paraphrased questions with different difficulties
|
| 149 |
+
difficulty = "easy" if i == 1 else "hard"
|
| 150 |
+
try:
|
| 151 |
+
paraphrased_user = paraphraser.paraphrase(user, difficulty=difficulty)
|
| 152 |
+
if paraphrased_user and not A.is_invalid_response(paraphrased_user):
|
| 153 |
+
paraphrased_user = A.ensure_terminal_punct(paraphrased_user)
|
| 154 |
+
question_variants.append((paraphrased_user, [f"paraphrase_question_{difficulty}"]))
|
| 155 |
+
stats["paraphrased_input"] += 1
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.warning(f"Failed to paraphrase question variant {i}: {e}")
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
# Create combinations: each question variant with each answer variant
|
| 161 |
+
for q_user, q_tags in question_variants:
|
| 162 |
+
for a_out, a_tags in answer_variants:
|
| 163 |
+
combined_tags = q_tags + a_tags
|
| 164 |
+
variants.append((q_user, a_out, combined_tags))
|
| 165 |
+
|
| 166 |
+
# Add Vietnamese variants if translator is available
|
| 167 |
+
if translator and translator.is_loaded():
|
| 168 |
+
vi_variants = []
|
| 169 |
+
for q_user, a_out, tags in variants[:3]: # Limit to first 3 to avoid too many variants
|
| 170 |
+
try:
|
| 171 |
+
# Translate question and answer
|
| 172 |
+
vi_q = translator.translate_text(q_user)
|
| 173 |
+
vi_a = translator.translate_text(a_out)
|
| 174 |
+
|
| 175 |
+
if vi_q and vi_a and not A.is_invalid_response(vi_q) and not A.is_invalid_response(vi_a):
|
| 176 |
+
vi_tags = tags + ["vietnamese_translated"]
|
| 177 |
+
vi_variants.append((vi_q, vi_a, vi_tags))
|
| 178 |
+
stats["vietnamese_variants"] = stats.get("vietnamese_variants", 0) + 1
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.warning(f"Failed to create Vietnamese variant: {e}")
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
variants.extend(vi_variants)
|
| 184 |
+
|
| 185 |
+
return variants
|
| 186 |
+
|
| 187 |
def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
|
| 188 |
# Base cleanup & caps (returns cleaned strings)
|
| 189 |
user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
|
|
|
|
| 197 |
# Stack list of entries that has been applied augmentation and stylings
|
| 198 |
applied = []
|
| 199 |
|
| 200 |
+
# Clean invalid responses with retry logic
|
| 201 |
+
if A.is_invalid_response(out):
|
| 202 |
+
out = A.retry_invalid_response(out, paraphraser, max_retries=3)
|
| 203 |
+
if not out: # If retry failed, return empty to indicate drop
|
| 204 |
+
return instr, user, "", applied
|
| 205 |
+
applied.append("invalid_response_retried")
|
| 206 |
+
|
| 207 |
# Style standardizing the answer
|
| 208 |
if opts.get("style_standardize", True):
|
| 209 |
out = A.style_standardize_answer(out)
|
|
|
|
| 266 |
try:
|
| 267 |
instr, user, out, applied = _apply_aug(instr, user, out, source, opts, paraphraser, stats)
|
| 268 |
|
| 269 |
+
# Skip if retry failed (empty output)
|
| 270 |
+
if not out:
|
| 271 |
+
stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
# 1) ALWAYS write the original (cleaned/style-standardised only)
|
| 275 |
# Optional consistency spot-check (cheap)
|
| 276 |
if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
|
|
|
|
| 278 |
# keep the sample but tag it
|
| 279 |
applied.append("consistency_flag")
|
| 280 |
|
| 281 |
+
# 2) If expansion is enabled, add enriched variants for SFT
|
| 282 |
_commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen, translator=translator)
|
| 283 |
+
|
| 284 |
+
# Add enriched variants if expand is enabled
|
| 285 |
if opts.get("expand", True):
|
| 286 |
+
# Use enriched variants for SFT (multiple Q&A combinations)
|
| 287 |
+
enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator)
|
| 288 |
+
for (u_aug, o_aug, aug_tags) in enriched_variants:
|
| 289 |
+
rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
|
| 290 |
_commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 291 |
|
| 292 |
# Increment count only on success
|
|
|
|
| 333 |
rid = str(k)
|
| 334 |
|
| 335 |
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats)
|
| 336 |
+
|
| 337 |
+
# Skip if retry failed (empty output)
|
| 338 |
+
if not out:
|
| 339 |
+
stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1
|
| 340 |
+
continue
|
| 341 |
+
|
| 342 |
_commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied,
|
| 343 |
extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen, translator=translator)
|
| 344 |
if opts.get("expand", True):
|
| 345 |
+
# Use enriched variants for SFT (multiple Q&A combinations)
|
| 346 |
+
enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator)
|
| 347 |
+
for (u_aug, o_aug, aug_tags) in enriched_variants:
|
| 348 |
+
rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
|
| 349 |
_commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa",
|
| 350 |
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 351 |
|
|
|
|
| 396 |
out = guess.strip()
|
| 397 |
|
| 398 |
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats)
|
| 399 |
+
|
| 400 |
+
# Skip if retry failed (empty output)
|
| 401 |
+
if not out:
|
| 402 |
+
stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
_commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
|
| 406 |
if opts.get("expand", True):
|
| 407 |
+
# Use enriched variants for SFT (multiple Q&A combinations)
|
| 408 |
+
enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator)
|
| 409 |
+
for (u_aug, o_aug, aug_tags) in enriched_variants:
|
| 410 |
+
rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
|
| 411 |
_commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa",
|
| 412 |
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 413 |
|
|
|
|
| 497 |
|
| 498 |
# Process the item
|
| 499 |
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats)
|
| 500 |
+
|
| 501 |
+
# Skip if retry failed (empty output)
|
| 502 |
+
if not out:
|
| 503 |
+
stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1
|
| 504 |
+
continue
|
| 505 |
+
|
| 506 |
_commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator)
|
| 507 |
|
| 508 |
# Handle expansion if enabled
|
| 509 |
if opts.get("expand", True):
|
| 510 |
+
# Use enriched variants for SFT (multiple Q&A combinations)
|
| 511 |
+
enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator)
|
| 512 |
+
for (u_aug, o_aug, aug_tags) in enriched_variants:
|
| 513 |
+
rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
|
| 514 |
_commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa",
|
| 515 |
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 516 |
|
utils/rag.py
CHANGED
|
@@ -7,7 +7,8 @@ from typing import Dict, List, Tuple, Optional, Callable
|
|
| 7 |
|
| 8 |
from utils.schema import sft_row, rag_row
|
| 9 |
from utils.llm import NvidiaClient, KeyRotator
|
| 10 |
-
from vi.processing import should_translate
|
|
|
|
| 11 |
|
| 12 |
# Logger
|
| 13 |
logger = logging.getLogger("rag_processor")
|
|
@@ -190,6 +191,15 @@ class RAGProcessor:
|
|
| 190 |
# Convert to QCA format
|
| 191 |
question, context, answer = self.convert_to_qca_format(instr, user, out)
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
if not question or not answer:
|
| 194 |
continue
|
| 195 |
|
|
@@ -246,6 +256,15 @@ class RAGProcessor:
|
|
| 246 |
context = self.clean_conversational_content(context)
|
| 247 |
answer = self.clean_conversational_content(answer)
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
# Generate context if missing
|
| 250 |
if not context:
|
| 251 |
context = self.generate_context_from_qa(question, answer)
|
|
@@ -289,9 +308,8 @@ class RAGProcessor:
|
|
| 289 |
# Apply Vietnamese translation if requested (translate Q/A/C fields directly)
|
| 290 |
if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
|
| 291 |
try:
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
row["vi_translated"] = True
|
| 295 |
except Exception as e:
|
| 296 |
logger.error(f"Failed to translate RAG row: {e}")
|
| 297 |
|
|
@@ -307,7 +325,8 @@ def process_file_into_rag(
|
|
| 307 |
sample_limit: Optional[int],
|
| 308 |
seed: int,
|
| 309 |
progress_cb: Optional[Callable[[float, str], None]],
|
| 310 |
-
translator=None
|
|
|
|
| 311 |
) -> Tuple[int, Dict]:
|
| 312 |
"""Main entry point for RAG processing"""
|
| 313 |
random.seed(seed)
|
|
|
|
| 7 |
|
| 8 |
from utils.schema import sft_row, rag_row
|
| 9 |
from utils.llm import NvidiaClient, KeyRotator
|
| 10 |
+
from vi.processing import should_translate, translate_rag_row
|
| 11 |
+
from utils import augment as A
|
| 12 |
|
| 13 |
# Logger
|
| 14 |
logger = logging.getLogger("rag_processor")
|
|
|
|
| 191 |
# Convert to QCA format
|
| 192 |
question, context, answer = self.convert_to_qca_format(instr, user, out)
|
| 193 |
|
| 194 |
+
# Clean invalid responses with retry logic
|
| 195 |
+
if A.is_invalid_response(answer):
|
| 196 |
+
if paraphraser:
|
| 197 |
+
answer = A.retry_invalid_response(answer, paraphraser, max_retries=3)
|
| 198 |
+
else:
|
| 199 |
+
answer = A.clean_invalid_response(answer, "")
|
| 200 |
+
if not answer: # If retry failed, skip this sample
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
if not question or not answer:
|
| 204 |
continue
|
| 205 |
|
|
|
|
| 256 |
context = self.clean_conversational_content(context)
|
| 257 |
answer = self.clean_conversational_content(answer)
|
| 258 |
|
| 259 |
+
# Clean invalid responses with retry logic
|
| 260 |
+
if A.is_invalid_response(answer):
|
| 261 |
+
if paraphraser:
|
| 262 |
+
answer = A.retry_invalid_response(answer, paraphraser, max_retries=3)
|
| 263 |
+
else:
|
| 264 |
+
answer = A.clean_invalid_response(answer, "")
|
| 265 |
+
if not answer: # If retry failed, skip this sample
|
| 266 |
+
continue
|
| 267 |
+
|
| 268 |
# Generate context if missing
|
| 269 |
if not context:
|
| 270 |
context = self.generate_context_from_qa(question, answer)
|
|
|
|
| 308 |
# Apply Vietnamese translation if requested (translate Q/A/C fields directly)
|
| 309 |
if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
|
| 310 |
try:
|
| 311 |
+
row = translate_rag_row(row, translator, ["question", "answer", "context"])
|
| 312 |
+
row["vi_translated"] = True
|
|
|
|
| 313 |
except Exception as e:
|
| 314 |
logger.error(f"Failed to translate RAG row: {e}")
|
| 315 |
|
|
|
|
| 325 |
sample_limit: Optional[int],
|
| 326 |
seed: int,
|
| 327 |
progress_cb: Optional[Callable[[float, str], None]],
|
| 328 |
+
translator=None,
|
| 329 |
+
paraphraser=None
|
| 330 |
) -> Tuple[int, Dict]:
|
| 331 |
"""Main entry point for RAG processing"""
|
| 332 |
random.seed(seed)
|
vi/processing.py
CHANGED
|
@@ -31,6 +31,34 @@ def _vi_sanitize_text(s: str) -> str:
|
|
| 31 |
t = " ".join(filtered)
|
| 32 |
return t
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
|
| 35 |
"""
|
| 36 |
Translate specific text fields in an SFT row from English to Vietnamese.
|
|
@@ -53,10 +81,16 @@ def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] =
|
|
| 53 |
|
| 54 |
try:
|
| 55 |
translated_row = translator.translate_dict(row, text_fields)
|
| 56 |
-
#
|
| 57 |
for f in text_fields:
|
| 58 |
if f in translated_row.get("sft", {}):
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
logger.debug(f"Translated SFT row with fields: {text_fields}")
|
| 61 |
return translated_row
|
| 62 |
except Exception as e:
|
|
@@ -85,10 +119,16 @@ def translate_rag_row(row: Dict[str, Any], translator, text_fields: List[str] =
|
|
| 85 |
|
| 86 |
try:
|
| 87 |
translated_row = translator.translate_dict(row, text_fields)
|
| 88 |
-
#
|
| 89 |
for f in text_fields:
|
| 90 |
if f in translated_row:
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
logger.debug(f"Translated RAG row with fields: {text_fields}")
|
| 93 |
return translated_row
|
| 94 |
except Exception as e:
|
|
|
|
| 31 |
t = " ".join(filtered)
|
| 32 |
return t
|
| 33 |
|
| 34 |
+
def _validate_vi_translation(original: str, translated: str) -> bool:
|
| 35 |
+
"""Validate Vietnamese translation quality"""
|
| 36 |
+
if not translated or not isinstance(translated, str):
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
# Check if translation is too short or too different in length
|
| 40 |
+
if len(translated.strip()) < 3:
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
# Check if translation contains too much English (should be mostly Vietnamese)
|
| 44 |
+
import re
|
| 45 |
+
english_chars = len(re.findall(r'[a-zA-Z]', translated))
|
| 46 |
+
total_chars = len(re.sub(r'\s', '', translated))
|
| 47 |
+
if total_chars > 0 and english_chars / total_chars > 0.7:
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
# Check for common translation failure patterns
|
| 51 |
+
failure_patterns = [
|
| 52 |
+
"translation", "error", "failed", "unable", "cannot",
|
| 53 |
+
"not available", "not found", "invalid", "error"
|
| 54 |
+
]
|
| 55 |
+
translated_lower = translated.lower()
|
| 56 |
+
for pattern in failure_patterns:
|
| 57 |
+
if pattern in translated_lower:
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
return True
|
| 61 |
+
|
| 62 |
def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
|
| 63 |
"""
|
| 64 |
Translate specific text fields in an SFT row from English to Vietnamese.
|
|
|
|
| 81 |
|
| 82 |
try:
|
| 83 |
translated_row = translator.translate_dict(row, text_fields)
|
| 84 |
+
# Validate and sanitize translated fields
|
| 85 |
for f in text_fields:
|
| 86 |
if f in translated_row.get("sft", {}):
|
| 87 |
+
original = row.get("sft", {}).get(f, "")
|
| 88 |
+
translated = translated_row["sft"][f]
|
| 89 |
+
if _validate_vi_translation(original, translated):
|
| 90 |
+
translated_row["sft"][f] = _vi_sanitize_text(translated)
|
| 91 |
+
else:
|
| 92 |
+
logger.warning(f"Invalid Vietnamese translation for field {f}, keeping original")
|
| 93 |
+
translated_row["sft"][f] = original
|
| 94 |
logger.debug(f"Translated SFT row with fields: {text_fields}")
|
| 95 |
return translated_row
|
| 96 |
except Exception as e:
|
|
|
|
| 119 |
|
| 120 |
try:
|
| 121 |
translated_row = translator.translate_dict(row, text_fields)
|
| 122 |
+
# Validate and sanitize translated fields
|
| 123 |
for f in text_fields:
|
| 124 |
if f in translated_row:
|
| 125 |
+
original = row.get(f, "")
|
| 126 |
+
translated = translated_row[f]
|
| 127 |
+
if _validate_vi_translation(original, translated):
|
| 128 |
+
translated_row[f] = _vi_sanitize_text(translated)
|
| 129 |
+
else:
|
| 130 |
+
logger.warning(f"Invalid Vietnamese translation for field {f}, keeping original")
|
| 131 |
+
translated_row[f] = original
|
| 132 |
logger.debug(f"Translated RAG row with fields: {text_fields}")
|
| 133 |
return translated_row
|
| 134 |
except Exception as e:
|