import gradio as gr import numpy as np import requests from io import BytesIO from PIL import Image import tensorflow as tf from huggingface_hub import hf_hub_download # Download the TFLite model and labels from your Hugging Face repository MODEL_REPO = "JahnaviBhansali/mobilenet-v2-ethos-u55" MODEL_FILE = "mobilenet_v2_1.0_224_INT8.tflite" # Using original INT8 model for Gradio compatibility VELA_MODEL_FILE = "mobilenet_v2_1.0_224_INT8_vela.tflite" # Vela-optimized model for Ethos-U55 LABELS_FILE = "labelmappings.txt" print("Downloading model and labels from Hugging Face...") model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) vela_model_path = hf_hub_download(repo_id=MODEL_REPO, filename=VELA_MODEL_FILE) # Download Vela model for reference labels_path = hf_hub_download(repo_id=MODEL_REPO, filename=LABELS_FILE) # Load the TFLite model interpreter = tf.lite.Interpreter(model_path=model_path) interpreter.allocate_tensors() # Get input and output details input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # Load class labels with open(labels_path, 'r') as f: class_labels = [line.strip() for line in f.readlines()] print(f"Model loaded successfully! Input shape: {input_details[0]['shape']}") print(f"Number of classes: {len(class_labels)}") print(f"Vela-optimized model also available: {VELA_MODEL_FILE}") # Force rebuild with modern design print(f"Repository: {MODEL_REPO}") def preprocess_image(image): """ Preprocess image for MobileNetV2 INT8 quantized model. """ # Resize to 224x224 as expected by the model image = image.resize((224, 224)) # Convert to numpy array img_array = np.array(image, dtype=np.float32) # Normalize to [0, 1] then scale to [-1, 1] for MobileNetV2 img_array = img_array / 255.0 img_array = (img_array - 0.5) * 2.0 # Quantize to INT8 range [-128, 127] img_array = img_array * 127.0 img_array = np.clip(img_array, -128, 127).astype(np.int8) # Add batch dimension img_array = np.expand_dims(img_array, axis=0) return img_array def classify_image(image): """ Classify the input image and return top-3 predictions with confidence scores. """ if image is None: return "Please upload an image." try: # Handle different image inputs if isinstance(image, str): # Handle URL response = requests.get(image) image = Image.open(BytesIO(response.content)).convert("RGB") elif isinstance(image, np.ndarray): image = Image.fromarray(image).convert("RGB") else: image = image.convert("RGB") # Preprocess the image input_data = preprocess_image(image) # Set input tensor interpreter.set_tensor(input_details[0]['index'], input_data) # Run inference interpreter.invoke() # Get output tensor output_data = interpreter.get_tensor(output_details[0]['index']) predictions = output_data[0] # Remove batch dimension # Convert from INT8 quantized output to probabilities # Dequantize the output scale = output_details[0]['quantization'][0] zero_point = output_details[0]['quantization'][1] predictions = scale * (predictions.astype(np.float32) - zero_point) # Apply softmax to get probabilities predictions = tf.nn.softmax(predictions).numpy() # Get top-3 predictions top3_indices = np.argsort(predictions)[-3:][::-1] # Format results results = [] for i, idx in enumerate(top3_indices): class_name = class_labels[idx] if idx < len(class_labels) else f"Class {idx}" confidence = predictions[idx] results.append(f"**{class_name}**: {confidence:.1%}") # Create formatted output result_text = "\n".join(f"{idx+1}. {result}" for idx, result in enumerate(results)) return result_text except Exception: return "Error processing image. Please try again." def load_example_image(example_path): """Load example images for demonstration.""" example_urls = { "Cat": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", "Dog": "https://images.unsplash.com/photo-1587300003388-59208cc962cb?w=500", "Car": "https://images.unsplash.com/photo-1494905998402-395d579af36f?w=500", "Food": "https://images.unsplash.com/photo-1565299624946-b28f40a0ca4b?w=500", "Nature": "https://images.unsplash.com/photo-1441974231531-c6227db76b6e?w=500" } if example_path in example_urls: try: response = requests.get(example_urls[example_path]) return Image.open(BytesIO(response.content)) except: return None return None # Create Gradio interface with gr.Blocks( theme=gr.themes.Default(), title="MobileNetV2 Classification", css=""" .gradio-container { max-width: 1200px !important; margin: auto !important; background-color: white !important; font-family: 'Inter', 'Segoe UI', -apple-system, sans-serif !important; } .main-header { text-align: center; margin: 2rem 0 3rem 0; color: #3b82f6 !important; font-weight: 600; font-size: 2.5rem; letter-spacing: -0.025em; } .card { background: white !important; border-radius: 12px !important; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important; border: 1px solid #e5e7eb !important; margin-bottom: 1.5rem !important; transition: all 0.2s ease-in-out !important; overflow: hidden !important; } .card > * { padding: 0 !important; margin: 0 !important; } .card:hover { box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05) !important; transform: translateY(-1px) !important; } .card-header { background: linear-gradient(135deg, #1975cf 0%, #1557b0 100%) !important; color: white !important; padding: 1rem 1.5rem !important; border-radius: 12px 12px 0 0 !important; font-weight: 600 !important; font-size: 1.1rem !important; } .card-header * { color: white !important; } .card-content { padding: 1.5rem !important; color: #4b5563 !important; line-height: 1.6 !important; background: white !important; } .stats-grid { display: grid !important; grid-template-columns: 1fr 1fr !important; gap: 1.5rem !important; margin-top: 1.5rem !important; } .stat-item { background: #f8fafc !important; padding: 1rem !important; border-radius: 8px !important; border-left: 4px solid #1975cf !important; } .stat-label { font-weight: 600 !important; color: #4b5563 !important; font-size: 0.9rem !important; margin-bottom: 0.5rem !important; } .stat-value { color: #4b5563 !important; font-size: 0.85rem !important; } .btn-example { background: #f1f5f9 !important; border: 1px solid #cbd5e1 !important; color: #4b5563 !important; border-radius: 6px !important; transition: all 0.2s ease !important; margin: 0.35rem !important; padding: 0.5rem 1rem !important; } .btn-example:hover { background: #1975cf !important; border-color: #1975cf !important; color: white !important; } .btn-primary { background: #1975cf !important; border-color: #1975cf !important; color: white !important; } .btn-primary:hover { background: #1557b0 !important; border-color: #1557b0 !important; } .markdown { color: #374151 !important; } .results-text { color: #4b5563 !important; font-weight: 500 !important; padding: 0 !important; margin: 0 !important; } .results-text p { color: #4b5563 !important; margin: 0.5rem 0 !important; } .results-text * { color: #4b5563 !important; } div[data-testid="markdown"] p { color: #4b5563 !important; } .prose { color: #4b5563 !important; } .prose * { color: #4b5563 !important; } .card-header, .card-header * { color: white !important; } .example-grid { display: grid !important; grid-template-columns: 1fr !important; gap: 1.5rem !important; margin-top: 1.5rem !important; } .example-item { background: #f8fafc !important; padding: 1rem !important; border-radius: 8px !important; border-left: 4px solid #1975cf !important; } .example-label { font-weight: 600 !important; color: #1975cf !important; font-size: 0.9rem !important; margin-bottom: 0.5rem !important; } .example-buttons { color: #374151 !important; font-size: 0.85rem !important; } .results-grid { display: grid !important; grid-template-columns: 1fr !important; gap: 1.5rem !important; margin-top: 1.5rem !important; } .results-item { background: #f8fafc !important; padding: 1rem !important; border-radius: 8px !important; border-left: 4px solid #1975cf !important; } .results-label { font-weight: 600 !important; color: #1975cf !important; font-size: 0.9rem !important; margin-bottom: 0.5rem !important; } .results-content { color: #374151 !important; font-size: 0.85rem !important; } """ ) as demo: gr.HTML("""