TangYiJay commited on
Commit
5e5a776
·
verified ·
1 Parent(s): 398ed96
Files changed (1) hide show
  1. app.py +54 -19
app.py CHANGED
@@ -1,28 +1,63 @@
1
  import gradio as gr
2
  from PIL import Image, ImageChops
 
 
3
 
4
- # === 加载本地 base 图 ===
5
- BASE_IMAGE_PATH = "base.jpg"
6
- base_img = Image.open(BASE_IMAGE_PATH).convert("RGB")
 
7
 
8
- def detect_difference(uploaded_img):
9
- uploaded_img = uploaded_img.convert("RGB")
10
- diff = ImageChops.difference(uploaded_img, base_img)
11
 
12
- # 检查差异程度
13
- bbox = diff.getbbox()
14
- if bbox is None:
15
- return "No difference detected (same as base)."
16
- else:
17
- return "Trash detected! Differences found."
18
 
19
- demo = gr.Interface(
20
- fn=detect_difference,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  inputs=gr.Image(type="pil", label="Upload Trash Image"),
22
- outputs=gr.Textbox(label="Detection Result"),
23
- title="♻️ Trash Detector (Static Base)",
24
- description="Base image (empty bin) is preloaded from base.jpg during build."
25
  )
26
 
27
- if __name__ == "__main__":
28
- demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image, ImageChops
3
+ from transformers import BlipProcessor, BlipForQuestionAnswering
4
+ import torch
5
 
6
+ # Load BLIP-VQA model
7
+ model_name = "Salesforce/blip-vqa-base"
8
+ processor = BlipProcessor.from_pretrained(model_name)
9
+ model = BlipForQuestionAnswering.from_pretrained(model_name)
10
 
11
+ valid_classes = ["plastic", "metal", "paper", "cardboard", "glass", "trash"]
12
+ base_img = None # Global variable to store base image
 
13
 
14
+ # Function to compute difference image
15
+ def get_difference_image(base: Image.Image, trash: Image.Image) -> Image.Image:
16
+ diff = ImageChops.difference(base, trash).convert("RGB")
17
+ # Optional: enhance contrast to highlight difference
18
+ return diff
 
19
 
20
+ # Set base image
21
+ def set_base(image):
22
+ global base_img
23
+ base_img = image.convert("RGB")
24
+ return "Base image saved successfully."
25
+
26
+ # Detect trash material
27
+ def detect_material(trash_image):
28
+ global base_img
29
+ if base_img is None:
30
+ return "Please set base image first."
31
+
32
+ trash_image = trash_image.convert("RGB")
33
+ diff_image = get_difference_image(base_img, trash_image)
34
+
35
+ question = "What material is this object? Choose one of: plastic, metal, paper, cardboard, glass, trash."
36
+
37
+ inputs = processor(diff_image, question, return_tensors="pt")
38
+ out = model.generate(**inputs)
39
+ answer = processor.decode(out[0], skip_special_tokens=True).lower()
40
+
41
+ # Ensure answer is one of the valid classes
42
+ material = next((c for c in valid_classes if c in answer), "trash")
43
+ return material.capitalize()
44
+
45
+ # Build Gradio UI
46
+ set_base_ui = gr.Interface(
47
+ fn=set_base,
48
+ inputs=gr.Image(type="pil", label="Upload Base Image (Empty Bin)"),
49
+ outputs=gr.Textbox(label="Result"),
50
+ title="Set Base Image",
51
+ api_name="/set_base"
52
+ )
53
+
54
+ detect_ui = gr.Interface(
55
+ fn=detect_material,
56
  inputs=gr.Image(type="pil", label="Upload Trash Image"),
57
+ outputs=gr.Textbox(label="Detected Material"),
58
+ title="Trash Material Detector",
59
+ api_name="/detect_material"
60
  )
61
 
62
+ demo = gr.TabbedInterface([set_base_ui, detect_ui], ["Set Base", "Detect Trash"])
63
+ demo.launch()