Spaces:
Sleeping
Sleeping
| import os, json, joblib | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from sklearn.preprocessing import normalize | |
| from sklearn.neighbors import NearestNeighbors | |
| import gradio as gr | |
| from PIL import Image | |
| import pickle | |
| from skimage.color import rgb2lab, lab2rgb | |
| from skimage.feature import local_binary_pattern, hog | |
| from sklearn.cluster import KMeans | |
| # ---------------- CONFIG ---------------- | |
| ARTIFACTS_DIR = "." | |
| FEATURES_PATH = os.path.join(ARTIFACTS_DIR, "features.npy") | |
| PATHS_PATH = os.path.join(ARTIFACTS_DIR, "image_paths.json") | |
| PALETTES_PATH = os.path.join(ARTIFACTS_DIR, "palettes.json") | |
| INDEX_PATH = os.path.join(ARTIFACTS_DIR, "nn_index.joblib") | |
| MODEL_PATH = os.path.join(ARTIFACTS_DIR, "resnet50_multilayer_ssl.pt") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| GDRIVE_FOLDER = "https://drive.google.com/drive/folders/10EXzo27vuTjyG9FXHWWO4J5AhVQhyUp9?usp=sharing" | |
| # ---------------- LOAD ARTIFACTS ---------------- | |
| features = np.load(FEATURES_PATH) | |
| with open(PATHS_PATH, "r") as f: | |
| IMG_PATHS = json.load(f) | |
| with open(PALETTES_PATH, "r") as f: | |
| DATA_PALETTES = json.load(f) | |
| nn_index = joblib.load(INDEX_PATH) | |
| with open("kmeans.pkl", "rb") as f: | |
| kmeans = pickle.load(f) | |
| with open("kmeans.pkl", "rb") as f: | |
| fitted_kmeans = pickle.load(f) | |
| # ---------------- FEATURE CLASSES ---------------- | |
| class AutoColor: | |
| def __init__(self, n_colors=5, sample_px=150000, random_state=42): | |
| self.n_colors = n_colors | |
| self.sample_px = sample_px | |
| self.random_state = random_state | |
| def extract(self, arr: np.ndarray): | |
| lab = rgb2lab(arr / 255.0).reshape(-1, 3) | |
| if lab.shape[0] > self.sample_px: | |
| idx = np.random.RandomState(self.random_state).choice( | |
| lab.shape[0], self.sample_px, replace=False | |
| ) | |
| lab = lab[idx] | |
| kmeans = KMeans(n_clusters=self.n_colors, random_state=self.random_state, n_init=8) | |
| kmeans.fit(lab) | |
| centers = kmeans.cluster_centers_ | |
| labels = kmeans.labels_ | |
| counts = np.bincount(labels, minlength=self.n_colors).astype(np.float32) | |
| props = counts / counts.sum() | |
| return centers, props | |
| def vectorize(self, centers, props): | |
| return np.concatenate([centers.flatten(), props]).astype(np.float32) | |
| class TextureBank: | |
| def __init__(self): | |
| self.lbp_settings = [(8, 1), (8, 2), (16, 3)] | |
| self.gabor_kernels = [] | |
| for theta in np.linspace(0, np.pi, 6, endpoint=False): | |
| for sigma in (1.0, 2.0, 3.0): | |
| for lambd in (3.0, 6.0, 9.0): | |
| kern = cv2.getGaborKernel((9, 9), sigma, theta, lambd, gamma=0.5, psi=0) | |
| self.gabor_kernels.append(kern) | |
| def extract(self, arr: np.ndarray): | |
| gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY) | |
| gray = cv2.resize(gray, (512, 512), interpolation=cv2.INTER_AREA) | |
| feats = [] | |
| for (P, R) in self.lbp_settings: | |
| lbp = local_binary_pattern(gray, P=P, R=R, method="uniform") | |
| n_bins = P + 2 | |
| hist, _ = np.histogram(lbp, bins=n_bins, range=(0, n_bins), density=True) | |
| feats.append(hist.astype(np.float32)) | |
| for k in self.gabor_kernels: | |
| resp = cv2.filter2D(gray, cv2.CV_32F, k) | |
| feats.append([resp.mean(), resp.std()]) | |
| h = hog( | |
| gray, | |
| pixels_per_cell=(16, 16), | |
| cells_per_block=(2, 2), | |
| orientations=9, | |
| visualize=False, | |
| feature_vector=True, | |
| ) | |
| feats.append(h.astype(np.float32)) | |
| return np.concatenate(feats, axis=0) | |
| class ORBBoVW: | |
| def __init__(self, n_words=64): | |
| self.n_words = n_words | |
| self.kmeans = None | |
| self.orb = cv2.ORB_create(nfeatures=3000) | |
| def _orb_des(self, arr: np.ndarray): | |
| gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY) | |
| kps, des = self.orb.detectAndCompute(gray, None) | |
| if des is None: | |
| return np.zeros((0, 32), dtype=np.uint8) | |
| return des | |
| def transform(self, arr: np.ndarray): | |
| d = self._orb_des(arr) | |
| if d.shape[0] == 0: | |
| bow = np.zeros((self.n_words,), dtype=np.float32) | |
| else: | |
| idx = self.kmeans.predict(d.astype(np.float32)) | |
| bow, _ = np.histogram(idx, bins=np.arange(self.n_words + 1)) | |
| bow = bow.astype(np.float32) | |
| bow /= np.linalg.norm(bow) + 1e-8 | |
| return bow | |
| class ResNetMultiLayer(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| base = torchvision.models.resnet50(weights=None) | |
| self.conv1 = base.conv1; self.bn1 = base.bn1 | |
| self.relu = base.relu; self.maxpool = base.maxpool | |
| self.layer1 = base.layer1; self.layer2 = base.layer2 | |
| self.layer3 = base.layer3; self.layer4 = base.layer4 | |
| self.gap = nn.AdaptiveAvgPool2d((1, 1)) | |
| def forward(self, x): | |
| x = self.conv1(x); x = self.bn1(x); x = self.relu(x); x = self.maxpool(x) | |
| x = self.layer1(x); x2 = self.layer2(x) | |
| x3 = self.layer3(x2); x4 = self.layer4(x3) | |
| f2 = self.gap(x2).flatten(1) | |
| f3 = self.gap(x3).flatten(1) | |
| f4 = self.gap(x4).flatten(1) | |
| return torch.cat([f2, f3, f4], dim=1) | |
| # ---------------- LOAD MODELS ---------------- | |
| backbone = ResNetMultiLayer().to(DEVICE) | |
| backbone.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| backbone.eval() | |
| autocolor = AutoColor() | |
| texturebank = TextureBank() | |
| bovw = ORBBoVW(n_words=64) | |
| bovw.kmeans = kmeans # dummy for transform() | |
| TF_INFER = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) | |
| ]) | |
| # ---------------- FEATURE EXTRACTION ---------------- | |
| def extract_single_feature(img): | |
| if isinstance(img, str): | |
| img = Image.open(img).convert("RGB") | |
| else: | |
| img = img.convert("RGB") | |
| arr = np.array(img) | |
| pil = transforms.ToPILImage()(arr) | |
| x = TF_INFER(pil).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| fcnn = backbone(x).cpu().numpy() | |
| fcnn = normalize(fcnn, norm="l2") * 0.50 | |
| # forb = bovw.transform(arr)[None, :] | |
| # forb = normalize(forb, norm="l2") * 0.10 | |
| ftex = texturebank.extract(arr)[None, :] | |
| ftex = normalize(ftex, norm="l2") * 0.30 | |
| centers, props = autocolor.extract(arr) | |
| fcol = autocolor.vectorize(centers, props)[None, :] | |
| fcol = normalize(fcol, norm="l2") * 0.10 | |
| feats = np.hstack([fcnn, ftex, fcol]).astype(np.float32) | |
| feats = normalize(feats, norm="l2") | |
| return feats | |
| def adjust_path(colab_path: str): | |
| fname = os.path.basename(colab_path) | |
| return f"{GDRIVE_FOLDER}/{fname}" | |
| def recommend_gradio(img, top_k=5): | |
| qf = extract_single_feature(img) | |
| qf = np.array(qf).reshape(1, -1) | |
| # 🔹 PAD if dimensions mismatch | |
| expected_dim = nn_index._fit_X.shape[1] # dimension nn_index was trained on | |
| if qf.shape[1] < expected_dim: | |
| padding = np.zeros((1, expected_dim - qf.shape[1]), dtype=qf.dtype) | |
| qf = np.hstack([qf, padding]) | |
| elif qf.shape[1] > expected_dim: | |
| qf = qf[:, :expected_dim] # just in case it's larger (rare) | |
| dists, idxs = nn_index.kneighbors(qf) | |
| idxs = idxs[0].tolist() | |
| results = [] | |
| for i in idxs[:top_k]: | |
| cand = IMG_PATHS[i] | |
| adjusted = adjust_path(cand) | |
| results.append(f"[View Image]({adjusted})") | |
| return "\n".join(results) | |
| # ---------------- GRADIO APP ---------------- | |
| interface = gr.Interface( | |
| fn=recommend_gradio, | |
| inputs=gr.Image(type="filepath", label="Upload an Image"), | |
| # outputs=gr.Gallery(label="Top Matches", columns=5, rows=2), | |
| outputs=gr.Markdown(), | |
| title="Image Similarity Search", | |
| description="Upload an image and find the most similar images from the dataset." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |