"""
@app.get("/status")
def status():
with STATE_LOCK:
return JSONResponse(STATE)
# ──────── GCS token (only in cloud mode) ────────
@app.get("/oauth2/start")
def oauth2_start(request: Request):
if IS_LOCAL:
raise HTTPException(400, "OAuth is not available in local mode. Google Drive integration is disabled.")
# Compute redirect URI dynamically from the actual host the Space is using
host = request.headers.get("x-forwarded-host") or request.headers.get("host")
scheme = "https" # Spaces are HTTPS at the edge
redirect_uri = f"{scheme}://{host}/oauth2/callback"
try:
url = build_auth_url(redirect_uri)
return JSONResponse({"authorize_url": url})
except Exception as e:
raise HTTPException(500, f"OAuth init failed: {e}")
# Display your token (only in cloud mode)
@app.get("/oauth2/callback")
def oauth2_callback(request: Request, code: str = "", state: str = ""):
if IS_LOCAL:
raise HTTPException(400, "OAuth is not available in local mode. Google Drive integration is disabled.")
if not code:
raise HTTPException(400, "Missing 'code'")
# Send req
host = request.headers.get("x-forwarded-host") or request.headers.get("host")
scheme = "https"
redirect_uri = f"{scheme}://{host}/oauth2/callback"
# Parse and show token code
try:
creds = exchange_code(code, redirect_uri)
refresh = creds.refresh_token or os.getenv("GDRIVE_REFRESH_TOKEN", "")
# UI
html = f"""
✅ Google Drive Authorized
Your refresh token is:
{refresh}
👉 Copy this token and save it into your Hugging Face Space Secrets
as GDRIVE_REFRESH_TOKEN.
This ensures persistence across rebuilds.
"""
return HTMLResponse(html)
except Exception as e:
raise HTTPException(500, f"OAuth exchange failed: {e}")
@app.get("/files")
def files():
out = []
for root, _, fns in os.walk(OUTPUT_DIR):
for fn in fns:
out.append(os.path.relpath(os.path.join(root, fn), OUTPUT_DIR))
return {"output_dir": OUTPUT_DIR, "files": sorted(out)}
@app.post("/process/{dataset_key}")
def process_dataset(dataset_key: str, params: ProcessParams, background: BackgroundTasks):
with STATE_LOCK:
if STATE["running"]:
logger.warning(
f"[JOB] Rejecting new job dataset={dataset_key} "
f"current={STATE['dataset']} started_at={STATE['started_at']}"
)
raise HTTPException(409, detail="Another job is running.")
STATE["running"] = True
STATE["dataset"] = dataset_key
STATE["started_at"] = now_iso()
STATE["progress"] = 0.0
STATE["message"] = "starting"
STATE["last_result"] = None
logger.info(
f"[JOB] Queued dataset={dataset_key} "
f"params={{'sample_limit': {params.sample_limit}, 'seed': {params.seed}, "
f"'rag_processing': {params.rag_processing}, 'augment': {params.augment.dict()} }}"
)
# Start job to background runner thread
logger.info(f"[JOB] Started dataset={dataset_key}")
background.add_task(_run_job, dataset_key, params)
return {"ok": True, "message": f"Job for '{dataset_key}' started."}
@app.post("/rag/{dataset_key}")
def process_rag_dataset(dataset_key: str, params: ProcessParams, background: BackgroundTasks):
"""Dedicated RAG processing endpoint"""
# Force RAG processing mode
params.rag_processing = True
with STATE_LOCK:
if STATE["running"]:
logger.warning(
f"[RAG] Rejecting new RAG job dataset={dataset_key} "
f"current={STATE['dataset']} started_at={STATE['started_at']}"
)
raise HTTPException(409, detail="Another job is running.")
STATE["running"] = True
STATE["dataset"] = dataset_key
STATE["started_at"] = now_iso()
STATE["progress"] = 0.0
STATE["message"] = "starting RAG processing"
STATE["last_result"] = None
logger.info(
f"[RAG] Queued RAG dataset={dataset_key} "
f"params={{'sample_limit': {params.sample_limit}, 'seed': {params.seed} }}"
)
# Start job to background runner thread
logger.info(f"[RAG] Started RAG dataset={dataset_key}")
background.add_task(_run_job, dataset_key, params)
return {"ok": True, "message": f"RAG processing job for '{dataset_key}' started."}
def _run_job(dataset_key: str, params: ProcessParams):
t0 = time.time()
try:
ds = resolve_dataset(dataset_key)
if not ds:
set_state(running=False, message="unknown dataset")
return
# Download HF Dataset and start processing units
set_state(message="downloading")
local_path = hf_download_dataset(ds["repo_id"], ds["filename"], ds["repo_type"])
logger.info(f"[JOB] Downloaded {ds['repo_id']}/{ds['filename']} → {local_path}")
# Prepare timestamp for fire writing
ts = dt.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
mode_suffix = "rag" if params.rag_processing else "sft"
stem = f"{dataset_key}-{mode_suffix}-{ts}"
jsonl_path = os.path.join(OUTPUT_DIR, f"{stem}.jsonl")
csv_path = os.path.join(OUTPUT_DIR, f"{stem}.csv")
# Change state
set_state(message="processing", progress=0.05)
# Writer
writer = RAGWriter(jsonl_path=jsonl_path, csv_path=csv_path) if params.rag_processing else CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
# Load translator if Vietnamese translation is requested
translator = None
if params.vietnamese_translation:
set_state(message="Loading Vietnamese translator", progress=0.05)
try:
# Ensure cache directories are set up properly
cache_dir = os.path.abspath("cache/huggingface")
os.makedirs(cache_dir, exist_ok=True)
os.environ["HF_HOME"] = cache_dir
# Pass paraphraser to translator for LLM-based translation
vietnamese_translator.paraphraser = paraphraser
vietnamese_translator.load_model()
translator = vietnamese_translator
logger.info("✅ Vietnamese translator loaded successfully with LLM models")
except Exception as e:
logger.error(f"❌ Failed to load Vietnamese translator: {e}")
logger.warning("Continuing without Vietnamese translation...")
set_state(message=f"Warning: Vietnamese translation disabled - {e}", progress=0.1)
# Don't fail the entire job, just disable translation
translator = None
if params.rag_processing:
# RAG processing mode
set_state(message="RAG processing", progress=0.1)
count, stats = process_file_into_rag(
dataset_key=dataset_key,
input_path=local_path,
writer=writer,
nvidia_model=os.getenv("NVIDIA_MODEL", "meta/llama-3.1-8b-instruct"),
sample_limit=params.sample_limit,
seed=params.seed,
progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
translator=translator,
paraphraser=paraphraser,
is_local=IS_LOCAL,
hf_token=os.getenv("HF_TOKEN")
)
else:
# Standard SFT processing mode
set_state(message="SFT processing", progress=0.1)
# Add Vietnamese translation flag to augment options
augment_opts = params.augment.dict()
augment_opts["vietnamese_translation"] = params.vietnamese_translation
count, stats = process_file_into_sft(
dataset_key=dataset_key,
input_path=local_path,
writer=writer,
paraphraser=paraphraser,
augment_opts=augment_opts,
sample_limit=params.sample_limit,
seed=params.seed,
progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
translator=translator
)
# Log translation statistics if translator was used
if translator and hasattr(translator, 'get_stats'):
translation_stats = translator.get_stats()
logger.info(f"[JOB] Translation stats: {translation_stats}")
stats["translation_stats"] = translation_stats
logger.info(f"[JOB] Processed dataset={dataset_key} rows={count} stats={stats}")
writer.close()
# Upload to GDrive (only in cloud mode) or save locally
if IS_LOCAL:
set_state(message="saving files locally", progress=0.95)
logger.info(f"[JOB] Files saved locally: jsonl={jsonl_path} csv={csv_path}")
up1 = up2 = True # Local mode always "succeeds"
else:
set_state(message="uploading to Google Drive", progress=0.95)
up1 = drive.upload_file_to_drive(jsonl_path, mimetype="application/json")
up2 = drive.upload_file_to_drive(csv_path, mimetype="text/csv")
logger.info(
f"[JOB] Uploads complete uploaded={bool(up1 and up2)} "
f"jsonl={jsonl_path} csv={csv_path}"
)
# Finalize a task
result = {
"dataset": dataset_key,
"processing_mode": "RAG" if params.rag_processing else "SFT",
"processed_rows": count,
"stats": stats,
"artifacts": {"jsonl": jsonl_path, "csv": csv_path},
"uploaded": bool(up1 and up2),
"duration_sec": round(time.time() - t0, 2)
}
set_state(message="done", progress=1.0, last_result=result, running=False)
logger.info(
f"[JOB] Finished dataset={dataset_key} "
f"duration_sec={round(time.time()-t0, 2)}"
)
except Exception as e:
logger.exception(f"[JOB] Error for dataset={dataset_key}: {e}")
set_state(message=f"error: {e}", running=False)