MobileNetDemo / app.py
Jahnavibh's picture
Add informational footer box with links to documentation
4600993
import gradio as gr
import numpy as np
import requests
from io import BytesIO
from PIL import Image
import tensorflow as tf
from huggingface_hub import hf_hub_download
# Download the TFLite model and labels from your Hugging Face repository
MODEL_REPO = "JahnaviBhansali/person-classification-tflite"
MODEL_FILE = "person_classification_flash(448x640).tflite" # Using flash model for better accuracy
SRAM_MODEL_FILE = "person_classification_sram(256x448).tflite" # SRAM model for memory-constrained devices
print("Downloading model from Hugging Face...")
# Use local file if already downloaded
import os
if os.path.exists(MODEL_FILE):
model_path = MODEL_FILE
else:
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
sram_model_path = SRAM_MODEL_FILE if os.path.exists(SRAM_MODEL_FILE) else hf_hub_download(repo_id=MODEL_REPO, filename=SRAM_MODEL_FILE) # Download SRAM model for reference
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Binary classification - Person vs No Person
class_labels = ["No Person", "Person"]
print(f"Model loaded successfully! Input shape: {input_details[0]['shape']}")
print(f"Input dtype: {input_details[0]['dtype']}")
print(f"Output shape: {output_details[0]['shape']}")
print(f"Output dtype: {output_details[0]['dtype']}")
print(f"Number of classes: {len(class_labels)}")
print(f"SRAM-optimized model also available: {SRAM_MODEL_FILE}")
# Force rebuild with modern design
print(f"Repository: {MODEL_REPO}")
def preprocess_image(image):
"""
Preprocess image for Person Classification INT8 quantized model.
"""
# Resize to 640x448 (width x height) as PIL expects (width, height)
# Model expects input shape [batch, 448, 640, 3] meaning height=448, width=640
image = image.resize((640, 448))
# Convert to numpy array
img_array = np.array(image, dtype=np.float32)
# Convert to INT8 input as expected by the model
# First normalize to [-128, 127] range
img_array = img_array.astype(np.float32)
img_array = (img_array - 128.0).astype(np.int8)
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
def classify_image(image):
"""
Classify the input image and return person detection result with confidence score.
"""
if image is None:
return "Please upload an image."
try:
# Handle different image inputs
if isinstance(image, str):
# Handle URL
response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert("RGB")
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
# Preprocess the image
input_data = preprocess_image(image)
# Set input tensor
interpreter.set_tensor(input_details[0]['index'], input_data)
# Run inference
interpreter.invoke()
# Get output tensor
output_data = interpreter.get_tensor(output_details[0]['index'])
predictions = output_data[0] # Remove batch dimension
# Convert from INT8 quantized output to probabilities
# Dequantize the output if quantization info is available
if 'quantization' in output_details[0] and output_details[0]['quantization'] is not None:
scale = output_details[0]['quantization'][0]
zero_point = output_details[0]['quantization'][1]
predictions = scale * (predictions.astype(np.float32) - zero_point)
else:
# If no quantization info, assume output is already in correct format
predictions = predictions.astype(np.float32)
# For binary classification, get the probability
# The model outputs a single value for person probability
if len(predictions.shape) == 0 or predictions.shape[0] == 1:
# Single output - probability of person
person_prob = float(predictions)
else:
# If it outputs two values, use softmax
predictions = tf.nn.softmax(predictions).numpy()
person_prob = predictions[1] if len(predictions) > 1 else predictions[0]
# Determine classification
is_person = person_prob > 0.5
class_name = "Person" if is_person else "No Person"
confidence = person_prob if is_person else (1 - person_prob)
# Create formatted output
result_text = f"**Detection Result**\n\n**{class_name}**: {confidence:.1%}"
return result_text
except Exception as e:
import traceback
error_msg = f"Error processing image: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg) # Log to console
return f"Error processing image: {str(e)}"
def load_example_image(example_path):
"""Load example images for demonstration."""
example_urls = {
"Person": "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?w=500",
"Group": "https://images.unsplash.com/photo-1529156069898-49953e39b3ac?w=500",
"Empty Room": "https://images.unsplash.com/photo-1486304873000-235643847519?w=500",
"Landscape": "https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=500"
}
if example_path in example_urls:
try:
response = requests.get(example_urls[example_path])
return Image.open(BytesIO(response.content))
except:
return None
return None
# Create Gradio interface
with gr.Blocks(
theme=gr.themes.Default(primary_hue="blue", neutral_hue="gray"),
title="Person Classification",
css="""
body {
background: #fafafa !important;
}
.gradio-container {
max-width: none !important;
margin: 0 !important;
background-color: #fafafa !important;
font-family: 'Inter', 'Segoe UI', -apple-system, sans-serif !important;
width: 100vw !important;
}
.main-header {
text-align: center;
margin: 0 !important;
color: #3b82f6 !important;
font-weight: 600;
font-size: 2.5rem;
letter-spacing: -0.025em;
}
.card {
background: #fafafa !important;
border-radius: 12px !important;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
border: 1px solid #e5e7eb !important;
margin-bottom: 1.5rem !important;
transition: all 0.2s ease-in-out !important;
overflow: hidden !important;
}
.card > * {
padding: 0 !important;
margin: 0 !important;
}
.card:hover {
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05) !important;
transform: translateY(-1px) !important;
}
.card-header {
background: linear-gradient(135deg, #1975cf 0%, #1557b0 100%) !important;
color: white !important;
padding: 1rem 1.5rem !important;
border-radius: 12px 12px 0 0 !important;
font-weight: 600 !important;
font-size: 1.1rem !important;
}
.card-header * {
color: white !important;
}
.card-content {
padding: 1.5rem !important;
color: #4b5563 !important;
line-height: 1.6 !important;
background: #fafafa !important;
}
.stats-grid {
display: grid !important;
grid-template-columns: 1fr 1fr !important;
gap: 1.5rem !important;
margin-top: 1.5rem !important;
}
.stat-item {
background: #f8fafc !important;
padding: 1rem !important;
border-radius: 8px !important;
border-left: 4px solid #1975cf !important;
}
.stat-label {
font-weight: 600 !important;
color: #4b5563 !important;
font-size: 0.9rem !important;
margin-bottom: 0.5rem !important;
}
.stat-value {
color: #4b5563 !important;
font-size: 0.85rem !important;
}
.btn-example {
background: #f1f5f9 !important;
border: 1px solid #cbd5e1 !important;
color: #4b5563 !important;
border-radius: 6px !important;
transition: all 0.2s ease !important;
margin: 0.35rem !important;
padding: 0.5rem 1rem !important;
}
.btn-example:hover {
background: #1975cf !important;
border-color: #1975cf !important;
color: white !important;
}
.btn-primary {
background: #1975cf !important;
border-color: #1975cf !important;
color: white !important;
}
.btn-primary:hover {
background: #1557b0 !important;
border-color: #1557b0 !important;
}
.markdown {
color: #374151 !important;
}
.results-text {
color: #4b5563 !important;
font-weight: 500 !important;
padding: 0 !important;
margin: 0 !important;
}
.results-text p {
color: #4b5563 !important;
margin: 0.5rem 0 !important;
}
.results-text * {
color: #4b5563 !important;
}
div[data-testid="markdown"] p {
color: #4b5563 !important;
}
.prose {
color: #4b5563 !important;
}
.prose * {
color: #4b5563 !important;
}
.card-header,
.card-header * {
color: white !important;
}
.example-grid {
display: grid !important;
grid-template-columns: 1fr !important;
gap: 1.5rem !important;
margin-top: 1.5rem !important;
}
.example-item {
background: #f8fafc !important;
padding: 1rem !important;
border-radius: 8px !important;
border-left: 4px solid #1975cf !important;
}
.example-label {
font-weight: 600 !important;
color: #1975cf !important;
font-size: 0.9rem !important;
margin-bottom: 0.5rem !important;
}
.example-buttons {
color: #374151 !important;
font-size: 0.85rem !important;
}
.results-grid {
display: grid !important;
grid-template-columns: 1fr !important;
gap: 1.5rem !important;
margin-top: 1.5rem !important;
}
.results-item {
background: #f8fafc !important;
padding: 1rem !important;
border-radius: 8px !important;
border-left: 4px solid #1975cf !important;
}
.results-label {
font-weight: 600 !important;
color: #1975cf !important;
font-size: 0.9rem !important;
margin-bottom: 0.5rem !important;
}
.results-content {
color: #374151 !important;
font-size: 0.85rem !important;
}
.custom-footer {
max-width: 800px !important;
margin: 2rem auto !important;
background: white !important;
border-radius: 12px !important;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
border: 1px solid #e5e7eb !important;
padding: 1.5rem !important;
text-align: center !important;
}
.custom-footer a {
color: #1975cf !important;
text-decoration: none !important;
font-weight: 600 !important;
}
.custom-footer a:hover {
text-decoration: underline !important;
}
"""
) as demo:
gr.HTML("""
<div class="main-header">
<h1>Person Classification</h1>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="",
type="pil",
height=280
)
classify_btn = gr.Button(
"Classify Image",
variant="primary",
size="lg",
elem_classes=["btn-primary"]
)
with gr.Group(elem_classes=["card"]):
gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Example Images</span></div>')
with gr.Column(elem_classes=["card-content"]):
with gr.Row():
example_person = gr.Button("Person", size="sm", elem_classes=["btn-example"])
example_group = gr.Button("Group", size="sm", elem_classes=["btn-example"])
with gr.Row():
example_empty = gr.Button("Empty Room", size="sm", elem_classes=["btn-example"])
example_landscape = gr.Button("Landscape", size="sm", elem_classes=["btn-example"])
with gr.Column(scale=1):
gr.HTML("""
<div class="card">
<div class="card-header">
<span style="color: white; font-weight: 600;">Model Performance</span>
</div>
<div class="card-content">
<div class="stats-grid">
<div class="stat-item">
<div class="stat-label">Accelerator</div>
<div class="stat-value">
Configuration: Ethos_U55_128<br>
Clock: 400 MHz
</div>
</div>
<div class="stat-item">
<div class="stat-label">Memory Usage</div>
<div class="stat-value">
Total SRAM: 1205.00 KiB<br>
Total Flash: 1460.69 KiB
</div>
</div>
<div class="stat-item">
<div class="stat-label">Operator Distribution</div>
<div class="stat-value">
CPU Operators: 0 (0.0%)<br>
NPU Operators: 87 (100.0%)
</div>
</div>
<div class="stat-item">
<div class="stat-label">Performance</div>
<div class="stat-value">
Inference time: 37.20 ms
</div>
</div>
</div>
</div>
</div>
""")
with gr.Group(elem_classes=["card"]):
gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Classification Results</span></div>')
with gr.Column(elem_classes=["card-content"]):
output_text = gr.Markdown(
value="Upload an image to see predictions...",
label="",
elem_classes=["results-text"]
)
# Set up event handlers
classify_btn.click(
fn=classify_image,
inputs=input_image,
outputs=output_text
)
# Example image handlers
example_person.click(lambda: load_example_image("Person"), outputs=input_image)
example_group.click(lambda: load_example_image("Group"), outputs=input_image)
example_empty.click(lambda: load_example_image("Empty Room"), outputs=input_image)
example_landscape.click(lambda: load_example_image("Landscape"), outputs=input_image)
# Auto-classify when image is uploaded
input_image.change(
fn=classify_image,
inputs=input_image,
outputs=output_text
)
# Footer
gr.HTML("""
<div class="custom-footer">
<div style="margin-bottom: 0.5rem;">
For a detailed walkthrough, please see our
<a href="http://localhost:3000/sr/evaluate-sr" target="_blank">Evaluate Model Guide</a>.
</div>
<div>
To get started quickly, visit our
<a href="http://localhost:3000/sr/quick-start" target="_blank">SR Quick Start page</a>.
</div>
</div>
""")
# Launch the demo
if __name__ == "__main__":
demo.launch(show_api=False)