EnginDev commited on
Commit
6e85785
·
verified ·
1 Parent(s): 82a6cd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +444 -43
app.py CHANGED
@@ -1,63 +1,464 @@
1
  import gradio as gr
2
- from transformers import SamProcessor, SamModel
3
- from PIL import Image
4
  import torch
5
  import numpy as np
6
- import random
7
- import traceback
8
 
9
- # Modell laden
10
- model_id = "facebook/sam-vit-base"
11
- processor = SamProcessor.from_pretrained(model_id)
12
- model = SamModel.from_pretrained(model_id)
13
 
14
- def random_color():
15
- """Zufällige RGB-Farbe"""
16
- return [random.randint(0, 255) for _ in range(3)]
17
 
18
- def segment_image(image):
19
- try:
20
- device = torch.device("cpu")
21
- model.to(device)
22
 
23
- inputs = processor(images=image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- with torch.no_grad():
26
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- masks = processor.post_process_masks(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  outputs.pred_masks.cpu(),
30
  inputs["original_sizes"].cpu(),
31
  inputs["reshaped_input_sizes"].cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- mask_arrays = masks[0].numpy()
35
- img_array = np.array(image)
36
- overlay = np.zeros_like(img_array, dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Jede Maske farbig einfärben
39
- for mask in mask_arrays:
40
- mask = mask[0]
41
- color = random_color()
42
- for c in range(3):
43
- overlay[:, :, c] = np.where(mask > 0.5, color[c], overlay[:, :, c])
44
 
45
- # Stärkere Farbmischung (80 % Maske / 20 % Original)
46
- blended = Image.fromarray(
47
- (0.2 * img_array + 0.8 * overlay).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- return blended
51
-
52
- except Exception:
53
- return f"Fehler:\n{traceback.format_exc()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- demo = gr.Interface(
56
- fn=segment_image,
57
- inputs=gr.Image(type="pil", label="Upload your fish image"),
58
- outputs=gr.Image(type="pil", label="Segmented Output"),
59
- title="FishBoost – Colorful SAM Segmentation (Enhanced Colors)",
60
- description="Erzeugt kräftige, farbige Masken mit Meta SAM (CPU-Version)."
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- demo.launch()
 
 
 
1
  import gradio as gr
 
 
2
  import torch
3
  import numpy as np
4
+ from PIL import Image
5
+ import cv2
6
 
7
+ print("🚀 Starting SAM2 App v2.1 - OPTIMIZED...")
 
 
 
8
 
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"📱 Using device: {device}")
 
11
 
12
+ model = None
13
+ processor = None
 
 
14
 
15
+ def load_model():
16
+ global model, processor
17
+ if model is None:
18
+ print("📦 Loading SAM model...")
19
+ try:
20
+ from transformers import SamModel, SamProcessor
21
+
22
+ model_name = "facebook/sam-vit-large"
23
+
24
+ processor = SamProcessor.from_pretrained(model_name)
25
+ model = SamModel.from_pretrained(model_name)
26
+ model.to(device)
27
+ print(f"✅ Model loaded: {model_name}")
28
+ except Exception as e:
29
+ print(f"❌ Error: {e}, falling back to base model")
30
+ model_name = "facebook/sam-vit-base"
31
+ processor = SamProcessor.from_pretrained(model_name)
32
+ model = SamModel.from_pretrained(model_name)
33
+ model.to(device)
34
+ return model, processor
35
 
36
+ def prepare_image(image, max_size=1024):
37
+ if isinstance(image, np.ndarray):
38
+ image_pil = Image.fromarray(image)
39
+ else:
40
+ image_pil = image
41
+
42
+ if image_pil.mode != 'RGB':
43
+ image_pil = image_pil.convert('RGB')
44
+
45
+ image_np = np.array(image_pil)
46
+ h, w = image_np.shape[:2]
47
+
48
+ if max(h, w) > max_size:
49
+ scale = max_size / max(h, w)
50
+ new_h, new_w = int(h * scale), int(w * scale)
51
+ image_pil = image_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
52
+ image_np = np.array(image_pil)
53
+
54
+ return image_pil, image_np
55
+
56
+ def refine_mask(mask, kernel_size=5):
57
+ """Glättet Maskenkanten"""
58
+ mask_uint8 = (mask > 0).astype(np.uint8) * 255
59
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
60
+ mask_closed = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel)
61
+ mask_refined = cv2.morphologyEx(mask_closed, cv2.MORPH_OPEN, kernel)
62
+ return mask_refined > 0
63
 
64
+ def segment_automatic(image, quality="high", merge_parts=True):
65
+ """
66
+ OPTIMIERTE Automatische Segmentierung
67
+ Schnell & präzise - kombiniert mehrere Masken
68
+ """
69
+ if image is None:
70
+ return None, {"error": "Kein Bild hochgeladen"}
71
+
72
+ try:
73
+ print(f"🔄 Starting segmentation (quality: {quality}, merge: {merge_parts})...")
74
+ model, processor = load_model()
75
+
76
+ image_pil, image_np = prepare_image(image)
77
+ h, w = image_np.shape[:2]
78
+
79
+ center_x, center_y = w // 2, h // 2
80
+
81
+ # Single point inference mit multimask_output
82
+ inputs = processor(
83
+ image_pil,
84
+ input_points=[[[center_x, center_y]]],
85
+ input_labels=[[1]],
86
+ return_tensors="pt"
87
+ ).to(device)
88
+
89
+ print("🧠 Running inference...")
90
+ with torch.no_grad():
91
+ outputs = model(**inputs, multimask_output=True)
92
+
93
+ masks = processor.image_processor.post_process_masks(
94
  outputs.pred_masks.cpu(),
95
  inputs["original_sizes"].cpu(),
96
  inputs["reshaped_input_sizes"].cpu()
97
+ )[0]
98
+
99
+ scores = outputs.iou_scores.cpu().numpy()
100
+ if scores.ndim > 1:
101
+ scores = scores.flatten()
102
+
103
+ print(f"✅ Got {len(scores)} masks with scores: {scores}")
104
+
105
+ # SMART MERGING: Kombiniere alle guten Masken
106
+ if merge_parts:
107
+ combined_mask = np.zeros((h, w), dtype=bool)
108
+ masks_used = 0
109
+
110
+ for idx, score in enumerate(scores):
111
+ if score > 0.5: # Nur Masken mit gutem Score
112
+ if masks.ndim == 4:
113
+ mask = masks[0, idx].numpy()
114
+ else:
115
+ mask = masks[idx].numpy()
116
+
117
+ # OR-Kombination (super schnell!)
118
+ combined_mask = combined_mask | (mask > 0)
119
+ masks_used += 1
120
+ print(f" ✅ Added mask {idx} (score: {score:.3f})")
121
+
122
+ final_mask = combined_mask
123
+ print(f"🔗 Combined {masks_used} masks into one!")
124
+ else:
125
+ # Nur beste Maske
126
+ best_idx = np.argmax(scores)
127
+ if masks.ndim == 4:
128
+ final_mask = masks[0, best_idx].numpy() > 0
129
+ else:
130
+ final_mask = masks[best_idx].numpy() > 0
131
+ masks_used = 1
132
+ print(f"✅ Using best mask (score: {scores[best_idx]:.3f})")
133
+
134
+ # Refinement für glatte Kanten
135
+ if quality == "high":
136
+ print("🎨 Refining mask...")
137
+ final_mask = refine_mask(final_mask, kernel_size=7)
138
+
139
+ # Overlay erstellen
140
+ overlay = image_np.copy()
141
+ color = np.array([255, 80, 180]) # Rosa/Pink
142
+
143
+ mask_float = final_mask.astype(float)
144
+ if quality == "high":
145
+ mask_float = cv2.GaussianBlur(mask_float, (5, 5), 0)
146
+
147
+ # Farbiges Overlay
148
+ for c in range(3):
149
+ overlay[:, :, c] = (
150
+ overlay[:, :, c] * (1 - mask_float * 0.65) +
151
+ color[c] * mask_float * 0.65
152
+ )
153
+
154
+ # Gelbe Kontur zeichnen
155
+ contours, _ = cv2.findContours(
156
+ final_mask.astype(np.uint8),
157
+ cv2.RETR_EXTERNAL,
158
+ cv2.CHAIN_APPROX_SIMPLE
159
  )
160
+ cv2.drawContours(overlay, contours, -1, (255, 255, 0), 3)
161
+
162
+ metadata = {
163
+ "success": True,
164
+ "mode": "automatic_plus" if merge_parts else "automatic",
165
+ "quality": quality,
166
+ "masks_combined": masks_used,
167
+ "all_scores": scores.tolist(),
168
+ "image_size": [w, h],
169
+ "mask_area": int(np.sum(final_mask)),
170
+ "mask_percentage": float(np.sum(final_mask) / (h * w) * 100),
171
+ "num_contours": len(contours),
172
+ "device": device
173
+ }
174
+
175
+ print("✅ Segmentation complete!")
176
+ return Image.fromarray(overlay.astype(np.uint8)), metadata
177
+
178
+ except Exception as e:
179
+ import traceback
180
+ print(f"❌ ERROR:\n{traceback.format_exc()}")
181
+ return image, {"error": str(e)}
182
 
183
+ def segment_multi_dense(image, density="medium"):
184
+ """Multi-Object Segmentierung mit Grid"""
185
+ if image is None:
186
+ return None, {"error": "Kein Bild"}
187
+
188
+ try:
189
+ print(f"🎯 Starting multi-region segmentation (density: {density})...")
190
+ model, processor = load_model()
191
+ image_pil, image_np = prepare_image(image)
192
+ h, w = image_np.shape[:2]
193
+
194
+ # Grid-Größe basierend auf Density
195
+ if density == "high":
196
+ grid_size = 5
197
+ elif density == "medium":
198
+ grid_size = 4
199
+ else:
200
+ grid_size = 3
201
+
202
+ # Grid-Punkte generieren
203
+ points = []
204
+ for i in range(1, grid_size + 1):
205
+ for j in range(1, grid_size + 1):
206
+ x = int(w * i / (grid_size + 1))
207
+ y = int(h * j / (grid_size + 1))
208
+ points.append([x, y])
209
+
210
+ print(f"📍 Using {len(points)} grid points ({grid_size}x{grid_size})...")
211
+
212
+ all_masks = []
213
+ all_scores = []
214
+
215
+ # Segmentiere jeden Punkt
216
+ for idx, point in enumerate(points):
217
+ inputs = processor(
218
+ image_pil,
219
+ input_points=[[point]],
220
+ input_labels=[[1]],
221
+ return_tensors="pt"
222
+ ).to(device)
223
+
224
+ with torch.no_grad():
225
+ outputs = model(**inputs, multimask_output=True)
226
+
227
+ masks = processor.image_processor.post_process_masks(
228
+ outputs.pred_masks.cpu(),
229
+ inputs["original_sizes"].cpu(),
230
+ inputs["reshaped_input_sizes"].cpu()
231
+ )[0]
232
+
233
+ scores = outputs.iou_scores.cpu().numpy().flatten()
234
+ best_idx = np.argmax(scores)
235
+
236
+ if masks.ndim == 4:
237
+ mask = masks[0, best_idx].numpy()
238
+ else:
239
+ mask = masks[best_idx].numpy()
240
+
241
+ # Nur Masken mit gutem Score
242
+ if scores[best_idx] > 0.7:
243
+ all_masks.append(refine_mask(mask))
244
+ all_scores.append(scores[best_idx])
245
+
246
+ print(f"✅ Got {len(all_masks)} quality masks")
247
+
248
+ # Overlay mit verschiedenen Farben
249
+ overlay = image_np.copy()
250
+
251
+ # HSV-basierte Farbgenerierung
252
+ colors = []
253
+ for i in range(len(all_masks)):
254
+ hue = int(180 * i / max(len(all_masks), 1))
255
+ color_hsv = np.uint8([[[hue, 255, 200]]])
256
+ color_rgb = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB)[0][0]
257
+ colors.append(color_rgb)
258
+
259
+ # Masken anwenden
260
+ for mask, color, score in zip(all_masks, colors, all_scores):
261
+ alpha = 0.4 + (score - 0.7) * 0.2 # Höherer Score = stärkere Farbe
262
+ overlay[mask] = (
263
+ overlay[mask] * (1 - alpha) +
264
+ np.array(color) * alpha
265
+ ).astype(np.uint8)
266
+
267
+ # Kontur
268
+ contours, _ = cv2.findContours(
269
+ mask.astype(np.uint8),
270
+ cv2.RETR_EXTERNAL,
271
+ cv2.CHAIN_APPROX_SIMPLE
272
+ )
273
+ cv2.drawContours(overlay, contours, -1, color.tolist(), 2)
274
+
275
+ metadata = {
276
+ "success": True,
277
+ "mode": "multi_object_dense",
278
+ "density": density,
279
+ "grid_size": f"{grid_size}x{grid_size}",
280
+ "total_points": len(points),
281
+ "quality_masks": len(all_masks),
282
+ "avg_score": float(np.mean(all_scores)) if all_scores else 0,
283
+ "scores": [float(s) for s in all_scores]
284
+ }
285
+
286
+ print("✅ Multi-region complete!")
287
+ return Image.fromarray(overlay), metadata
288
+
289
+ except Exception as e:
290
+ import traceback
291
+ print(f"❌ ERROR:\n{traceback.format_exc()}")
292
+ return image, {"error": str(e)}
293
 
294
+ # Gradio Interface
295
+ demo = gr.Blocks(title="SAM2 Boostly", theme=gr.themes.Soft())
 
 
 
 
296
 
297
+ with demo:
298
+ gr.Markdown("# 🎨 SAM2 Segmentierung - Boostly Edition")
299
+ gr.Markdown("### Optimierte Zero-Shot Object Segmentation")
300
+
301
+ with gr.Tab("🤖 Automatisch PLUS"):
302
+ gr.Markdown("**Smart Multi-Mask Combining** - Kombiniert automatisch alle Objektteile!")
303
+
304
+ with gr.Row():
305
+ with gr.Column():
306
+ input_auto = gr.Image(type="pil", label="📸 Bild hochladen")
307
+
308
+ quality_radio = gr.Radio(
309
+ choices=["high", "fast"],
310
+ value="high",
311
+ label="⚙️ Qualität",
312
+ info="High = präzisere Kanten, Fast = schneller"
313
+ )
314
+
315
+ merge_checkbox = gr.Checkbox(
316
+ value=True,
317
+ label="🔗 Teile zusammenfügen",
318
+ info="Kombiniert alle erkannten Bereiche (Fisch + Flosse = 1 Objekt)"
319
+ )
320
+
321
+ btn_auto = gr.Button("🚀 Segmentieren", variant="primary", size="lg")
322
+
323
+ gr.Markdown("""
324
+ **✨ Funktionsweise:**
325
+ - SAM generiert 3 verschiedene Masken
326
+ - Wenn "Teile zusammenfügen" AN: Alle kombiniert → vollständiges Objekt
327
+ - Wenn AUS: Nur präziseste Maske
328
+ - ⚡ Optimiert: ~10-30 Sekunden statt 25 Minuten!
329
+ """)
330
+
331
+ with gr.Column():
332
+ output_auto = gr.Image(label="✨ Segmentiertes Bild")
333
+ json_auto = gr.JSON(label="📊 Metadata")
334
+
335
+ btn_auto.click(
336
+ fn=segment_automatic,
337
+ inputs=[input_auto, quality_radio, merge_checkbox],
338
+ outputs=[output_auto, json_auto]
339
  )
340
+
341
+ gr.Examples(
342
+ examples=[],
343
+ inputs=input_auto,
344
+ label="💡 Tipp: Objekt sollte zentral im Bild sein"
345
+ )
346
+
347
+ with gr.Tab("🎯 Multi-Region"):
348
+ gr.Markdown("**Grid-basierte Segmentierung** - Für mehrere separate Objekte")
349
+
350
+ with gr.Row():
351
+ with gr.Column():
352
+ input_multi = gr.Image(type="pil", label="📸 Bild hochladen")
353
+
354
+ density_radio = gr.Radio(
355
+ choices=["high", "medium", "low"],
356
+ value="medium",
357
+ label="📊 Punkt-Dichte",
358
+ info="Mehr Punkte = mehr Details, aber langsamer"
359
+ )
360
+
361
+ btn_multi = gr.Button("🎯 Alle Bereiche segmentieren", variant="primary", size="lg")
362
+
363
+ gr.Markdown("""
364
+ **Grid-Größen:**
365
+ - 🔥 High: 5x5 = 25 Erkennungspunkte
366
+ - ⚡ Medium: 4x4 = 16 Punkte (empfohlen)
367
+ - 💨 Low: 3x3 = 9 Punkte
368
+
369
+ Jedes Objekt bekommt eigene Farbe!
370
+ """)
371
+
372
+ with gr.Column():
373
+ output_multi = gr.Image(label="✨ Segmentiertes Bild")
374
+ json_multi = gr.JSON(label="📊 Metadata")
375
+
376
+ btn_multi.click(
377
+ fn=segment_multi_dense,
378
+ inputs=[input_multi, density_radio],
379
+ outputs=[output_multi, json_multi]
380
+ )
381
+
382
+ with gr.Tab("📡 API Dokumentation"):
383
+ gr.Markdown("### 🔗 API Endpoint")
384
+ gr.Code(
385
+ "https://EnginDev-Boostly.hf.space/api/predict",
386
+ label="Base URL"
387
+ )
388
+
389
+ gr.Markdown("### 📝 JavaScript Integration (für Lovable)")
390
+ gr.Code('''
391
+ // Segmentation Service
392
+ const HUGGINGFACE_API = 'https://EnginDev-Boostly.hf.space';
393
 
394
+ async function segmentImage(imageFile, mode = 'automatic') {
395
+ // File zu Base64 konvertieren
396
+ const base64 = await new Promise((resolve) => {
397
+ const reader = new FileReader();
398
+ reader.onloadend = () => resolve(reader.result);
399
+ reader.readAsDataURL(imageFile);
400
+ });
401
+
402
+ // API Call
403
+ const response = await fetch(`${HUGGINGFACE_API}/api/predict`, {
404
+ method: 'POST',
405
+ headers: {'Content-Type': 'application/json'},
406
+ body: JSON.stringify({
407
+ data: [base64, "high", true], // [image, quality, merge]
408
+ fn_index: mode === 'automatic' ? 0 : 1
409
+ })
410
+ });
411
+
412
+ const result = await response.json();
413
+
414
+ return {
415
+ segmentedImage: result.data[0], // Base64 segmentiertes Bild
416
+ metadata: result.data[1] // JSON mit Details
417
+ };
418
+ }
419
 
420
+ // Verwendung:
421
+ const result = await segmentImage(myImageFile, 'automatic');
422
+ console.log('Mask covers:', result.metadata.mask_percentage + '%');
423
+ ''', language="javascript")
424
+
425
+ gr.Markdown("### ⚙️ Parameter")
426
+ gr.Markdown("""
427
+ **fn_index:**
428
+ - `0` = Automatisch PLUS (empfohlen für einzelne Objekte)
429
+ - `1` = Multi-Region (für mehrere Objekte)
430
+
431
+ **quality:**
432
+ - `"high"` = Präzise Kanten, Gaussian Blur, Refinement (~20-30s)
433
+ - `"fast"` = Schneller, weniger Nachbearbeitung (~10-15s)
434
+
435
+ **merge (nur fn_index=0):**
436
+ - `true` = Kombiniert alle Masken → vollständiges Objekt
437
+ - `false` = Nur beste Maske → nur Hauptteil
438
+
439
+ **density (nur fn_index=1):**
440
+ - `"high"` = 5x5 Grid = 25 Punkte
441
+ - `"medium"` = 4x4 Grid = 16 Punkte
442
+ - `"low"` = 3x3 Grid = 9 Punkte
443
+ """)
444
+
445
+ gr.Markdown("### 📊 Response Format")
446
+ gr.Code('''
447
+ {
448
+ "data": [
449
+ "data:image/png;base64,iVBORw0KGgo...", // Segmentiertes Bild
450
+ {
451
+ "success": true,
452
+ "mode": "automatic_plus",
453
+ "masks_combined": 3,
454
+ "mask_percentage": 12.5,
455
+ "num_contours": 1,
456
+ "all_scores": [0.998, 0.583, 0.864]
457
+ }
458
+ ]
459
+ }
460
+ ''', language="json")
461
 
462
+ if __name__ == "__main__":
463
+ print("🌐 Launching Boostly SAM2 v2.1...")
464
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)