#!/usr/bin/env python3 import os from pathlib import Path import streamlit as st import torch import timm import pandas as pd from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download # ========================= # Config # ========================= MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Shivani98/ViT-L_Insect_Classifier") MODEL_FILE = os.getenv("MODEL_FILE", "vit_l_518.pth") NUM_CLASSES = int(os.getenv("NUM_CLASSES", "3747")) IMG_SIZE = int(os.getenv("IMG_SIZE", "518")) CPU_THREADS = int(os.getenv("CPU_THREADS", "2")) HF_TOKEN = os.getenv("bglab_hf") IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) MAPPING_XLSX = Path("class_mapping_4k.xlsx") # expects: class_idx, Scientific Name, Common Name, Order, Family # ========================= # Streamlit basics # ========================= st.set_page_config(page_title="InsectNetv2 Classifier", layout="centered") st.title("🪲 InsectNetv2 Classifier") torch.set_num_threads(CPU_THREADS) torch.set_grad_enabled(False) # ========================= # Cached: Load model + preprocess # ========================= @st.cache_resource def load_model_and_preprocess(): st.caption("✨ App loaded from `app.py` (Streamlit)") # Download checkpoint (cached by HF) ckpt_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=MODEL_FILE, token=HF_TOKEN, cache_dir=str(Path.home() / ".cache" / "huggingface"), ) # Build model model = timm.create_model( "vit_large_patch14_reg4_dinov2.lvd142m", pretrained=True, num_classes=NUM_CLASSES, ) # Load checkpoint ckpt = torch.load(ckpt_path, map_location="cpu") state = ckpt.get("model", ckpt.get("state_dict", ckpt)) if isinstance(ckpt, dict) else ckpt model.load_state_dict(state, strict=False) # CPU speedup: dynamic quantization try: model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) except Exception: pass model.eval() preprocess = transforms.Compose([ transforms.Resize(IMG_SIZE), transforms.CenterCrop(IMG_SIZE), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) # Warmup with torch.inference_mode(): _ = model(torch.zeros(1, 3, IMG_SIZE, IMG_SIZE)) return model, preprocess model, preprocess = load_model_and_preprocess() # ========================= # Cached: Load mapping (xlsx) # ========================= @st.cache_resource def load_mapping_table(mapping_path: Path): """ Expects columns: - class_idx - Scientific Name - Common Name - Order - Family """ if not mapping_path.exists(): return None df = pd.read_excel(mapping_path) # Normalize column names just in case # (support a few common variants) col_map = {c.lower().strip(): c for c in df.columns} required = { "class_idx": None, "scientific name": None, "common name": None, "order": None, "family": None, } # Find matching original columns for key in list(required.keys()): for col in df.columns: if col.lower().strip() == key: required[key] = col break missing = [k for k, v in required.items() if v is None] if missing: st.warning(f"Mapping file found but missing columns: {missing}. Will fall back to raw indices.") return None # Set index to class_idx for O(1) lookup df = df.set_index(required["class_idx"]) return { "df": df, "cols": { "scientific": required["scientific name"], "common": required["common name"], "order": required["order"], "family": required["family"], }, } mapping_store = load_mapping_table(MAPPING_XLSX) # ========================= # Prediction util # ========================= @torch.inference_mode() def predict_indices(img: Image.Image, topk: int = 5): x = preprocess(img).unsqueeze(0) logits = model(x) probs = torch.softmax(logits, dim=1).squeeze(0) topk = min(topk, NUM_CLASSES) topk_probs, topk_idx = torch.topk(probs, k=topk) top1_idx = int(topk_idx[0].item()) top1_prob = float(topk_probs[0].item()) top5_idx = [int(i) for i in topk_idx.tolist()] top5_prob = [float(p) for p in topk_probs.tolist()] return top1_idx, top1_prob, top5_idx, top5_prob # ========================= # Helpers to format rows # ========================= def fmt_top1(idx: int, p: float): if mapping_store is None: st.info(f"Top-1 index: **{idx}** — p={p:.3f}\n\n(Upload a `class_mapping.xlsx` to show names/taxonomy.)") return df = mapping_store["df"] cols = mapping_store["cols"] if idx not in df.index: st.warning(f"Top-1 index {idx} not found in mapping; showing raw index only.") st.write(f"Confidence: `{p:.3f}`") return row = df.loc[idx] sci = row[cols["scientific"]] com = row[cols["common"]] odr = row[cols["order"]] fam = row[cols["family"]] # No index displayed here by design st.subheader("🦋 Top-1 Prediction") st.markdown( f""" **Scientific Name:** *{sci}* **Common Name:** {com} **Order:** {odr} **Family:** {fam} **Confidence:** `{p:.3f}` """.strip() ) def fmt_top5(idxs, ps): st.markdown("### 🌿 Top-5 Predictions") if mapping_store is None: for i, p in zip(idxs, ps): st.write(f"- Index **{i}** — p={p:.3f}") return df = mapping_store["df"] cols = mapping_store["cols"] for i, p in zip(idxs, ps): if i in df.index: row = df.loc[i] sci = row[cols["scientific"]] com = row[cols["common"]] # Only scientific + common for top-5 st.markdown(f"- **{sci}** (*{com}*) — `{p:.3f}`") else: st.markdown(f"- Index **{i}** — `{p:.3f}`") # ========================= # UI # ========================= with st.sidebar: st.header("Settings") fps_note = st.caption("Model: ViT-L DINOv2 head · Image size: {}".format(IMG_SIZE)) if mapping_store is None: st.warning("No `class_mapping.xlsx` found. Top-1/Top-5 will show indices only.") uploaded = st.file_uploader("Upload a JPG/PNG", type=["jpg", "jpeg", "png"]) if uploaded: try: img = Image.open(uploaded).convert("RGB") except Exception as e: st.error(f"Failed to read image: {e}") st.stop() st.image(img, caption="Input", use_container_width=True) with st.spinner("Predicting…"): top1_idx, top1_prob, top5_idx, top5_prob = predict_indices(img, topk=5) # Render: Top-1 (all attributes, no index), then Top-5 (name + common only) fmt_top1(top1_idx, top1_prob) fmt_top5(top5_idx, top5_prob) with st.expander("Advanced • Raw indices & probabilities"): st.write(f"Top-1 index: **{top1_idx}** — p={top1_prob:.4f}") for i, p in zip(top5_idx, top5_prob): st.write(f"- {i} : {p:.4f}") else: st.info("Upload an image to see predictions.") st.caption("Tip: place `class_mapping.xlsx` next to this script with columns: " "`class_idx, Scientific Name, Common Name, Order, Family`.")