InsectNetV2 / app.py
BGLab's picture
Update app.py
3c6e044 verified
#!/usr/bin/env python3
import os
from pathlib import Path
import streamlit as st
import torch
import timm
import pandas as pd
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
# =========================
# Config
# =========================
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Shivani98/ViT-L_Insect_Classifier")
MODEL_FILE = os.getenv("MODEL_FILE", "vit_l_518.pth")
NUM_CLASSES = int(os.getenv("NUM_CLASSES", "3747"))
IMG_SIZE = int(os.getenv("IMG_SIZE", "518"))
CPU_THREADS = int(os.getenv("CPU_THREADS", "2"))
HF_TOKEN = os.getenv("bglab_hf")
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
MAPPING_XLSX = Path("class_mapping_4k.xlsx") # expects: class_idx, Scientific Name, Common Name, Order, Family
# =========================
# Streamlit basics
# =========================
st.set_page_config(page_title="InsectNetv2 Classifier", layout="centered")
st.title("🪲 InsectNetv2 Classifier")
torch.set_num_threads(CPU_THREADS)
torch.set_grad_enabled(False)
# =========================
# Cached: Load model + preprocess
# =========================
@st.cache_resource
def load_model_and_preprocess():
st.caption("✨ App loaded from `app.py` (Streamlit)")
# Download checkpoint (cached by HF)
ckpt_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILE,
token=HF_TOKEN,
cache_dir=str(Path.home() / ".cache" / "huggingface"),
)
# Build model
model = timm.create_model(
"vit_large_patch14_reg4_dinov2.lvd142m",
pretrained=True,
num_classes=NUM_CLASSES,
)
# Load checkpoint
ckpt = torch.load(ckpt_path, map_location="cpu")
state = ckpt.get("model", ckpt.get("state_dict", ckpt)) if isinstance(ckpt, dict) else ckpt
model.load_state_dict(state, strict=False)
# CPU speedup: dynamic quantization
try:
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
except Exception:
pass
model.eval()
preprocess = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.CenterCrop(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
# Warmup
with torch.inference_mode():
_ = model(torch.zeros(1, 3, IMG_SIZE, IMG_SIZE))
return model, preprocess
model, preprocess = load_model_and_preprocess()
# =========================
# Cached: Load mapping (xlsx)
# =========================
@st.cache_resource
def load_mapping_table(mapping_path: Path):
"""
Expects columns:
- class_idx
- Scientific Name
- Common Name
- Order
- Family
"""
if not mapping_path.exists():
return None
df = pd.read_excel(mapping_path)
# Normalize column names just in case
# (support a few common variants)
col_map = {c.lower().strip(): c for c in df.columns}
required = {
"class_idx": None,
"scientific name": None,
"common name": None,
"order": None,
"family": None,
}
# Find matching original columns
for key in list(required.keys()):
for col in df.columns:
if col.lower().strip() == key:
required[key] = col
break
missing = [k for k, v in required.items() if v is None]
if missing:
st.warning(f"Mapping file found but missing columns: {missing}. Will fall back to raw indices.")
return None
# Set index to class_idx for O(1) lookup
df = df.set_index(required["class_idx"])
return {
"df": df,
"cols": {
"scientific": required["scientific name"],
"common": required["common name"],
"order": required["order"],
"family": required["family"],
},
}
mapping_store = load_mapping_table(MAPPING_XLSX)
# =========================
# Prediction util
# =========================
@torch.inference_mode()
def predict_indices(img: Image.Image, topk: int = 5):
x = preprocess(img).unsqueeze(0)
logits = model(x)
probs = torch.softmax(logits, dim=1).squeeze(0)
topk = min(topk, NUM_CLASSES)
topk_probs, topk_idx = torch.topk(probs, k=topk)
top1_idx = int(topk_idx[0].item())
top1_prob = float(topk_probs[0].item())
top5_idx = [int(i) for i in topk_idx.tolist()]
top5_prob = [float(p) for p in topk_probs.tolist()]
return top1_idx, top1_prob, top5_idx, top5_prob
# =========================
# Helpers to format rows
# =========================
def fmt_top1(idx: int, p: float):
if mapping_store is None:
st.info(f"Top-1 index: **{idx}** — p={p:.3f}\n\n(Upload a `class_mapping.xlsx` to show names/taxonomy.)")
return
df = mapping_store["df"]
cols = mapping_store["cols"]
if idx not in df.index:
st.warning(f"Top-1 index {idx} not found in mapping; showing raw index only.")
st.write(f"Confidence: `{p:.3f}`")
return
row = df.loc[idx]
sci = row[cols["scientific"]]
com = row[cols["common"]]
odr = row[cols["order"]]
fam = row[cols["family"]]
# No index displayed here by design
st.subheader("🦋 Top-1 Prediction")
st.markdown(
f"""
**Scientific Name:** *{sci}*
**Common Name:** {com}
**Order:** {odr}
**Family:** {fam}
**Confidence:** `{p:.3f}`
""".strip()
)
def fmt_top5(idxs, ps):
st.markdown("### 🌿 Top-5 Predictions")
if mapping_store is None:
for i, p in zip(idxs, ps):
st.write(f"- Index **{i}** — p={p:.3f}")
return
df = mapping_store["df"]
cols = mapping_store["cols"]
for i, p in zip(idxs, ps):
if i in df.index:
row = df.loc[i]
sci = row[cols["scientific"]]
com = row[cols["common"]]
# Only scientific + common for top-5
st.markdown(f"- **{sci}** (*{com}*) — `{p:.3f}`")
else:
st.markdown(f"- Index **{i}** — `{p:.3f}`")
# =========================
# UI
# =========================
with st.sidebar:
st.header("Settings")
fps_note = st.caption("Model: ViT-L DINOv2 head · Image size: {}".format(IMG_SIZE))
if mapping_store is None:
st.warning("No `class_mapping.xlsx` found. Top-1/Top-5 will show indices only.")
uploaded = st.file_uploader("Upload a JPG/PNG", type=["jpg", "jpeg", "png"])
if uploaded:
try:
img = Image.open(uploaded).convert("RGB")
except Exception as e:
st.error(f"Failed to read image: {e}")
st.stop()
st.image(img, caption="Input", use_container_width=True)
with st.spinner("Predicting…"):
top1_idx, top1_prob, top5_idx, top5_prob = predict_indices(img, topk=5)
# Render: Top-1 (all attributes, no index), then Top-5 (name + common only)
fmt_top1(top1_idx, top1_prob)
fmt_top5(top5_idx, top5_prob)
with st.expander("Advanced • Raw indices & probabilities"):
st.write(f"Top-1 index: **{top1_idx}** — p={top1_prob:.4f}")
for i, p in zip(top5_idx, top5_prob):
st.write(f"- {i} : {p:.4f}")
else:
st.info("Upload an image to see predictions.")
st.caption("Tip: place `class_mapping.xlsx` next to this script with columns: "
"`class_idx, Scientific Name, Common Name, Order, Family`.")