Spaces:
Sleeping
Sleeping
Commit
·
19d62ff
1
Parent(s):
050b5e3
Upd RAG schema to QAC format
Browse files- app.py +11 -10
- utils/augment.py +1 -1
- utils/llm.py +3 -3
- utils/rag.py +17 -37
- utils/schema.py +54 -1
app.py
CHANGED
|
@@ -16,7 +16,7 @@ from utils.processor import process_file_into_sft
|
|
| 16 |
from utils.rag import process_file_into_rag
|
| 17 |
from utils.drive_saver import DriveSaver
|
| 18 |
from utils.llm import Paraphraser
|
| 19 |
-
from utils.schema import CentralisedWriter
|
| 20 |
from utils.token import get_credentials, exchange_code, build_auth_url
|
| 21 |
from vi.translator import VietnameseTranslator
|
| 22 |
|
|
@@ -71,14 +71,14 @@ STATE: Dict[str, object] = {
|
|
| 71 |
|
| 72 |
class AugmentOptions(BaseModel):
|
| 73 |
# ratios are 0..1
|
| 74 |
-
paraphrase_ratio: float = 0.
|
| 75 |
-
paraphrase_outputs: bool =
|
| 76 |
-
backtranslate_ratio: float = 0.
|
| 77 |
style_standardize: bool = True
|
| 78 |
deidentify: bool = True
|
| 79 |
dedupe: bool = True
|
| 80 |
max_chars: int = 5000 # cap extremely long contexts
|
| 81 |
-
consistency_check_ratio: float = 0.
|
| 82 |
# KD / distillation (optional, keeps default off)
|
| 83 |
distill_fraction: float = 0.0 # for unlabeled only
|
| 84 |
expand: bool = True # Enable back-translation and complex augmentation
|
|
@@ -178,15 +178,16 @@ def root():
|
|
| 178 |
headers: {{ "Content-Type": "application/json" }},
|
| 179 |
body: JSON.stringify({{
|
| 180 |
augment: {{
|
| 181 |
-
paraphrase_ratio: 0.
|
| 182 |
-
backtranslate_ratio: 0.
|
| 183 |
-
paraphrase_outputs:
|
| 184 |
style_standardize: true,
|
| 185 |
deidentify: true,
|
| 186 |
dedupe: true,
|
| 187 |
max_chars: 5000,
|
| 188 |
expand: true,
|
| 189 |
-
max_aug_per_sample: 2
|
|
|
|
| 190 |
}},
|
| 191 |
sample_limit: null, // Sample down (currently disabled)
|
| 192 |
seed: 42,
|
|
@@ -382,7 +383,7 @@ def _run_job(dataset_key: str, params: ProcessParams):
|
|
| 382 |
set_state(message="processing", progress=0.05)
|
| 383 |
|
| 384 |
# Writer
|
| 385 |
-
writer = CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
|
| 386 |
|
| 387 |
# Load translator if Vietnamese translation is requested
|
| 388 |
translator = None
|
|
|
|
| 16 |
from utils.rag import process_file_into_rag
|
| 17 |
from utils.drive_saver import DriveSaver
|
| 18 |
from utils.llm import Paraphraser
|
| 19 |
+
from utils.schema import CentralisedWriter, RAGWriter
|
| 20 |
from utils.token import get_credentials, exchange_code, build_auth_url
|
| 21 |
from vi.translator import VietnameseTranslator
|
| 22 |
|
|
|
|
| 71 |
|
| 72 |
class AugmentOptions(BaseModel):
|
| 73 |
# ratios are 0..1
|
| 74 |
+
paraphrase_ratio: float = 0.2
|
| 75 |
+
paraphrase_outputs: bool = True
|
| 76 |
+
backtranslate_ratio: float = 0.1
|
| 77 |
style_standardize: bool = True
|
| 78 |
deidentify: bool = True
|
| 79 |
dedupe: bool = True
|
| 80 |
max_chars: int = 5000 # cap extremely long contexts
|
| 81 |
+
consistency_check_ratio: float = 0.05 # small ratio e.g. 0.01
|
| 82 |
# KD / distillation (optional, keeps default off)
|
| 83 |
distill_fraction: float = 0.0 # for unlabeled only
|
| 84 |
expand: bool = True # Enable back-translation and complex augmentation
|
|
|
|
| 178 |
headers: {{ "Content-Type": "application/json" }},
|
| 179 |
body: JSON.stringify({{
|
| 180 |
augment: {{
|
| 181 |
+
paraphrase_ratio: 0.2,
|
| 182 |
+
backtranslate_ratio: 0.1,
|
| 183 |
+
paraphrase_outputs: true,
|
| 184 |
style_standardize: true,
|
| 185 |
deidentify: true,
|
| 186 |
dedupe: true,
|
| 187 |
max_chars: 5000,
|
| 188 |
expand: true,
|
| 189 |
+
max_aug_per_sample: 2,
|
| 190 |
+
consistency_check_ratio: 0.05
|
| 191 |
}},
|
| 192 |
sample_limit: null, // Sample down (currently disabled)
|
| 193 |
seed: 42,
|
|
|
|
| 383 |
set_state(message="processing", progress=0.05)
|
| 384 |
|
| 385 |
# Writer
|
| 386 |
+
writer = RAGWriter(jsonl_path=jsonl_path, csv_path=csv_path) if params.rag_processing else CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
|
| 387 |
|
| 388 |
# Load translator if Vietnamese translation is requested
|
| 389 |
translator = None
|
utils/augment.py
CHANGED
|
@@ -93,7 +93,7 @@ def maybe_paraphrase(text: str, ratio: float, paraphraser, difficulty: str) -> T
|
|
| 93 |
def maybe_backtranslate(text: str, ratio: float, paraphraser) -> Tuple[str, bool]:
|
| 94 |
if ratio <= 0 or not text: return text, False
|
| 95 |
if random.random() < ratio:
|
| 96 |
-
bt = paraphraser.backtranslate(text, via_lang="
|
| 97 |
return bt if bt else text, bool(bt)
|
| 98 |
return text, False
|
| 99 |
|
|
|
|
| 93 |
def maybe_backtranslate(text: str, ratio: float, paraphraser) -> Tuple[str, bool]:
|
| 94 |
if ratio <= 0 or not text: return text, False
|
| 95 |
if random.random() < ratio:
|
| 96 |
+
bt = paraphraser.backtranslate(text, via_lang="vi")
|
| 97 |
return bt if bt else text, bool(bt)
|
| 98 |
return text, False
|
| 99 |
|
utils/llm.py
CHANGED
|
@@ -154,18 +154,18 @@ class Paraphraser:
|
|
| 154 |
return self._clean_resp(out) if out else text
|
| 155 |
|
| 156 |
# ————— Translate & Backtranslate —————
|
| 157 |
-
def translate(self, text: str, target_lang: str = "
|
| 158 |
if not text: return text
|
| 159 |
prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
|
| 160 |
out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
|
| 161 |
if out: return out.strip()
|
| 162 |
return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
|
| 163 |
|
| 164 |
-
def backtranslate(self, text: str, via_lang: str = "
|
| 165 |
if not text: return text
|
| 166 |
mid = self.translate(text, target_lang=via_lang)
|
| 167 |
if not mid: return None
|
| 168 |
-
prompt = f"Translate the following
|
| 169 |
out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
|
| 170 |
if out: return out.strip()
|
| 171 |
res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
|
|
|
|
| 154 |
return self._clean_resp(out) if out else text
|
| 155 |
|
| 156 |
# ————— Translate & Backtranslate —————
|
| 157 |
+
def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
|
| 158 |
if not text: return text
|
| 159 |
prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
|
| 160 |
out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
|
| 161 |
if out: return out.strip()
|
| 162 |
return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
|
| 163 |
|
| 164 |
+
def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
|
| 165 |
if not text: return text
|
| 166 |
mid = self.translate(text, target_lang=via_lang)
|
| 167 |
if not mid: return None
|
| 168 |
+
prompt = f"Translate the following Vietnamese text back to English, preserving the exact meaning:\n\n{mid}"
|
| 169 |
out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
|
| 170 |
if out: return out.strip()
|
| 171 |
res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
|
utils/rag.py
CHANGED
|
@@ -5,9 +5,9 @@ import hashlib
|
|
| 5 |
import random
|
| 6 |
from typing import Dict, List, Tuple, Optional, Callable
|
| 7 |
|
| 8 |
-
from utils.schema import sft_row
|
| 9 |
from utils.llm import NvidiaClient, KeyRotator
|
| 10 |
-
from vi.processing import
|
| 11 |
|
| 12 |
# Logger
|
| 13 |
logger = logging.getLogger("rag_processor")
|
|
@@ -188,18 +188,8 @@ class RAGProcessor:
|
|
| 188 |
if not question or not answer:
|
| 189 |
continue
|
| 190 |
|
| 191 |
-
#
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
# Format user input as QCA
|
| 195 |
-
if context:
|
| 196 |
-
rag_user = f"Question: {question}\n\nContext: {context}"
|
| 197 |
-
else:
|
| 198 |
-
rag_user = f"Question: {question}"
|
| 199 |
-
|
| 200 |
-
# Commit the RAG-formatted row
|
| 201 |
-
if self._commit_rag_row(writer, source, rid, "rag_medical_qa",
|
| 202 |
-
rag_instruction, rag_user, answer,
|
| 203 |
stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
|
| 204 |
written += 1
|
| 205 |
|
|
@@ -256,16 +246,8 @@ class RAGProcessor:
|
|
| 256 |
context = self.generate_context_from_qa(question, answer)
|
| 257 |
|
| 258 |
rid = str(k)
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
if context:
|
| 262 |
-
rag_user = f"Question: {question}\n\nContext: {context}"
|
| 263 |
-
else:
|
| 264 |
-
rag_user = f"Question: {question}"
|
| 265 |
-
|
| 266 |
-
# Commit the RAG-formatted row
|
| 267 |
-
if self._commit_rag_row(writer, source, rid, "rag_biomedical_qa",
|
| 268 |
-
rag_instruction, rag_user, answer,
|
| 269 |
stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
|
| 270 |
written += 1
|
| 271 |
|
|
@@ -286,30 +268,28 @@ class RAGProcessor:
|
|
| 286 |
logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
|
| 287 |
return count
|
| 288 |
|
| 289 |
-
def _commit_rag_row(self, writer,
|
| 290 |
-
instruction: str, user_input: str, output: str,
|
| 291 |
stats: Dict, dedupe_seen: set = None, translator=None, opts=None) -> bool:
|
| 292 |
-
"""Commit a RAG-formatted row to the writer"""
|
| 293 |
# Simple deduplication based on content hash
|
| 294 |
if dedupe_seen is not None:
|
| 295 |
-
content_hash = hashlib.md5(f"{
|
| 296 |
if content_hash in dedupe_seen:
|
| 297 |
stats["dedup_skipped"] = stats.get("dedup_skipped", 0) + 1
|
| 298 |
return False
|
| 299 |
dedupe_seen.add(content_hash)
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
# Apply Vietnamese translation if requested
|
| 305 |
if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
|
| 306 |
try:
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
except Exception as e:
|
| 311 |
logger.error(f"Failed to translate RAG row: {e}")
|
| 312 |
-
|
| 313 |
writer.write(row)
|
| 314 |
stats["written"] = stats.get("written", 0) + 1
|
| 315 |
return True
|
|
|
|
| 5 |
import random
|
| 6 |
from typing import Dict, List, Tuple, Optional, Callable
|
| 7 |
|
| 8 |
+
from utils.schema import sft_row, rag_row
|
| 9 |
from utils.llm import NvidiaClient, KeyRotator
|
| 10 |
+
from vi.processing import should_translate
|
| 11 |
|
| 12 |
# Logger
|
| 13 |
logger = logging.getLogger("rag_processor")
|
|
|
|
| 188 |
if not question or not answer:
|
| 189 |
continue
|
| 190 |
|
| 191 |
+
# Commit the RAG-formatted row (QAC)
|
| 192 |
+
if self._commit_rag_row(writer, rid, question, context, answer,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
|
| 194 |
written += 1
|
| 195 |
|
|
|
|
| 246 |
context = self.generate_context_from_qa(question, answer)
|
| 247 |
|
| 248 |
rid = str(k)
|
| 249 |
+
# Commit the RAG-formatted row (QAC)
|
| 250 |
+
if self._commit_rag_row(writer, rid, question, context, answer,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
|
| 252 |
written += 1
|
| 253 |
|
|
|
|
| 268 |
logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
|
| 269 |
return count
|
| 270 |
|
| 271 |
+
def _commit_rag_row(self, writer, rid: str, question: str, context: str, answer: str,
|
|
|
|
| 272 |
stats: Dict, dedupe_seen: set = None, translator=None, opts=None) -> bool:
|
| 273 |
+
"""Commit a RAG-formatted row (QAC) to the writer"""
|
| 274 |
# Simple deduplication based on content hash
|
| 275 |
if dedupe_seen is not None:
|
| 276 |
+
content_hash = hashlib.md5(f"{question}{context}{answer}".encode()).hexdigest()
|
| 277 |
if content_hash in dedupe_seen:
|
| 278 |
stats["dedup_skipped"] = stats.get("dedup_skipped", 0) + 1
|
| 279 |
return False
|
| 280 |
dedupe_seen.add(content_hash)
|
| 281 |
+
|
| 282 |
+
row = rag_row(question=question, context=context, answer=answer, rid=rid)
|
| 283 |
+
|
| 284 |
+
# Apply Vietnamese translation if requested (translate Q/A/C fields directly)
|
|
|
|
| 285 |
if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
|
| 286 |
try:
|
| 287 |
+
if translator:
|
| 288 |
+
row = translator.translate_dict(row, ["question", "answer", "context"])
|
| 289 |
+
row["vi_translated"] = True
|
| 290 |
except Exception as e:
|
| 291 |
logger.error(f"Failed to translate RAG row: {e}")
|
| 292 |
+
|
| 293 |
writer.write(row)
|
| 294 |
stats["written"] = stats.get("written", 0) + 1
|
| 295 |
return True
|
utils/schema.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Centralized SFT writer (JSONL + CSV)
|
| 2 |
import csv
|
| 3 |
import orjson
|
| 4 |
from typing import Optional, Dict
|
|
@@ -66,3 +66,56 @@ class CentralisedWriter:
|
|
| 66 |
self.jsonl_fp.close()
|
| 67 |
finally:
|
| 68 |
self.csv_fp.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Centralized SFT writer (JSONL + CSV) and RAG writer
|
| 2 |
import csv
|
| 3 |
import orjson
|
| 4 |
from typing import Optional, Dict
|
|
|
|
| 66 |
self.jsonl_fp.close()
|
| 67 |
finally:
|
| 68 |
self.csv_fp.close()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# —— RAG (QAC) schema ——
|
| 72 |
+
|
| 73 |
+
def rag_row(question: str, context: str, answer: str, rid: str):
|
| 74 |
+
return {
|
| 75 |
+
"id": rid,
|
| 76 |
+
"question": question or "",
|
| 77 |
+
"answer": answer or "",
|
| 78 |
+
"context": context or ""
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def is_valid_rag_row(row: Dict, max_chars: int = 20000) -> bool:
|
| 83 |
+
q = row.get("question", "")
|
| 84 |
+
a = row.get("answer", "")
|
| 85 |
+
c = row.get("context", "")
|
| 86 |
+
if not (q and a):
|
| 87 |
+
return False
|
| 88 |
+
if any(len(x) > max_chars for x in (q, a, c)):
|
| 89 |
+
return False
|
| 90 |
+
return True
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class RAGWriter:
|
| 94 |
+
"""Streams JSONL + CSV for RAG (QAC) format with columns: id, question, answer, context."""
|
| 95 |
+
def __init__(self, jsonl_path: str, csv_path: str):
|
| 96 |
+
self.jsonl_fp = open(jsonl_path, "wb")
|
| 97 |
+
self.csv_fp = open(csv_path, "w", newline="", encoding="utf-8")
|
| 98 |
+
self.csv_wr = csv.DictWriter(self.csv_fp, fieldnames=["id","question","answer","context"])
|
| 99 |
+
self.csv_wr.writeheader()
|
| 100 |
+
|
| 101 |
+
def write(self, row: dict):
|
| 102 |
+
if not is_valid_rag_row(row):
|
| 103 |
+
logger.warning(
|
| 104 |
+
f"[RAG-WRITER] Skipping invalid row id={row.get('id')} "
|
| 105 |
+
f"(len q={len(row.get('question',''))}, a={len(row.get('answer',''))}, c={len(row.get('context',''))})"
|
| 106 |
+
)
|
| 107 |
+
return
|
| 108 |
+
self.jsonl_fp.write(orjson.dumps(row))
|
| 109 |
+
self.jsonl_fp.write(b"\n")
|
| 110 |
+
self.csv_wr.writerow({
|
| 111 |
+
"id": row.get("id",""),
|
| 112 |
+
"question": row.get("question",""),
|
| 113 |
+
"answer": row.get("answer",""),
|
| 114 |
+
"context": row.get("context","")
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
def close(self):
|
| 118 |
+
try:
|
| 119 |
+
self.jsonl_fp.close()
|
| 120 |
+
finally:
|
| 121 |
+
self.csv_fp.close()
|