Spaces:
Sleeping
Sleeping
| import spaces | |
| import torch | |
| import time | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from typing import List | |
| MODEL_ID = "remyxai/SpaceQwen2.5-VL-3B-Instruct" | |
| def load_model(): | |
| print("Loading model and processor...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| ).to(device) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| return model, processor | |
| model, processor = load_model() | |
| def process_image(image_path_or_obj): | |
| """Loads, resizes, and preprocesses an image path or Pillow Image.""" | |
| if isinstance(image_path_or_obj, str): | |
| # Path on disk or from history | |
| image = Image.open(image_path_or_obj).convert("RGB") | |
| elif isinstance(image_path_or_obj, Image.Image): | |
| image = image_path_or_obj.convert("RGB") | |
| else: | |
| raise ValueError("process_image expects a file path (str) or PIL.Image") | |
| max_width = 512 | |
| if image.width > max_width: | |
| aspect_ratio = image.height / image.width | |
| new_height = int(max_width * aspect_ratio) | |
| image = image.resize((max_width, new_height), Image.Resampling.LANCZOS) | |
| print(f"Resized image to: {max_width}x{new_height}") | |
| return image | |
| def get_latest_image(history): | |
| """ | |
| Look from the end to find the last user-uploaded image (stored as (file_path,) ). | |
| Return None if not found. | |
| """ | |
| for user_msg, _assistant_msg in reversed(history): | |
| if isinstance(user_msg, tuple) and len(user_msg) > 0: | |
| return user_msg[0] | |
| return None | |
| def only_assistant_text(full_text: str) -> str: | |
| """ | |
| Helper to strip out any lines containing 'system', 'user', etc., | |
| and return only the final assistant answer. | |
| Adjust this parsing if your model's output format differs. | |
| """ | |
| # Example output might look like: | |
| # system | |
| # ... | |
| # user | |
| # ... | |
| # assistant | |
| # The final answer | |
| # | |
| # We'll just split on 'assistant' and return everything after it. | |
| if "assistant" in full_text: | |
| parts = full_text.split("assistant", 1) | |
| result = parts[-1].strip() | |
| # Remove any leading punctuation (like a colon) | |
| result = result.lstrip(":").strip() | |
| return result | |
| return full_text.strip() | |
| def run_inference(image, prompt): | |
| """Runs Qwen2.5-VL inference on a single image and text prompt.""" | |
| system_msg = ( | |
| "You are a Vision Language Model specialized in interpreting visual data from images. " | |
| "Your task is to analyze the provided image and respond to queries with concise answers." | |
| ) | |
| conversation = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": system_msg}], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| }, | |
| ] | |
| text_input = processor.apply_chat_template( | |
| conversation, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor(text=[text_input], images=[image], return_tensors="pt").to(model.device) | |
| generated_ids = model.generate(**inputs, max_new_tokens=1024) | |
| output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| # Parse out only the final assistant text | |
| return only_assistant_text(output_text) | |
| def add_message(history, user_input): | |
| """ | |
| Step 1 (triggered by user's 'Submit' or 'Send'): | |
| - Save new text or images into `history`. | |
| - The Chatbot display uses pairs: [user_text_or_image, assistant_reply]. | |
| """ | |
| if not isinstance(history, list): | |
| history = [] | |
| files = user_input.get("files", []) | |
| text = user_input.get("text", "") | |
| # Store images | |
| for f in files: | |
| # Each image is stored as `[(file_path,), None]` | |
| history.append([(f,), None]) | |
| # Store text | |
| if text: | |
| history.append([text, None]) | |
| return history, gr.MultimodalTextbox(value=None) | |
| def inference_interface(history): | |
| """ | |
| Step 2: Use the most recent text + the most recent image to run Qwen2.5-VL. | |
| Instead of adding another entry, we fill the assistant's answer into | |
| the last user text entry. | |
| """ | |
| if not history: | |
| return history, gr.MultimodalTextbox(value=None) | |
| # 1) Get the user's most recent text | |
| user_text = "" | |
| # We'll search from the end for the first str we find | |
| for idx in range(len(history) - 1, -1, -1): | |
| user_msg, assistant_msg = history[idx] | |
| if isinstance(user_msg, str): | |
| user_text = user_msg | |
| # We'll also keep track of this index so we can fill in the assistant reply | |
| user_idx = idx | |
| break | |
| else: | |
| # No user text found | |
| print("No user text found in history. Skipping inference.") | |
| return history, gr.MultimodalTextbox(value=None) | |
| # 2) Get the latest image from the entire conversation | |
| latest_image = get_latest_image(history) | |
| if not latest_image: | |
| # No image found => can't run the model | |
| print("No image found in history. Skipping inference.") | |
| return history, gr.MultimodalTextbox(value=None) | |
| # 3) Process the image | |
| pil_image = process_image(latest_image) | |
| # 4) Run inference | |
| assistant_reply = run_inference(pil_image, user_text) | |
| # 5) Fill that assistant reply back into the last user text entry | |
| history[user_idx][1] = assistant_reply | |
| return history, gr.MultimodalTextbox(value=None) | |
| def build_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SpaceQwen2.5-VL Image Prompt Chatbot") | |
| chatbot = gr.Chatbot([], line_breaks=True) | |
| chat_input = gr.MultimodalTextbox( | |
| interactive=True, | |
| file_types=["image"], | |
| placeholder="Enter text or upload an image (or both).", | |
| show_label=True | |
| ) | |
| # When the user presses Enter in the MultimodalTextbox: | |
| submit_event = chat_input.submit( | |
| fn=add_message, # Step 1: store user data | |
| inputs=[chatbot, chat_input], | |
| outputs=[chatbot, chat_input] | |
| ) | |
| # After storing, run inference | |
| submit_event.then( | |
| fn=inference_interface, # Step 2: run Qwen2.5-VL | |
| inputs=[chatbot], | |
| outputs=[chatbot, chat_input] | |
| ) | |
| # Same logic for a "Send" button | |
| with gr.Row(): | |
| send_button = gr.Button("Send") | |
| clear_button = gr.ClearButton([chatbot, chat_input]) | |
| send_click = send_button.click( | |
| fn=add_message, | |
| inputs=[chatbot, chat_input], | |
| outputs=[chatbot, chat_input] | |
| ) | |
| send_click.then( | |
| fn=inference_interface, | |
| inputs=[chatbot], | |
| outputs=[chatbot, chat_input] | |
| ) | |
| # Example | |
| gr.Examples( | |
| examples=[ | |
| { | |
| "text": "Give me the height of the man in the red hat in feet.", | |
| "files": ["./examples/warehouse_rgb.jpg"] | |
| } | |
| ], | |
| inputs=[chat_input], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch(share=True) | |