JahnaviBhansali commited on
Commit
29829ab
·
verified ·
1 Parent(s): ee795f9

Upload app (1).py

Browse files
Files changed (1) hide show
  1. app (1).py +404 -0
app (1).py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import requests
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import tensorflow as tf
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # Download the TFLite model and labels from your Hugging Face repository
10
+ MODEL_REPO = "JahnaviBhansali/mobilenet-v2-ethos-u55"
11
+ MODEL_FILE = "mobilenet_v2_1.0_224_INT8.tflite" # Using original INT8 model for Gradio compatibility
12
+ VELA_MODEL_FILE = "mobilenet_v2_1.0_224_INT8_vela.tflite" # Vela-optimized model for Ethos-U55
13
+ LABELS_FILE = "labelmappings.txt"
14
+
15
+ print("Downloading model and labels from Hugging Face...")
16
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
17
+ vela_model_path = hf_hub_download(repo_id=MODEL_REPO, filename=VELA_MODEL_FILE) # Download Vela model for reference
18
+ labels_path = hf_hub_download(repo_id=MODEL_REPO, filename=LABELS_FILE)
19
+
20
+ # Load the TFLite model
21
+ interpreter = tf.lite.Interpreter(model_path=model_path)
22
+ interpreter.allocate_tensors()
23
+
24
+ # Get input and output details
25
+ input_details = interpreter.get_input_details()
26
+ output_details = interpreter.get_output_details()
27
+
28
+ # Load class labels
29
+ with open(labels_path, 'r') as f:
30
+ class_labels = [line.strip() for line in f.readlines()]
31
+
32
+ print(f"Model loaded successfully! Input shape: {input_details[0]['shape']}")
33
+ print(f"Number of classes: {len(class_labels)}")
34
+ print(f"Vela-optimized model also available: {VELA_MODEL_FILE}")
35
+ # Force rebuild with modern design
36
+ print(f"Repository: {MODEL_REPO}")
37
+
38
+ def preprocess_image(image):
39
+ """
40
+ Preprocess image for MobileNetV2 INT8 quantized model.
41
+ """
42
+ # Resize to 224x224 as expected by the model
43
+ image = image.resize((224, 224))
44
+
45
+ # Convert to numpy array
46
+ img_array = np.array(image, dtype=np.float32)
47
+
48
+ # Normalize to [0, 1] then scale to [-1, 1] for MobileNetV2
49
+ img_array = img_array / 255.0
50
+ img_array = (img_array - 0.5) * 2.0
51
+
52
+ # Quantize to INT8 range [-128, 127]
53
+ img_array = img_array * 127.0
54
+ img_array = np.clip(img_array, -128, 127).astype(np.int8)
55
+
56
+ # Add batch dimension
57
+ img_array = np.expand_dims(img_array, axis=0)
58
+
59
+ return img_array
60
+
61
+ def classify_image(image):
62
+ """
63
+ Classify the input image and return top-3 predictions with confidence scores.
64
+ """
65
+ if image is None:
66
+ return "Please upload an image."
67
+
68
+ try:
69
+ # Handle different image inputs
70
+ if isinstance(image, str):
71
+ # Handle URL
72
+ response = requests.get(image)
73
+ image = Image.open(BytesIO(response.content)).convert("RGB")
74
+ elif isinstance(image, np.ndarray):
75
+ image = Image.fromarray(image).convert("RGB")
76
+ else:
77
+ image = image.convert("RGB")
78
+
79
+ # Preprocess the image
80
+ input_data = preprocess_image(image)
81
+
82
+ # Set input tensor
83
+ interpreter.set_tensor(input_details[0]['index'], input_data)
84
+
85
+ # Run inference
86
+ interpreter.invoke()
87
+
88
+ # Get output tensor
89
+ output_data = interpreter.get_tensor(output_details[0]['index'])
90
+ predictions = output_data[0] # Remove batch dimension
91
+
92
+ # Convert from INT8 quantized output to probabilities
93
+ # Dequantize the output
94
+ scale = output_details[0]['quantization'][0]
95
+ zero_point = output_details[0]['quantization'][1]
96
+ predictions = scale * (predictions.astype(np.float32) - zero_point)
97
+
98
+ # Apply softmax to get probabilities
99
+ predictions = tf.nn.softmax(predictions).numpy()
100
+
101
+ # Get top-3 predictions
102
+ top3_indices = np.argsort(predictions)[-3:][::-1]
103
+
104
+ # Format results
105
+ results = []
106
+ for i, idx in enumerate(top3_indices):
107
+ class_name = class_labels[idx] if idx < len(class_labels) else f"Class {idx}"
108
+ confidence = predictions[idx]
109
+ results.append(f"**{class_name}**: {confidence:.1%}")
110
+
111
+ # Create formatted output
112
+ result_text = "\n".join(f"{idx+1}. {result}" for idx, result in enumerate(results))
113
+
114
+ return result_text
115
+
116
+ except Exception:
117
+ return "Error processing image. Please try again."
118
+
119
+ def load_example_image(example_path):
120
+ """Load example images for demonstration."""
121
+ example_urls = {
122
+ "Cat": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
123
+ "Dog": "https://images.unsplash.com/photo-1587300003388-59208cc962cb?w=500",
124
+ "Car": "https://images.unsplash.com/photo-1494905998402-395d579af36f?w=500",
125
+ "Food": "https://images.unsplash.com/photo-1565299624946-b28f40a0ca4b?w=500",
126
+ "Nature": "https://images.unsplash.com/photo-1441974231531-c6227db76b6e?w=500"
127
+ }
128
+
129
+ if example_path in example_urls:
130
+ try:
131
+ response = requests.get(example_urls[example_path])
132
+ return Image.open(BytesIO(response.content))
133
+ except:
134
+ return None
135
+ return None
136
+
137
+ # Create Gradio interface
138
+ with gr.Blocks(
139
+ theme=gr.themes.Default(),
140
+ title="MobileNetV2 Classification",
141
+ css="""
142
+ .gradio-container {
143
+ max-width: 1200px !important;
144
+ margin: auto !important;
145
+ background-color: #fafafa !important;
146
+ font-family: 'Inter', 'Segoe UI', -apple-system, sans-serif !important;
147
+ }
148
+ .main-header {
149
+ text-align: center;
150
+ margin: 2rem 0 3rem 0;
151
+ color: #3b82f6 !important;
152
+ font-weight: 600;
153
+ font-size: 2.5rem;
154
+ letter-spacing: -0.025em;
155
+ }
156
+ .card {
157
+ background: white !important;
158
+ border-radius: 12px !important;
159
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
160
+ border: 1px solid #e5e7eb !important;
161
+ margin-bottom: 1.5rem !important;
162
+ transition: all 0.2s ease-in-out !important;
163
+ overflow: hidden !important;
164
+ }
165
+ .card > * {
166
+ padding: 0 !important;
167
+ margin: 0 !important;
168
+ }
169
+ .card:hover {
170
+ box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05) !important;
171
+ transform: translateY(-1px) !important;
172
+ }
173
+ .card-header {
174
+ background: linear-gradient(135deg, #1975cf 0%, #1557b0 100%) !important;
175
+ color: white !important;
176
+ padding: 1rem 1.5rem !important;
177
+ border-radius: 12px 12px 0 0 !important;
178
+ font-weight: 600 !important;
179
+ font-size: 1.1rem !important;
180
+ }
181
+ .card-header * {
182
+ color: white !important;
183
+ }
184
+ .card-content {
185
+ padding: 1.5rem !important;
186
+ color: #4b5563 !important;
187
+ line-height: 1.6 !important;
188
+ background: white !important;
189
+ }
190
+ .stats-grid {
191
+ display: grid !important;
192
+ grid-template-columns: 1fr 1fr !important;
193
+ gap: 1.5rem !important;
194
+ margin-top: 1.5rem !important;
195
+ }
196
+ .stat-item {
197
+ background: #f8fafc !important;
198
+ padding: 1rem !important;
199
+ border-radius: 8px !important;
200
+ border-left: 4px solid #1975cf !important;
201
+ }
202
+ .stat-label {
203
+ font-weight: 600 !important;
204
+ color: #4b5563 !important;
205
+ font-size: 0.9rem !important;
206
+ margin-bottom: 0.5rem !important;
207
+ }
208
+ .stat-value {
209
+ color: #4b5563 !important;
210
+ font-size: 0.85rem !important;
211
+ }
212
+ .btn-example {
213
+ background: #f1f5f9 !important;
214
+ border: 1px solid #cbd5e1 !important;
215
+ color: #4b5563 !important;
216
+ border-radius: 6px !important;
217
+ transition: all 0.2s ease !important;
218
+ margin: 0.35rem !important;
219
+ padding: 0.5rem 1rem !important;
220
+ }
221
+ .btn-example:hover {
222
+ background: #1975cf !important;
223
+ border-color: #1975cf !important;
224
+ color: white !important;
225
+ }
226
+ .btn-primary {
227
+ background: #1975cf !important;
228
+ border-color: #1975cf !important;
229
+ color: white !important;
230
+ }
231
+ .btn-primary:hover {
232
+ background: #1557b0 !important;
233
+ border-color: #1557b0 !important;
234
+ }
235
+ .markdown {
236
+ color: #374151 !important;
237
+ }
238
+ .results-text {
239
+ color: #4b5563 !important;
240
+ font-weight: 500 !important;
241
+ padding: 0 !important;
242
+ margin: 0 !important;
243
+ }
244
+ .results-text p {
245
+ color: #4b5563 !important;
246
+ margin: 0.5rem 0 !important;
247
+ }
248
+ .results-text * {
249
+ color: #4b5563 !important;
250
+ }
251
+ div[data-testid="markdown"] p {
252
+ color: #4b5563 !important;
253
+ }
254
+ .prose {
255
+ color: #4b5563 !important;
256
+ }
257
+ .prose * {
258
+ color: #4b5563 !important;
259
+ }
260
+ .example-grid {
261
+ display: grid !important;
262
+ grid-template-columns: 1fr !important;
263
+ gap: 1.5rem !important;
264
+ margin-top: 1.5rem !important;
265
+ }
266
+ .example-item {
267
+ background: #f8fafc !important;
268
+ padding: 1rem !important;
269
+ border-radius: 8px !important;
270
+ border-left: 4px solid #1975cf !important;
271
+ }
272
+ .example-label {
273
+ font-weight: 600 !important;
274
+ color: #1975cf !important;
275
+ font-size: 0.9rem !important;
276
+ margin-bottom: 0.5rem !important;
277
+ }
278
+ .example-buttons {
279
+ color: #374151 !important;
280
+ font-size: 0.85rem !important;
281
+ }
282
+ .results-grid {
283
+ display: grid !important;
284
+ grid-template-columns: 1fr !important;
285
+ gap: 1.5rem !important;
286
+ margin-top: 1.5rem !important;
287
+ }
288
+ .results-item {
289
+ background: #f8fafc !important;
290
+ padding: 1rem !important;
291
+ border-radius: 8px !important;
292
+ border-left: 4px solid #1975cf !important;
293
+ }
294
+ .results-label {
295
+ font-weight: 600 !important;
296
+ color: #1975cf !important;
297
+ font-size: 0.9rem !important;
298
+ margin-bottom: 0.5rem !important;
299
+ }
300
+ .results-content {
301
+ color: #374151 !important;
302
+ font-size: 0.85rem !important;
303
+ }
304
+ """
305
+ ) as demo:
306
+
307
+ gr.HTML("""
308
+ <div class="main-header">
309
+ <h1>MobileNetV2 Classification</h1>
310
+ </div>
311
+ """)
312
+
313
+ with gr.Row():
314
+ with gr.Column(scale=1):
315
+
316
+ input_image = gr.Image(
317
+ label="",
318
+ type="pil",
319
+ height=280
320
+ )
321
+
322
+ classify_btn = gr.Button(
323
+ "Classify Image",
324
+ variant="primary",
325
+ size="lg",
326
+ elem_classes=["btn-primary"]
327
+ )
328
+
329
+ with gr.Group(elem_classes=["card"]):
330
+ gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Example Images</span></div>')
331
+
332
+ with gr.Column(elem_classes=["card-content"]):
333
+ with gr.Row():
334
+ example_cat = gr.Button("Cat", size="sm", elem_classes=["btn-example"])
335
+ example_dog = gr.Button("Dog", size="sm", elem_classes=["btn-example"])
336
+
337
+ with gr.Row():
338
+ example_car = gr.Button("Car", size="sm", elem_classes=["btn-example"])
339
+ example_food = gr.Button("Food", size="sm", elem_classes=["btn-example"])
340
+
341
+ with gr.Column(scale=1):
342
+ gr.HTML("""
343
+ <div class="card">
344
+ <div class="card-header">
345
+ <span style="color: white; font-weight: 600;">Model Performance</span>
346
+ </div>
347
+ <div class="card-content">
348
+ <div class="stats-grid">
349
+ <div class="stat-item">
350
+ <div class="stat-label">Performance</div>
351
+ <div class="stat-value">
352
+ 6M cycles/inference<br>
353
+ 15.14ms @ 400MHz<br>
354
+ NPU Coverage: 100%<br>
355
+ ImageNet Top-1: 69.7%
356
+ </div>
357
+ </div>
358
+ <div class="stat-item">
359
+ <div class="stat-label">Memory Usage</div>
360
+ <div class="stat-value">
361
+ SRAM: 353.5 KiB<br>
362
+ Flash: 3.6 MiB<br>
363
+ Model: MobileNetV2<br>
364
+ Input: 224×224×3
365
+ </div>
366
+ </div>
367
+ </div>
368
+ </div>
369
+ </div>
370
+ """)
371
+
372
+ with gr.Group(elem_classes=["card"]):
373
+ gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Classification Results</span></div>')
374
+
375
+ with gr.Column(elem_classes=["card-content"]):
376
+ output_text = gr.Markdown(
377
+ value="Upload an image to see predictions...",
378
+ label="",
379
+ elem_classes=["results-text"]
380
+ )
381
+
382
+ # Set up event handlers
383
+ classify_btn.click(
384
+ fn=classify_image,
385
+ inputs=input_image,
386
+ outputs=output_text
387
+ )
388
+
389
+ # Example image handlers
390
+ example_cat.click(lambda: load_example_image("Cat"), outputs=input_image)
391
+ example_dog.click(lambda: load_example_image("Dog"), outputs=input_image)
392
+ example_car.click(lambda: load_example_image("Car"), outputs=input_image)
393
+ example_food.click(lambda: load_example_image("Food"), outputs=input_image)
394
+
395
+ # Auto-classify when image is uploaded
396
+ input_image.change(
397
+ fn=classify_image,
398
+ inputs=input_image,
399
+ outputs=output_text
400
+ )
401
+
402
+ # Launch the demo
403
+ if __name__ == "__main__":
404
+ demo.launch()