# Dataset-specific parsers + paraphrasing flow import json import random import hashlib import logging from typing import Callable, Optional, Dict, Tuple from utils.schema import sft_row from utils import augment as A from vi.processing import translate_sft_row, should_translate, log_translation_stats # Logger logger = logging.getLogger("processor") if not logger.handlers: logger.setLevel(logging.INFO) logger.addHandler(logging.StreamHandler()) def _hash_id(*parts) -> str: h = hashlib.sha256() for p in parts: h.update(str(p).encode("utf-8")) return h.hexdigest()[:16] def _iter_json_or_jsonl(path: str): with open(path, "r", encoding="utf-8") as f: first = f.read(1); f.seek(0) if first == "[": data = json.load(f) for obj in data: yield obj else: for line in f: line = line.strip() if line: yield json.loads(line) def process_file_into_sft( dataset_key: str, input_path: str, writer, paraphraser, augment_opts: Dict, sample_limit: Optional[int], seed: int, progress_cb: Optional[Callable[[float, str], None]], translator=None ) -> Tuple[int, Dict]: random.seed(seed) stats = { "written": 0, "paraphrased_input": 0, "paraphrased_output": 0, "backtranslated_input": 0, "backtranslated_output": 0, "dedup_skipped": 0, "consistency_failed": 0, "medical_accuracy_failed": 0, "clinical_scenarios_created": 0, "enhanced_terminology": 0, "vietnamese_variants": 0 } # Start processing SFT key_summary = {k: augment_opts.get(k) for k in ( "paraphrase_ratio","backtranslate_ratio","paraphrase_outputs", "style_standardize","deidentify","dedupe", "consistency_check_ratio","distill_fraction" )} logger.info( f"[PROC] Begin dataset={dataset_key} sample_limit={sample_limit} opts={key_summary}" ) # If deduplicating enabled dedupe_seen = set() if augment_opts.get("dedupe", True) else None key = dataset_key.lower() if key in ("healthcaremagic", "icliniq"): count = _proc_med_dialog(source=key, path=input_path, writer=writer, paraphraser=paraphraser, opts=augment_opts, sample_limit=sample_limit, stats=stats, cb=progress_cb, dedupe_seen=dedupe_seen, translator=translator) elif key == "pubmedqa_l": count = _proc_pubmedqa_l(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen, translator=translator) elif key == "pubmedqa_u": count = _proc_pubmedqa_u(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen, translator=translator) elif key == "pubmedqa_map": count = _proc_pubmedqa_map(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen, translator=translator) else: raise ValueError(f"Unknown dataset: {dataset_key}") logger.info(f"[PROC] End dataset={dataset_key} stats={stats}") return count, stats # ——————————— helpers ——————————— def _build_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict): """Return a list of (user_variant, out_variant, applied_tags) not including the original.""" variants = [] max_k = max(0, int(opts.get("max_aug_per_sample", 1))) for _ in range(max_k): applied = [] u2, did_p = A.maybe_paraphrase(user, opts.get("paraphrase_ratio", 0.0), paraphraser, "easy") if did_p: applied.append("paraphrase_input"); stats["paraphrased_input"] += 1 u3, did_bt = A.maybe_backtranslate(u2, opts.get("backtranslate_ratio", 0.0), paraphraser) if did_bt: applied.append("backtranslate_input"); stats["backtranslated_input"] += 1 o3 = out if opts.get("paraphrase_outputs", False): o2, did_p2 = A.maybe_paraphrase(out, opts.get("paraphrase_ratio", 0.0), paraphraser, "hard") if did_p2: applied.append("paraphrase_output"); stats["paraphrased_output"] += 1 o3b, did_bt2 = A.maybe_backtranslate(o2, opts.get("backtranslate_ratio", 0.0), paraphraser) if did_bt2: applied.append("backtranslate_output"); stats["backtranslated_output"] += 1 o3 = o3b # If nothing applied, skip this variant if not applied: continue # Style standardize and punctuation for the variant too if opts.get("style_standardize", True): o3 = A.style_standardize_answer(o3) u3 = A.ensure_terminal_punct(u3) if u3 else u3 o3 = A.ensure_terminal_punct(o3) if o3 else o3 variants.append((u3, o3, applied)) return variants def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict, translator=None): """Build multiple paraphrased variants for SFT enrichment with enhanced diversity strategies""" variants = [] # Enhanced answer generation with different perspectives answer_variants = [] answer_strategies = [ ("original", out, ["original_answer"]), ("concise", None, ["concise_answer"]), ("detailed", None, ["detailed_answer"]), ("clinical", None, ["clinical_answer"]), ("patient_friendly", None, ["patient_friendly_answer"]) ] for strategy, original_text, tags in answer_strategies: if strategy == "original": answer_variants.append((original_text, tags)) else: try: # Generate different answer styles style_prompt = _get_answer_style_prompt(strategy, user, out) enhanced_out = paraphraser.paraphrase(out, difficulty="hard", custom_prompt=style_prompt) if enhanced_out and not A.is_invalid_response(enhanced_out): # Clean conversational elements enhanced_out = A.clean_conversational_elements(enhanced_out) if opts.get("style_standardize", True): enhanced_out = A.style_standardize_answer(enhanced_out) enhanced_out = A.ensure_terminal_punct(enhanced_out) answer_variants.append((enhanced_out, tags)) stats["paraphrased_output"] += 1 except Exception as e: logger.warning(f"Failed to generate {strategy} answer variant: {e}") continue # Enhanced question generation with different question types question_variants = [] question_strategies = [ ("original", user, ["original_question"]), ("clarifying", None, ["clarifying_question"]), ("follow_up", None, ["follow_up_question"]), ("symptom_focused", None, ["symptom_focused_question"]), ("treatment_focused", None, ["treatment_focused_question"]) ] for strategy, original_text, tags in question_strategies: if strategy == "original": question_variants.append((original_text, tags)) else: try: # Generate different question styles style_prompt = _get_question_style_prompt(strategy, user, out) enhanced_user = paraphraser.paraphrase(user, difficulty="hard", custom_prompt=style_prompt) if enhanced_user and not A.is_invalid_response(enhanced_user): # Clean conversational elements enhanced_user = A.clean_conversational_elements(enhanced_user) enhanced_user = A.ensure_terminal_punct(enhanced_user) question_variants.append((enhanced_user, tags)) stats["paraphrased_input"] += 1 except Exception as e: logger.warning(f"Failed to generate {strategy} question variant: {e}") continue # Create combinations: each question variant with each answer variant for q_user, q_tags in question_variants: for a_out, a_tags in answer_variants: combined_tags = q_tags + a_tags variants.append((q_user, a_out, combined_tags)) # Add Vietnamese variants if translator is available if translator and translator.is_loaded(): vi_variants = [] for q_user, a_out, tags in variants[:5]: # Limit to first 5 to avoid too many variants try: # Translate question and answer vi_q = translator.translate_text(q_user) vi_a = translator.translate_text(a_out) if vi_q and vi_a and not A.is_invalid_response(vi_q) and not A.is_invalid_response(vi_a): vi_tags = tags + ["vietnamese_translated"] vi_variants.append((vi_q, vi_a, vi_tags)) stats["vietnamese_variants"] = stats.get("vietnamese_variants", 0) + 1 except Exception as e: logger.warning(f"Failed to create Vietnamese variant: {e}") continue variants.extend(vi_variants) return variants def _get_answer_style_prompt(strategy: str, question: str, original_answer: str) -> str: """Generate style-specific prompts for answer enhancement with medical focus""" prompts = { "concise": ( "Rewrite this medical answer to be more concise while preserving all key medical information, clinical facts, and diagnostic details. Return only the rewritten answer without any introduction or commentary:\n\n" f"{original_answer}" ), "detailed": ( "Expand this medical answer with more detailed explanations, clinical context, and additional medical information while maintaining accuracy. Return only the expanded answer without any introduction or commentary:\n\n" f"{original_answer}" ), "clinical": ( "Rewrite this answer using more formal clinical language, precise medical terminology, and professional medical communication style. Return only the rewritten answer without any introduction or commentary:\n\n" f"{original_answer}" ), "patient_friendly": ( "Rewrite this medical answer in simpler, more patient-friendly language while keeping it medically accurate and informative. Return only the rewritten answer without any introduction or commentary:\n\n" f"{original_answer}" ) } return prompts.get(strategy, f"Paraphrase this medical answer: {original_answer}") def _get_question_style_prompt(strategy: str, original_question: str, answer: str) -> str: """Generate style-specific prompts for question enhancement with medical focus""" prompts = { "clarifying": ( "Rewrite this medical question to ask for clarification or more specific medical information. Return only the rewritten question without any introduction or commentary:\n\n" f"{original_question}" ), "follow_up": ( "Create a follow-up question that a patient might ask after this medical question, focusing on related medical concerns. Return only the follow-up question without any introduction or commentary:\n\n" f"{original_question}" ), "symptom_focused": ( "Rewrite this question to focus more on symptoms, their characteristics, and clinical presentation. Return only the rewritten question without any introduction or commentary:\n\n" f"{original_question}" ), "treatment_focused": ( "Rewrite this question to focus more on treatment options, management strategies, and therapeutic approaches. Return only the rewritten question without any introduction or commentary:\n\n" f"{original_question}" ) } return prompts.get(strategy, f"Paraphrase this medical question: {original_question}") def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict): # Base cleanup & caps (returns cleaned strings) user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True)) out = A.base_cleanup(out, opts.get("max_chars", 5000), opts.get("deidentify", True)) instr = A.base_cleanup(instr, opts.get("max_chars", 5000), False) # Language sanity (mostly English—skip aggressive transforms if not) if not A.lang_is_english(user): # very rare return instr, user, out, [] # Stack list of entries that has been applied augmentation and stylings applied = [] # Clean conversational elements first out = A.clean_conversational_elements(out) user = A.clean_conversational_elements(user) # Clean invalid responses with retry logic if A.is_invalid_response(out): out = A.retry_invalid_response(out, paraphraser, max_retries=3) if not out: # If retry failed, return empty to indicate drop return instr, user, "", applied applied.append("invalid_response_retried") # Style standardizing the answer if opts.get("style_standardize", True): out = A.style_standardize_answer(out) applied.append("style_standardize") # Ensure punctuation/whitespace user = A.ensure_terminal_punct(user) if user else user out = A.ensure_terminal_punct(out) if out else out return instr, user, out, applied def _commit_row(writer, source, rid, task, instr, user, out, opts, stats, aug_applied, extra_meta=None, dedupe_seen=None, translator=None): # Dedup entry if dedupe_seen is not None: fp = A.fingerprint(instr, user, out) if fp in dedupe_seen: stats["dedup_skipped"] += 1 return False dedupe_seen.add(fp) meta = {"augmentations": aug_applied} if extra_meta: meta.update(extra_meta) row = sft_row(instr, user, out, source=source, rid=rid, task=task, meta=meta) # Apply Vietnamese translation if requested if should_translate(opts.get("vietnamese_translation", False), translator): try: row = translate_sft_row(row, translator) meta["vietnamese_translated"] = True row["meta"] = meta except Exception as e: logger.error(f"Failed to translate SFT row: {e}") writer.write(row) stats["written"] += 1 return True # ——————————— dataset processors ——————————— def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None, translator=None): count = 0 written = 0 for i, obj in enumerate(_iter_json_or_jsonl(path), start=1): try: instr_raw = obj.get("instruction") or "Answer the patient's question like a clinician. Be concise and safe." user_raw = obj.get("input") or "" out_raw = obj.get("output") or "" # Ensure we have string values instr = str(instr_raw).strip() user = str(user_raw).strip() out = str(out_raw).strip() rid = _hash_id(source, i, len(user), len(out)) except Exception as e: logger.warning(f"[PROC] {source} error processing item {i}: {e}, item: {obj}") continue try: instr, user, out, applied = _apply_aug(instr, user, out, source, opts, paraphraser, stats) # Skip if retry failed (empty output) - DO NOT RECORD FAILED RESPONSES if not out: stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1 logger.warning(f"[PROC] {source} dropped invalid response for item {i} - will retry in next batch") continue # 1) ALWAYS write the original (cleaned/style-standardised only) # Enhanced medical accuracy validation (optimized for both cloud and local modes) if not A.validate_medical_accuracy(user, out, paraphraser): stats["medical_accuracy_failed"] = stats.get("medical_accuracy_failed", 0) + 1 applied.append("medical_accuracy_flag") # Optional consistency spot-check (cheap) if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser): stats["consistency_failed"] += 1 # keep the sample but tag it applied.append("consistency_flag") # 2) If expansion is enabled, add enriched variants for SFT _commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen, translator=translator) # Add enriched variants if expand is enabled if opts.get("expand", True): # Use enriched variants for SFT (multiple Q&A combinations) enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator) for (u_aug, o_aug, aug_tags) in enriched_variants: rid_aug = f"{rid}-enriched{random.randint(1000,9999)}" _commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator) # Add clinical scenarios for enhanced diversity if opts.get("clinical_scenarios", True): # Use dedicated method if available (both cloud and local modes now support this) if hasattr(paraphraser, 'create_clinical_scenarios'): clinical_scenarios = paraphraser.create_clinical_scenarios(user, out) else: clinical_scenarios = A.create_clinical_scenarios(user, out, paraphraser) for (scenario_q, scenario_a, scenario_tag) in clinical_scenarios: rid_scenario = f"{rid}-scenario{random.randint(1000,9999)}" _commit_row(writer, source, rid_scenario, "medical_dialogue", instr, scenario_q, scenario_a, opts, stats, [scenario_tag], dedupe_seen=dedupe_seen, translator=translator) stats["clinical_scenarios_created"] += 1 # Increment count only on success count += 1 except Exception as e: logger.warning(f"[PROC] {source} error in processing/augmentation for item {i}: {e}") continue if sample_limit and count >= sample_limit: break if cb and i % 1000 == 0: cb(min(0.9, 0.05 + i/200000), f"{source}: processed {i} rows") if cb: cb(0.92, f"{source} done ({count})") logger.info(f"[PROC] {source} done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}") return count def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None, translator=None): with open(path, "r", encoding="utf-8") as f: data = json.load(f) count = 0 for k, v in data.items(): try: q_raw = v.get("QUESTION") or "" ctx_list = v.get("CONTEXTS") or [] long_ans_raw = v.get("LONG_ANSWER") or "" final_raw = v.get("final_decision") or "" # Ensure we have string values q = str(q_raw).strip() if q_raw else "" if isinstance(ctx_list, list): context = "\n".join(str(ctx) for ctx in ctx_list).strip() else: context = str(ctx_list).strip() long_ans = str(long_ans_raw).strip() if long_ans_raw else "" final = str(final_raw).strip() if final_raw else "" except Exception as e: logger.warning(f"[PROC] pubmedqa_l error processing item {k}: {e}, item: {v}") continue try: instr = "Answer the biomedical question using the provided context. Include a concise rationale if possible." user = f"Question: {q}\n\nContext:\n{context}" if context else f"Question: {q}" out = long_ans if long_ans else final rid = str(k) instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats) # Skip if retry failed (empty output) if not out: stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1 continue _commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied, extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen, translator=translator) if opts.get("expand", True): # Use enriched variants for SFT (multiple Q&A combinations) enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator) for (u_aug, o_aug, aug_tags) in enriched_variants: rid_aug = f"{rid}-enriched{random.randint(1000,9999)}" _commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator) # Increment count only on success count += 1 except Exception as e: logger.warning(f"[PROC] pubmedqa_l error in processing/augmentation for item {k}: {e}") continue if sample_limit and count >= sample_limit: break if cb and count % 1000 == 0: cb(min(0.9, 0.05 + count/60000), f"pubmedqa_l processed {count}") if cb: cb(0.93, f"pubmedqa_l done ({count})") logger.info(f"[PROC] pubmedqa_l done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}") return count def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None, translator=None): with open(path, "r", encoding="utf-8") as f: data = json.load(f) count = 0 for k, v in data.items(): try: q_raw = v.get("QUESTION") or "" ctx_list = v.get("CONTEXTS") or [] # Ensure we have string values q = str(q_raw).strip() if q_raw else "" if isinstance(ctx_list, list): context = "\n".join(str(ctx) for ctx in ctx_list).strip() else: context = str(ctx_list).strip() except Exception as e: logger.warning(f"[PROC] pubmedqa_u error processing item {k}: {e}, item: {v}") continue try: instr = "Rewrite the context into a succinct note, then answer the question. If unknown, say 'insufficient evidence'." user = f"Question: {q}\n\nContext:\n{context}" if context else f"Question: {q}" out = "" # unlabeled rid = str(k) # Optional KD/distillation for a small fraction if opts.get("distill_fraction", 0.0) > 0.0 and random.random() < float(opts["distill_fraction"]): prompt = f"{instr}\n\n{user}\n\nAnswer briefly and safely." guess = paraphraser.paraphrase(prompt, difficulty="hard") # cheap single call if guess and len(guess) < 2000: out = guess.strip() instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats) # Skip if retry failed (empty output) if not out: stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1 continue _commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator) if opts.get("expand", True): # Use enriched variants for SFT (multiple Q&A combinations) enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator) for (u_aug, o_aug, aug_tags) in enriched_variants: rid_aug = f"{rid}-enriched{random.randint(1000,9999)}" _commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator) # Increment count only on success count += 1 except Exception as e: logger.warning(f"[PROC] pubmedqa_u error in processing/augmentation for item {k}: {e}") continue if sample_limit and count >= sample_limit: break if cb and count % 2000 == 0: cb(min(0.9, 0.05 + count/80000), f"pubmedqa_u processed {count}") if cb: cb(0.94, f"pubmedqa_u done ({count})") logger.info(f"[PROC] pubmedqa_u done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}") return count def _proc_pubmedqa_map(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None, translator=None): with open(path, "r", encoding="utf-8") as f: obj = json.load(f) # Log the structure for debugging logger.info(f"[PROC] pubmedqa_map data type: {type(obj)}") if isinstance(obj, dict): logger.info(f"[PROC] pubmedqa_map dict keys: {list(obj.keys())}") if len(obj) > 0: sample_key = next(iter(obj.keys())) sample_value = obj[sample_key] logger.info(f"[PROC] pubmedqa_map sample value type: {type(sample_value)}") if isinstance(sample_value, dict): logger.info(f"[PROC] pubmedqa_map sample value keys: {list(sample_value.keys())}") # Iteration of items def iter_items(): try: if isinstance(obj, list): for it in obj: if isinstance(it, dict): yield it else: logger.warning(f"[PROC] pubmedqa_map skipping non-dict list item: {type(it)}") elif isinstance(obj, dict): qs, cs, ans = obj.get("question"), obj.get("context"), obj.get("answer") if isinstance(qs, list) and isinstance(cs, list) and isinstance(ans, list): for i in range(min(len(qs), len(cs), len(ans))): yield {"question": qs[i], "context": cs[i], "answer": ans[i]} else: # Handle case where values might be dictionaries or other objects for k, v in obj.items(): if isinstance(v, dict): # If v is a dict, ensure it has the expected structure if "question" in v and "context" in v and "answer" in v: yield v else: # Try to map the keys to expected structure yield { "question": v.get("question") or v.get("QUESTION") or str(k), "context": v.get("context") or v.get("CONTEXT") or "", "answer": v.get("answer") or v.get("ANSWER") or "" } else: # If v is not a dict, create a simple structure yield {"question": str(k), "context": str(v) if v else "", "answer": ""} else: logger.warning(f"[PROC] pubmedqa_map unexpected data type: {type(obj)}") except Exception as e: logger.error(f"[PROC] pubmedqa_map error in iter_items: {e}") return count = 0 for i, v in enumerate(iter_items(), start=1): try: # Ensure we have string values, convert if necessary q_raw = v.get("question") or "" c_raw = v.get("context") or "" a_raw = v.get("answer") or "" # Convert to string if not already q = str(q_raw).strip() if q_raw else "" c = str(c_raw).strip() if c_raw else "" a = str(a_raw).strip() if a_raw else "" instr = "Answer the biomedical question based on the context. Justify briefly." user = f"Question: {q}\n\nContext:\n{c}" if c else f"Question: {q}" out = a rid = _hash_id("pubmedqa_map", i, len(q)) # Process the item instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats) # Skip if retry failed (empty output) if not out: stats["dropped_invalid"] = stats.get("dropped_invalid", 0) + 1 continue _commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen, translator=translator) # Handle expansion if enabled if opts.get("expand", True): # Use enriched variants for SFT (multiple Q&A combinations) enriched_variants = _build_enriched_variants(user, out, paraphraser, opts, stats, translator) for (u_aug, o_aug, aug_tags) in enriched_variants: rid_aug = f"{rid}-enriched{random.randint(1000,9999)}" _commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator) # Increment count only on success count += 1 except Exception as e: logger.warning(f"[PROC] pubmedqa_map error processing item {i}: {e}, item: {v}") continue # Check sample limit if sample_limit and count >= sample_limit: break if cb and i % 2000 == 0: cb(min(0.9, 0.05 + i/120000), f"pubmedqa_map processed {i}") if cb: cb(0.95, f"pubmedqa_map done ({count})") logger.info(f"[PROC] pubmedqa_map done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}") return count