pii-nl-bench / app.py
jellewas's picture
Deploy to HuggingFace Space (a10g-small)
e7d9849
"""Gradio app for pii-nl-bench β€” run benchmarks on HuggingFace Spaces (GPU).
HuggingFace Space: Docker SDK + T4/A10G GPU
"""
from __future__ import annotations
import io
import sys
import time
import traceback
from contextlib import redirect_stdout, redirect_stderr
from datetime import datetime
from pathlib import Path
import gradio as gr
from preflight import run_preflight
# ── Helpers ────────────────────────────────────────────────────────
def capture_output(fn, *args, **kwargs):
"""Run fn and capture both stdout/stderr."""
buf = io.StringIO()
with redirect_stdout(buf), redirect_stderr(buf):
try:
result = fn(*args, **kwargs)
except Exception:
buf.write(traceback.format_exc())
result = None
return result, buf.getvalue()
# ── Preflight ──────────────────────────────────────────────────────
def run_preflight_check(quick: bool = False):
"""Run preflight checks and return formatted results."""
report, output = capture_output(run_preflight, skip_model_load=quick)
if report is None:
return f"Preflight crashed:\n```\n{output}\n```", False
summary = report.summary()
full = f"{output}\n{summary}" if output.strip() else summary
return full, report.all_critical_passed
# ── Benchmark runner ───────────────────────────────────────────────
def run_benchmark(groups: list[str], mode: str, max_samples: int, progress=gr.Progress()):
"""Run the group-based benchmark and return results."""
from benchmark.groups import (
GROUP_RUNNERS,
MatchMode,
detect_device,
)
from benchmark.config import RESULTS_DIR
progress(0, desc="Detecting device...")
device = detect_device("auto")
match_mode = MatchMode(mode)
# Parse group selection
if "All groups" in groups:
group_ids = [1, 2, 3, 4]
else:
group_ids = []
for g in groups:
gid = int(g.split(":")[0].strip())
group_ids.append(gid)
all_results = []
total_groups = len(group_ids)
log_lines = []
log_lines.append(f"Device: {device}")
log_lines.append(f"Mode: {match_mode.value}")
log_lines.append(f"Groups: {group_ids}")
log_lines.append(f"Max samples: {max_samples or 'unlimited'}")
log_lines.append("")
for i, gid in enumerate(group_ids):
if gid not in GROUP_RUNNERS:
log_lines.append(f"Unknown group {gid}, skipping")
continue
label, runner = GROUP_RUNNERS[gid]
progress((i / total_groups), desc=f"Group {gid}: {label}...")
log_lines.append(f"{'=' * 60}")
log_lines.append(f" Running Group {gid}: {label}")
log_lines.append(f"{'=' * 60}")
# Capture group output
results, output = capture_output(runner, device, match_mode, max_samples)
if output:
log_lines.append(output)
if results:
all_results.extend(results)
for r in results:
o = r.overall
if o.support > 0 or o.fp > 0:
log_lines.append(
f" {r.model_name:<25s} [{r.dataset_name}] "
f"P={o.precision:.3f} R={o.recall:.3f} "
f"F1={o.f1:.3f} F2={o.f2:.3f}"
)
log_lines.append("")
progress(0.90, desc="Loading dataset statistics...")
# Collect dataset samples for report statistics
from benchmark.datasets.loader import load_dataset_by_name
from benchmark.datasets.normalize import normalize_dataset
loaded_datasets = {}
for ds_name in ["ai4privacy", "gretel", "e3jsi", "conll2002", "article9"]:
if not any(r.dataset_name == ds_name for r in all_results):
continue
try:
raw = load_dataset_by_name(ds_name)
loaded_datasets[ds_name] = normalize_dataset(raw)
except Exception:
pass
progress(0.95, desc="Generating report...")
# Generate report
report_md = ""
if all_results:
import json as _json
from benchmark.evaluation.report import generate_report, export_results_json
from benchmark.evaluation.charts import generate_all_charts
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y-%m-%d_%H%M")
# JSON results
json_data = export_results_json(all_results)
json_path = RESULTS_DIR / f"groups_{ts}.json"
json_path.write_text(_json.dumps(json_data, indent=2, ensure_ascii=False))
log_lines.append(f"JSON saved: {json_path}")
# Charts
charts_dir = RESULTS_DIR / f"charts_{ts}"
chart_paths = generate_all_charts(all_results, charts_dir)
chart_pngs = [str(p) for p in chart_paths if p.suffix == ".png"]
if chart_paths:
log_lines.append(f"Charts saved: {charts_dir}/ ({len(chart_paths)} files)")
# Markdown report (after charts so it can embed them)
report_md = generate_report(
all_results,
datasets=loaded_datasets,
charts_dir=str(charts_dir) if chart_paths else None,
)
report_path = RESULTS_DIR / f"groups_{ts}.md"
report_path.write_text(report_md)
log_lines.append(f"Report saved: {report_path}")
progress(1.0, desc="Done")
log_text = "\n".join(log_lines)
return log_text, report_md, chart_pngs if all_results else []
# ── Gradio UI ──────────────────────────────────────────────────────
def build_ui():
with gr.Blocks(
title="pii-nl-bench β€” Dutch PII Detection Benchmark",
theme=gr.themes.Soft(),
) as app:
gr.Markdown(
"# pii-nl-bench β€” Dutch PII Detection Benchmark\n"
"Compare PII detection models on Dutch text. "
"Proves `monsieur_regex + qwen_adapter` outperforms alternatives.\n\n"
"**Step 1**: Run preflight checks to validate GPU, models, and data. \n"
"**Step 2**: Run the benchmark."
)
with gr.Tab("Preflight Checks"):
gr.Markdown(
"Validates GPU, CUDA, bfloat16, model downloads, LoRA adapter, "
"datasets, and disk space **before** running the benchmark."
)
with gr.Row():
quick_check = gr.Checkbox(
label="Quick (skip model loading)", value=False,
)
preflight_btn = gr.Button("Run Preflight Checks", variant="primary")
preflight_status = gr.Textbox(
label="Preflight Status", lines=3, interactive=False,
)
preflight_output = gr.Code(
label="Detailed Output", language=None, lines=25,
)
def on_preflight(quick):
result, passed = run_preflight_check(quick)
status = "ALL CLEAR β€” ready to benchmark" if passed else "BLOCKED β€” see details below"
return status, result
preflight_btn.click(
fn=on_preflight,
inputs=[quick_check],
outputs=[preflight_status, preflight_output],
)
with gr.Tab("Benchmark"):
with gr.Row():
with gr.Column():
group_select = gr.CheckboxGroup(
choices=[
"All groups",
"1: Structured PII",
"2: Named Entity Recognition",
"3: Full PII Coverage",
"4: Article 9 Special Categories",
],
value=["All groups"],
label="Benchmark Groups",
)
mode_select = gr.Radio(
choices=["lenient", "strict", "label_only"],
value="lenient",
label="Span Matching Mode",
)
max_samples = gr.Slider(
minimum=0, maximum=5000, step=50, value=1000,
label="Max samples per dataset (0 = unlimited, default: 1000)",
)
run_btn = gr.Button("Run Benchmark", variant="primary")
with gr.Row():
log_output = gr.Code(
label="Benchmark Log", language=None, lines=30,
)
with gr.Row():
chart_gallery = gr.Gallery(
label="Benchmark Charts",
columns=2,
height="auto",
object_fit="contain",
)
with gr.Row():
report_output = gr.Markdown(label="Report")
run_btn.click(
fn=run_benchmark,
inputs=[group_select, mode_select, max_samples],
outputs=[log_output, report_output, chart_gallery],
)
with gr.Tab("About"):
gr.Markdown(
"## Models Compared\n\n"
"| Model | Type | Labels |\n"
"|-------|------|--------|\n"
"| **monsieur_regex** | Rule-based regex | 16 structured PII types |\n"
"| **qwen_adapter** (jellewas/gdpr-lora) | Qwen3.5-4B LoRA | 23 types incl. Article 9 |\n"
"| **regex+adapter** (combined) | Ensemble | All types |\n"
"| **pii_ner_nl** (jellewas/pii-ner-nl) | RobBERT token classifier | BIO-tagged NER |\n"
"| **flair** | BiLSTM-CRF | PERSON, LOCATION, ORG |\n"
"| **gliner** | Zero-shot transformer | 9 types |\n"
"| **deduce** | Dutch clinical rules | 11 types |\n"
"| **presidio** | spaCy + regex | 12 types |\n\n"
"## Evaluation\n\n"
"- **Primary metric**: F2 (recall weighted 4x over precision)\n"
"- **Rationale**: Missed PII = GDPR violation > false alarm\n"
"- **Matching modes**: strict (exact span), lenient (50% overlap), label-only\n"
)
return app
if __name__ == "__main__":
app = build_ui()
app.launch(server_name="0.0.0.0", server_port=7860)