TangYiJay commited on
Commit
82b9d60
·
verified ·
1 Parent(s): 1d11911
Files changed (1) hide show
  1. app.py +29 -36
app.py CHANGED
@@ -1,61 +1,54 @@
1
  import gradio as gr
2
  from transformers import BlipProcessor, BlipForQuestionAnswering
3
  from PIL import Image
4
- import os
5
- import difflib
6
 
7
- # Model
8
  model_name = "Salesforce/blip-vqa-base"
9
  processor = BlipProcessor.from_pretrained(model_name)
10
  model = BlipForQuestionAnswering.from_pretrained(model_name)
11
 
12
- # Categories
13
- CATEGORIES = ["plastic", "metal", "paper", "cardboard", "glass", "trash"]
14
- BASE_IMAGE_PATH = "base_reference.jpg"
15
-
16
- # --- Functions ---
17
 
18
  def set_base(image):
19
- if image is None:
20
- return "⚠️ Please upload an image."
21
- image.save(BASE_IMAGE_PATH)
22
  return "✅ Base image saved successfully."
23
 
24
- def identify_material(image):
25
  if image is None:
26
- return "⚠️ Please upload an image."
 
 
 
 
 
 
27
 
28
- question = "Ignore the base, what material is this trash made of?"
29
  inputs = processor(image, question, return_tensors="pt")
30
  out = model.generate(**inputs)
31
- answer = processor.decode(out[0], skip_special_tokens=True).lower()
32
 
33
- # Match to one of the six categories
34
- match = difflib.get_close_matches(answer, CATEGORIES, n=1, cutoff=0.3)
35
- if match:
36
- return f"🧠 Detected: **{match[0]}**"
37
- else:
38
- return f"🤔 Unclear, model said: {answer}"
39
 
40
- # --- UI ---
41
 
 
42
  set_base_ui = gr.Interface(
43
  fn=set_base,
44
- inputs=gr.Image(type="pil", label="Upload Empty Trash Bin Base"),
45
- outputs="text",
46
- title="🧱 Set Base",
47
- description="Upload an image of the empty trash bin (no object)."
48
  )
49
 
50
- identify_ui = gr.Interface(
51
- fn=identify_material,
52
- inputs=gr.Image(type="pil", label="Upload Trash Image to Identify"),
53
- outputs="markdown",
54
- title="🧠 Waste Material Classifier",
55
- description="Upload a trash image. Model will predict one of: plastic, metal, paper, cardboard, glass, or trash."
56
  )
57
 
58
- demo = gr.TabbedInterface([set_base_ui, identify_ui], ["Set Base", "Detect Trash"])
59
-
60
- if __name__ == "__main__":
61
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import BlipProcessor, BlipForQuestionAnswering
3
  from PIL import Image
 
 
4
 
5
+ # Load BLIP model
6
  model_name = "Salesforce/blip-vqa-base"
7
  processor = BlipProcessor.from_pretrained(model_name)
8
  model = BlipForQuestionAnswering.from_pretrained(model_name)
9
 
10
+ # Global variable to store base image
11
+ base_image = None
 
 
 
12
 
13
  def set_base(image):
14
+ global base_image
15
+ base_image = image
 
16
  return "✅ Base image saved successfully."
17
 
18
+ def detect_trash(image):
19
  if image is None:
20
+ return "Please upload an image."
21
+
22
+ if base_image is None:
23
+ return "Please set base image first."
24
+
25
+ # Ask question
26
+ question = "Ignore the base, what material is this? Choose from: plastic, metal, paper, cardboard, glass, trash."
27
 
 
28
  inputs = processor(image, question, return_tensors="pt")
29
  out = model.generate(**inputs)
30
+ answer = processor.decode(out[0], skip_special_tokens=True)
31
 
32
+ # Keep only defined categories
33
+ classes = ["plastic", "metal", "paper", "cardboard", "glass", "trash"]
34
+ matched = next((c for c in classes if c in answer.lower()), "trash")
 
 
 
35
 
36
+ return matched.capitalize()
37
 
38
+ # Two interfaces
39
  set_base_ui = gr.Interface(
40
  fn=set_base,
41
+ inputs=gr.Image(type="pil", label="Upload Empty Base Image"),
42
+ outputs=gr.Textbox(label="Result"),
43
+ title="🧩 Set Base"
 
44
  )
45
 
46
+ detect_trash_ui = gr.Interface(
47
+ fn=detect_trash,
48
+ inputs=gr.Image(type="pil", label="Upload Trash Image"),
49
+ outputs=gr.Textbox(label="Detected Material"),
50
+ title="♻️ Trash Material Detector"
 
51
  )
52
 
53
+ demo = gr.TabbedInterface([set_base_ui, detect_trash_ui], ["Set Base", "Detect Trash"])
54
+ demo.launch()