LiamKhoaLe commited on
Commit
19d62ff
·
1 Parent(s): 050b5e3

Upd RAG schema to QAC format

Browse files
Files changed (5) hide show
  1. app.py +11 -10
  2. utils/augment.py +1 -1
  3. utils/llm.py +3 -3
  4. utils/rag.py +17 -37
  5. utils/schema.py +54 -1
app.py CHANGED
@@ -16,7 +16,7 @@ from utils.processor import process_file_into_sft
16
  from utils.rag import process_file_into_rag
17
  from utils.drive_saver import DriveSaver
18
  from utils.llm import Paraphraser
19
- from utils.schema import CentralisedWriter
20
  from utils.token import get_credentials, exchange_code, build_auth_url
21
  from vi.translator import VietnameseTranslator
22
 
@@ -71,14 +71,14 @@ STATE: Dict[str, object] = {
71
 
72
  class AugmentOptions(BaseModel):
73
  # ratios are 0..1
74
- paraphrase_ratio: float = 0.0
75
- paraphrase_outputs: bool = False
76
- backtranslate_ratio: float = 0.0
77
  style_standardize: bool = True
78
  deidentify: bool = True
79
  dedupe: bool = True
80
  max_chars: int = 5000 # cap extremely long contexts
81
- consistency_check_ratio: float = 0.0 # small ratio e.g. 0.01
82
  # KD / distillation (optional, keeps default off)
83
  distill_fraction: float = 0.0 # for unlabeled only
84
  expand: bool = True # Enable back-translation and complex augmentation
@@ -178,15 +178,16 @@ def root():
178
  headers: {{ "Content-Type": "application/json" }},
179
  body: JSON.stringify({{
180
  augment: {{
181
- paraphrase_ratio: 0.1,
182
- backtranslate_ratio: 0.00, // Increase to 0.05-0.1 for back-translation
183
- paraphrase_outputs: false,
184
  style_standardize: true,
185
  deidentify: true,
186
  dedupe: true,
187
  max_chars: 5000,
188
  expand: true,
189
- max_aug_per_sample: 2
 
190
  }},
191
  sample_limit: null, // Sample down (currently disabled)
192
  seed: 42,
@@ -382,7 +383,7 @@ def _run_job(dataset_key: str, params: ProcessParams):
382
  set_state(message="processing", progress=0.05)
383
 
384
  # Writer
385
- writer = CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
386
 
387
  # Load translator if Vietnamese translation is requested
388
  translator = None
 
16
  from utils.rag import process_file_into_rag
17
  from utils.drive_saver import DriveSaver
18
  from utils.llm import Paraphraser
19
+ from utils.schema import CentralisedWriter, RAGWriter
20
  from utils.token import get_credentials, exchange_code, build_auth_url
21
  from vi.translator import VietnameseTranslator
22
 
 
71
 
72
  class AugmentOptions(BaseModel):
73
  # ratios are 0..1
74
+ paraphrase_ratio: float = 0.2
75
+ paraphrase_outputs: bool = True
76
+ backtranslate_ratio: float = 0.1
77
  style_standardize: bool = True
78
  deidentify: bool = True
79
  dedupe: bool = True
80
  max_chars: int = 5000 # cap extremely long contexts
81
+ consistency_check_ratio: float = 0.05 # small ratio e.g. 0.01
82
  # KD / distillation (optional, keeps default off)
83
  distill_fraction: float = 0.0 # for unlabeled only
84
  expand: bool = True # Enable back-translation and complex augmentation
 
178
  headers: {{ "Content-Type": "application/json" }},
179
  body: JSON.stringify({{
180
  augment: {{
181
+ paraphrase_ratio: 0.2,
182
+ backtranslate_ratio: 0.1,
183
+ paraphrase_outputs: true,
184
  style_standardize: true,
185
  deidentify: true,
186
  dedupe: true,
187
  max_chars: 5000,
188
  expand: true,
189
+ max_aug_per_sample: 2,
190
+ consistency_check_ratio: 0.05
191
  }},
192
  sample_limit: null, // Sample down (currently disabled)
193
  seed: 42,
 
383
  set_state(message="processing", progress=0.05)
384
 
385
  # Writer
386
+ writer = RAGWriter(jsonl_path=jsonl_path, csv_path=csv_path) if params.rag_processing else CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
387
 
388
  # Load translator if Vietnamese translation is requested
389
  translator = None
utils/augment.py CHANGED
@@ -93,7 +93,7 @@ def maybe_paraphrase(text: str, ratio: float, paraphraser, difficulty: str) -> T
93
  def maybe_backtranslate(text: str, ratio: float, paraphraser) -> Tuple[str, bool]:
94
  if ratio <= 0 or not text: return text, False
95
  if random.random() < ratio:
96
- bt = paraphraser.backtranslate(text, via_lang="de")
97
  return bt if bt else text, bool(bt)
98
  return text, False
99
 
 
93
  def maybe_backtranslate(text: str, ratio: float, paraphraser) -> Tuple[str, bool]:
94
  if ratio <= 0 or not text: return text, False
95
  if random.random() < ratio:
96
+ bt = paraphraser.backtranslate(text, via_lang="vi")
97
  return bt if bt else text, bool(bt)
98
  return text, False
99
 
utils/llm.py CHANGED
@@ -154,18 +154,18 @@ class Paraphraser:
154
  return self._clean_resp(out) if out else text
155
 
156
  # ————— Translate & Backtranslate —————
157
- def translate(self, text: str, target_lang: str = "de") -> Optional[str]:
158
  if not text: return text
159
  prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
160
  out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
161
  if out: return out.strip()
162
  return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
163
 
164
- def backtranslate(self, text: str, via_lang: str = "de") -> Optional[str]:
165
  if not text: return text
166
  mid = self.translate(text, target_lang=via_lang)
167
  if not mid: return None
168
- prompt = f"Translate the following {via_lang} text back to English, preserving the exact meaning:\n\n{mid}"
169
  out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
170
  if out: return out.strip()
171
  res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
 
154
  return self._clean_resp(out) if out else text
155
 
156
  # ————— Translate & Backtranslate —————
157
+ def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
158
  if not text: return text
159
  prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
160
  out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
161
  if out: return out.strip()
162
  return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
163
 
164
+ def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
165
  if not text: return text
166
  mid = self.translate(text, target_lang=via_lang)
167
  if not mid: return None
168
+ prompt = f"Translate the following Vietnamese text back to English, preserving the exact meaning:\n\n{mid}"
169
  out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
170
  if out: return out.strip()
171
  res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
utils/rag.py CHANGED
@@ -5,9 +5,9 @@ import hashlib
5
  import random
6
  from typing import Dict, List, Tuple, Optional, Callable
7
 
8
- from utils.schema import sft_row
9
  from utils.llm import NvidiaClient, KeyRotator
10
- from vi.processing import translate_rag_row, should_translate, log_translation_stats
11
 
12
  # Logger
13
  logger = logging.getLogger("rag_processor")
@@ -188,18 +188,8 @@ class RAGProcessor:
188
  if not question or not answer:
189
  continue
190
 
191
- # Create RAG-specific instruction
192
- rag_instruction = "Answer the medical question based on the provided context. If the context is insufficient, provide the best available medical information."
193
-
194
- # Format user input as QCA
195
- if context:
196
- rag_user = f"Question: {question}\n\nContext: {context}"
197
- else:
198
- rag_user = f"Question: {question}"
199
-
200
- # Commit the RAG-formatted row
201
- if self._commit_rag_row(writer, source, rid, "rag_medical_qa",
202
- rag_instruction, rag_user, answer,
203
  stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
204
  written += 1
205
 
@@ -256,16 +246,8 @@ class RAGProcessor:
256
  context = self.generate_context_from_qa(question, answer)
257
 
258
  rid = str(k)
259
- rag_instruction = "Answer the biomedical question based on the provided context."
260
-
261
- if context:
262
- rag_user = f"Question: {question}\n\nContext: {context}"
263
- else:
264
- rag_user = f"Question: {question}"
265
-
266
- # Commit the RAG-formatted row
267
- if self._commit_rag_row(writer, source, rid, "rag_biomedical_qa",
268
- rag_instruction, rag_user, answer,
269
  stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
270
  written += 1
271
 
@@ -286,30 +268,28 @@ class RAGProcessor:
286
  logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
287
  return count
288
 
289
- def _commit_rag_row(self, writer, source: str, rid: str, task: str,
290
- instruction: str, user_input: str, output: str,
291
  stats: Dict, dedupe_seen: set = None, translator=None, opts=None) -> bool:
292
- """Commit a RAG-formatted row to the writer"""
293
  # Simple deduplication based on content hash
294
  if dedupe_seen is not None:
295
- content_hash = hashlib.md5(f"{user_input}{output}".encode()).hexdigest()
296
  if content_hash in dedupe_seen:
297
  stats["dedup_skipped"] = stats.get("dedup_skipped", 0) + 1
298
  return False
299
  dedupe_seen.add(content_hash)
300
-
301
- meta = {"rag_processing": True, "format": "qca"}
302
- row = sft_row(instruction, user_input, output, source=source, rid=rid, task=task, meta=meta)
303
-
304
- # Apply Vietnamese translation if requested
305
  if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
306
  try:
307
- row = translate_rag_row(row, translator)
308
- meta["vietnamese_translated"] = True
309
- row["meta"] = meta
310
  except Exception as e:
311
  logger.error(f"Failed to translate RAG row: {e}")
312
-
313
  writer.write(row)
314
  stats["written"] = stats.get("written", 0) + 1
315
  return True
 
5
  import random
6
  from typing import Dict, List, Tuple, Optional, Callable
7
 
8
+ from utils.schema import sft_row, rag_row
9
  from utils.llm import NvidiaClient, KeyRotator
10
+ from vi.processing import should_translate
11
 
12
  # Logger
13
  logger = logging.getLogger("rag_processor")
 
188
  if not question or not answer:
189
  continue
190
 
191
+ # Commit the RAG-formatted row (QAC)
192
+ if self._commit_rag_row(writer, rid, question, context, answer,
 
 
 
 
 
 
 
 
 
 
193
  stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
194
  written += 1
195
 
 
246
  context = self.generate_context_from_qa(question, answer)
247
 
248
  rid = str(k)
249
+ # Commit the RAG-formatted row (QAC)
250
+ if self._commit_rag_row(writer, rid, question, context, answer,
 
 
 
 
 
 
 
 
251
  stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
252
  written += 1
253
 
 
268
  logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
269
  return count
270
 
271
+ def _commit_rag_row(self, writer, rid: str, question: str, context: str, answer: str,
 
272
  stats: Dict, dedupe_seen: set = None, translator=None, opts=None) -> bool:
273
+ """Commit a RAG-formatted row (QAC) to the writer"""
274
  # Simple deduplication based on content hash
275
  if dedupe_seen is not None:
276
+ content_hash = hashlib.md5(f"{question}{context}{answer}".encode()).hexdigest()
277
  if content_hash in dedupe_seen:
278
  stats["dedup_skipped"] = stats.get("dedup_skipped", 0) + 1
279
  return False
280
  dedupe_seen.add(content_hash)
281
+
282
+ row = rag_row(question=question, context=context, answer=answer, rid=rid)
283
+
284
+ # Apply Vietnamese translation if requested (translate Q/A/C fields directly)
 
285
  if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
286
  try:
287
+ if translator:
288
+ row = translator.translate_dict(row, ["question", "answer", "context"])
289
+ row["vi_translated"] = True
290
  except Exception as e:
291
  logger.error(f"Failed to translate RAG row: {e}")
292
+
293
  writer.write(row)
294
  stats["written"] = stats.get("written", 0) + 1
295
  return True
utils/schema.py CHANGED
@@ -1,4 +1,4 @@
1
- # Centralized SFT writer (JSONL + CSV)
2
  import csv
3
  import orjson
4
  from typing import Optional, Dict
@@ -66,3 +66,56 @@ class CentralisedWriter:
66
  self.jsonl_fp.close()
67
  finally:
68
  self.csv_fp.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Centralized SFT writer (JSONL + CSV) and RAG writer
2
  import csv
3
  import orjson
4
  from typing import Optional, Dict
 
66
  self.jsonl_fp.close()
67
  finally:
68
  self.csv_fp.close()
69
+
70
+
71
+ # —— RAG (QAC) schema ——
72
+
73
+ def rag_row(question: str, context: str, answer: str, rid: str):
74
+ return {
75
+ "id": rid,
76
+ "question": question or "",
77
+ "answer": answer or "",
78
+ "context": context or ""
79
+ }
80
+
81
+
82
+ def is_valid_rag_row(row: Dict, max_chars: int = 20000) -> bool:
83
+ q = row.get("question", "")
84
+ a = row.get("answer", "")
85
+ c = row.get("context", "")
86
+ if not (q and a):
87
+ return False
88
+ if any(len(x) > max_chars for x in (q, a, c)):
89
+ return False
90
+ return True
91
+
92
+
93
+ class RAGWriter:
94
+ """Streams JSONL + CSV for RAG (QAC) format with columns: id, question, answer, context."""
95
+ def __init__(self, jsonl_path: str, csv_path: str):
96
+ self.jsonl_fp = open(jsonl_path, "wb")
97
+ self.csv_fp = open(csv_path, "w", newline="", encoding="utf-8")
98
+ self.csv_wr = csv.DictWriter(self.csv_fp, fieldnames=["id","question","answer","context"])
99
+ self.csv_wr.writeheader()
100
+
101
+ def write(self, row: dict):
102
+ if not is_valid_rag_row(row):
103
+ logger.warning(
104
+ f"[RAG-WRITER] Skipping invalid row id={row.get('id')} "
105
+ f"(len q={len(row.get('question',''))}, a={len(row.get('answer',''))}, c={len(row.get('context',''))})"
106
+ )
107
+ return
108
+ self.jsonl_fp.write(orjson.dumps(row))
109
+ self.jsonl_fp.write(b"\n")
110
+ self.csv_wr.writerow({
111
+ "id": row.get("id",""),
112
+ "question": row.get("question",""),
113
+ "answer": row.get("answer",""),
114
+ "context": row.get("context","")
115
+ })
116
+
117
+ def close(self):
118
+ try:
119
+ self.jsonl_fp.close()
120
+ finally:
121
+ self.csv_fp.close()