Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import os | |
| from pathlib import Path | |
| import streamlit as st | |
| import torch | |
| import timm | |
| import pandas as pd | |
| from PIL import Image | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| # ========================= | |
| # Config | |
| # ========================= | |
| MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Shivani98/ViT-L_Insect_Classifier") | |
| MODEL_FILE = os.getenv("MODEL_FILE", "vit_l_518.pth") | |
| NUM_CLASSES = int(os.getenv("NUM_CLASSES", "3747")) | |
| IMG_SIZE = int(os.getenv("IMG_SIZE", "518")) | |
| CPU_THREADS = int(os.getenv("CPU_THREADS", "2")) | |
| HF_TOKEN = os.getenv("bglab_hf") | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| MAPPING_XLSX = Path("class_mapping_4k.xlsx") # expects: class_idx, Scientific Name, Common Name, Order, Family | |
| # ========================= | |
| # Streamlit basics | |
| # ========================= | |
| st.set_page_config(page_title="InsectNetv2 Classifier", layout="centered") | |
| st.title("🪲 InsectNetv2 Classifier") | |
| torch.set_num_threads(CPU_THREADS) | |
| torch.set_grad_enabled(False) | |
| # ========================= | |
| # Cached: Load model + preprocess | |
| # ========================= | |
| def load_model_and_preprocess(): | |
| st.caption("✨ App loaded from `app.py` (Streamlit)") | |
| # Download checkpoint (cached by HF) | |
| ckpt_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=MODEL_FILE, | |
| token=HF_TOKEN, | |
| cache_dir=str(Path.home() / ".cache" / "huggingface"), | |
| ) | |
| # Build model | |
| model = timm.create_model( | |
| "vit_large_patch14_reg4_dinov2.lvd142m", | |
| pretrained=True, | |
| num_classes=NUM_CLASSES, | |
| ) | |
| # Load checkpoint | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| state = ckpt.get("model", ckpt.get("state_dict", ckpt)) if isinstance(ckpt, dict) else ckpt | |
| model.load_state_dict(state, strict=False) | |
| # CPU speedup: dynamic quantization | |
| try: | |
| model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) | |
| except Exception: | |
| pass | |
| model.eval() | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(IMG_SIZE), | |
| transforms.CenterCrop(IMG_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| # Warmup | |
| with torch.inference_mode(): | |
| _ = model(torch.zeros(1, 3, IMG_SIZE, IMG_SIZE)) | |
| return model, preprocess | |
| model, preprocess = load_model_and_preprocess() | |
| # ========================= | |
| # Cached: Load mapping (xlsx) | |
| # ========================= | |
| def load_mapping_table(mapping_path: Path): | |
| """ | |
| Expects columns: | |
| - class_idx | |
| - Scientific Name | |
| - Common Name | |
| - Order | |
| - Family | |
| """ | |
| if not mapping_path.exists(): | |
| return None | |
| df = pd.read_excel(mapping_path) | |
| # Normalize column names just in case | |
| # (support a few common variants) | |
| col_map = {c.lower().strip(): c for c in df.columns} | |
| required = { | |
| "class_idx": None, | |
| "scientific name": None, | |
| "common name": None, | |
| "order": None, | |
| "family": None, | |
| } | |
| # Find matching original columns | |
| for key in list(required.keys()): | |
| for col in df.columns: | |
| if col.lower().strip() == key: | |
| required[key] = col | |
| break | |
| missing = [k for k, v in required.items() if v is None] | |
| if missing: | |
| st.warning(f"Mapping file found but missing columns: {missing}. Will fall back to raw indices.") | |
| return None | |
| # Set index to class_idx for O(1) lookup | |
| df = df.set_index(required["class_idx"]) | |
| return { | |
| "df": df, | |
| "cols": { | |
| "scientific": required["scientific name"], | |
| "common": required["common name"], | |
| "order": required["order"], | |
| "family": required["family"], | |
| }, | |
| } | |
| mapping_store = load_mapping_table(MAPPING_XLSX) | |
| # ========================= | |
| # Prediction util | |
| # ========================= | |
| def predict_indices(img: Image.Image, topk: int = 5): | |
| x = preprocess(img).unsqueeze(0) | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1).squeeze(0) | |
| topk = min(topk, NUM_CLASSES) | |
| topk_probs, topk_idx = torch.topk(probs, k=topk) | |
| top1_idx = int(topk_idx[0].item()) | |
| top1_prob = float(topk_probs[0].item()) | |
| top5_idx = [int(i) for i in topk_idx.tolist()] | |
| top5_prob = [float(p) for p in topk_probs.tolist()] | |
| return top1_idx, top1_prob, top5_idx, top5_prob | |
| # ========================= | |
| # Helpers to format rows | |
| # ========================= | |
| def fmt_top1(idx: int, p: float): | |
| if mapping_store is None: | |
| st.info(f"Top-1 index: **{idx}** — p={p:.3f}\n\n(Upload a `class_mapping.xlsx` to show names/taxonomy.)") | |
| return | |
| df = mapping_store["df"] | |
| cols = mapping_store["cols"] | |
| if idx not in df.index: | |
| st.warning(f"Top-1 index {idx} not found in mapping; showing raw index only.") | |
| st.write(f"Confidence: `{p:.3f}`") | |
| return | |
| row = df.loc[idx] | |
| sci = row[cols["scientific"]] | |
| com = row[cols["common"]] | |
| odr = row[cols["order"]] | |
| fam = row[cols["family"]] | |
| # No index displayed here by design | |
| st.subheader("🦋 Top-1 Prediction") | |
| st.markdown( | |
| f""" | |
| **Scientific Name:** *{sci}* | |
| **Common Name:** {com} | |
| **Order:** {odr} | |
| **Family:** {fam} | |
| **Confidence:** `{p:.3f}` | |
| """.strip() | |
| ) | |
| def fmt_top5(idxs, ps): | |
| st.markdown("### 🌿 Top-5 Predictions") | |
| if mapping_store is None: | |
| for i, p in zip(idxs, ps): | |
| st.write(f"- Index **{i}** — p={p:.3f}") | |
| return | |
| df = mapping_store["df"] | |
| cols = mapping_store["cols"] | |
| for i, p in zip(idxs, ps): | |
| if i in df.index: | |
| row = df.loc[i] | |
| sci = row[cols["scientific"]] | |
| com = row[cols["common"]] | |
| # Only scientific + common for top-5 | |
| st.markdown(f"- **{sci}** (*{com}*) — `{p:.3f}`") | |
| else: | |
| st.markdown(f"- Index **{i}** — `{p:.3f}`") | |
| # ========================= | |
| # UI | |
| # ========================= | |
| with st.sidebar: | |
| st.header("Settings") | |
| fps_note = st.caption("Model: ViT-L DINOv2 head · Image size: {}".format(IMG_SIZE)) | |
| if mapping_store is None: | |
| st.warning("No `class_mapping.xlsx` found. Top-1/Top-5 will show indices only.") | |
| uploaded = st.file_uploader("Upload a JPG/PNG", type=["jpg", "jpeg", "png"]) | |
| if uploaded: | |
| try: | |
| img = Image.open(uploaded).convert("RGB") | |
| except Exception as e: | |
| st.error(f"Failed to read image: {e}") | |
| st.stop() | |
| st.image(img, caption="Input", use_container_width=True) | |
| with st.spinner("Predicting…"): | |
| top1_idx, top1_prob, top5_idx, top5_prob = predict_indices(img, topk=5) | |
| # Render: Top-1 (all attributes, no index), then Top-5 (name + common only) | |
| fmt_top1(top1_idx, top1_prob) | |
| fmt_top5(top5_idx, top5_prob) | |
| with st.expander("Advanced • Raw indices & probabilities"): | |
| st.write(f"Top-1 index: **{top1_idx}** — p={top1_prob:.4f}") | |
| for i, p in zip(top5_idx, top5_prob): | |
| st.write(f"- {i} : {p:.4f}") | |
| else: | |
| st.info("Upload an image to see predictions.") | |
| st.caption("Tip: place `class_mapping.xlsx` next to this script with columns: " | |
| "`class_idx, Scientific Name, Common Name, Order, Family`.") | |