Spaces:
Paused
Paused
| import random | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import base64 | |
| from io import BytesIO | |
| import torch | |
| import torchvision.transforms.functional as F | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from src.pix2pix_turbo import Pix2Pix_Turbo | |
| import nltk | |
| from nltk import pos_tag | |
| from nltk.tokenize import word_tokenize | |
| import re | |
| import os | |
| import threading | |
| import hashlib | |
| from flask import Flask, request, send_file, jsonify, render_template_string | |
| from flask_cors import CORS | |
| import signal | |
| import sys | |
| import logging | |
| import json | |
| import gc | |
| from torch.cuda.amp import autocast | |
| # Set environment variable for better memory management | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' | |
| # Function to clear CUDA cache and collect garbage | |
| def clear_memory(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Load the configuration from config.json | |
| with open('config.json', 'r') as config_file: | |
| config = json.load(config_file) | |
| # Setup logging as per config | |
| logging.basicConfig(level=config["logging"]["level"], format=config["logging"]["format"]) | |
| # Ensure NLTK resources are downloaded | |
| nltk.download('averaged_perceptron_tagger', quiet=True) | |
| nltk.download('punkt', quiet=True) | |
| # File paths for storing sketches and outputs | |
| SKETCH_PATH = config["file_paths"]["sketch_path"] | |
| OUTPUT_PATH = config["file_paths"]["output_path"] | |
| # Processing queue | |
| processing_queue = [] | |
| # Global Constants and Configuration | |
| STYLE_LIST = config["style_list"] | |
| STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST} | |
| DEFAULT_STYLE_NAME = config["default_style_name"] | |
| RANDOM_VALUES = config["random_values"] | |
| PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"] | |
| DEVICE = config["model_params"]["device"] | |
| DEFAULT_SEED = config["model_params"]["default_seed"] | |
| VAL_R_DEFAULT = config["model_params"]["val_r_default"] | |
| MAX_SEED = config["model_params"]["max_seed"] | |
| # Canvas configuration | |
| CANVAS_WIDTH = config["canvas"]["width"] | |
| CANVAS_HEIGHT = config["canvas"]["height"] | |
| BACKGROUND_COLOR = config["canvas"]["background_color"] | |
| DEFAULT_BRUSH_COLOR = config["canvas"]["default_brush_color"] | |
| DEFAULT_BRUSH_SIZE = config["canvas"]["default_brush_size"] | |
| ERASER_COLOR = config["canvas"]["eraser_color"] | |
| MAX_BRUSH_SIZE = config["canvas"]["max_brush_size"] | |
| MIN_BRUSH_SIZE = config["canvas"]["min_brush_size"] | |
| # Preload Models | |
| logging.debug("Loading BLIP and Pix2Pix models...") | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE).eval() # Set model to eval mode | |
| pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME).to(DEVICE).eval() # Set model to eval mode | |
| logging.debug("Models loaded.") | |
| style_list = [ | |
| { | |
| "name": "Cinematic", | |
| "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", | |
| }, | |
| # Other styles... | |
| ] | |
| styles = {k["name"]: k["prompt"] for k in style_list} | |
| STYLE_NAMES = list(styles.keys()) | |
| DEFAULT_STYLE_NAME = "Fantasy art" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # Shared flag and thread for managing the current processing | |
| current_thread = None | |
| cancel_flag = threading.Event() | |
| def pil_image_to_data_uri(img: Image, format="PNG") -> str: | |
| """Converts a PIL image to a data URI.""" | |
| buffered = BytesIO() | |
| img.save(buffered, format=format) | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/{format.lower()};base64,{img_str}" | |
| def generate_prompt_from_sketch(image: Image) -> str: | |
| """Generates a text prompt based on a sketch using the BLIP model.""" | |
| logging.debug("Generating prompt from sketch...") | |
| image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS) | |
| inputs = processor(image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| out = blip_model.generate(**inputs, max_new_tokens=50) | |
| text_prompt = processor.decode(out[0], skip_special_tokens=True) | |
| logging.debug(f"Generated prompt: {text_prompt}") | |
| recognized_items = [extract_main_words(item) for item in text_prompt.split(', ') if item.strip()] | |
| random_prefix = random.choice(RANDOM_VALUES) | |
| prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}" | |
| logging.debug(f"Final prompt: {prompt}") | |
| return prompt | |
| def extract_main_words(item: str) -> str: | |
| """Extracts all nouns from a given text fragment and returns them as a space-separated string.""" | |
| words = word_tokenize(item.strip()) | |
| tagged = pos_tag(words) | |
| nouns = [word.capitalize() for word, tag in tagged if tag in ('NN', 'NNP', 'NNPS', 'NNS')] | |
| return ' '.join(nouns) | |
| def run(image, prompt, prompt_template, style_name, seed, val_r): | |
| """Runs the main image processing pipeline.""" | |
| logging.debug("Running model inference...") | |
| if image is None: | |
| blank_image = Image.new("L", (CANVAS_WIDTH, CANVAS_HEIGHT), 255) | |
| blank_image.save(SKETCH_PATH) # Save blank image as sketch | |
| logging.debug("No image provided. Saving blank image.") | |
| return "", "", "", "" | |
| if not prompt.strip(): | |
| prompt = generate_prompt_from_sketch(image) | |
| # Save the sketch to a file | |
| image.save(SKETCH_PATH) | |
| # Show the original prompt before processing | |
| original_prompt = f"Original Prompt: {prompt}" | |
| logging.debug(original_prompt) | |
| # Add the task to the processing queue | |
| processing_queue.append({"prompt": prompt, "status": "processing"}) | |
| prompt = prompt_template.replace("{prompt}", prompt) | |
| logging.debug(f"Processing with prompt: {prompt}") | |
| image = image.convert("RGB") | |
| image_tensor = F.to_tensor(image) * 2 - 1 # Normalize to [-1, 1] | |
| clear_memory() # Clear memory before running the model | |
| try: | |
| with torch.no_grad(): | |
| c_t = image_tensor.unsqueeze(0).to(DEVICE).float() | |
| torch.manual_seed(seed) | |
| B, C, H, W = c_t.shape | |
| noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) | |
| logging.debug("Calling Pix2Pix model...") | |
| # Enable mixed precision | |
| with autocast(): | |
| if cancel_flag.is_set(): | |
| logging.debug("Processing canceled.") | |
| return "", "", "", original_prompt | |
| output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
| logging.debug("Model inference completed.") | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| logging.warning("CUDA out of memory error. Falling back to CPU.") | |
| with torch.no_grad(): | |
| c_t = c_t.cpu() | |
| noise = noise.cpu() | |
| pix2pix_model_cpu = pix2pix_model.cpu() # Move the model to CPU | |
| output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
| else: | |
| raise e | |
| output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) | |
| output_pil.save(OUTPUT_PATH) | |
| logging.debug("Output image saved.") | |
| input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image))) | |
| output_image_uri = pil_image_to_data_uri(output_pil) | |
| logging.debug(f"Generated output URI: {output_image_uri}") | |
| clear_memory() # Clear memory after running the model | |
| return output_image_uri, input_sketch_uri, output_image_uri, original_prompt | |
| def process_image_task(image, prompt, style_name, seed, val_r): | |
| try: | |
| global cancel_flag | |
| cancel_flag.clear() # Clear any previous cancellation flag | |
| output_image_uri, _, _, _ = run(image, prompt, STYLES.get(style_name, DEFAULT_STYLE_NAME), style_name, seed, val_r) | |
| logging.debug(f"Processed image URI: {output_image_uri}") | |
| return jsonify({"image": output_image_uri}) | |
| except Exception as e: | |
| logging.error(f"Error processing image: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| # Flask Server Setup for Preview and JSON endpoint | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS | |
| def process_image(): | |
| global current_thread, cancel_flag | |
| # Cancel any ongoing processing | |
| if current_thread is not None and current_thread.is_alive(): | |
| logging.debug("Cancelling previous processing...") | |
| cancel_flag.set() | |
| current_thread.join() # Wait for the thread to finish | |
| data = request.get_json() | |
| # Extract and decode the base64 image | |
| image_data = data.get("image", "").split(",")[1] | |
| image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB") | |
| prompt = data.get("prompt", "") | |
| style_name = data.get("style_name", DEFAULT_STYLE_NAME) | |
| seed = int(data.get("seed", DEFAULT_SEED)) | |
| val_r = float(data.get("val_r", VAL_R_DEFAULT)) | |
| # Start new processing in a separate thread | |
| current_thread = threading.Thread(target=process_image_task, args=(image, prompt, style_name, seed, val_r)) | |
| current_thread.start() | |
| return jsonify({"status": "processing_started"}) | |
| def get_sketch(): | |
| if os.path.exists(SKETCH_PATH): | |
| return send_file(SKETCH_PATH, mimetype='image/png') | |
| return jsonify({"status": "error", "message": "Sketch not found."}), 404 | |
| def get_output(): | |
| if os.path.exists(OUTPUT_PATH): | |
| return send_file(OUTPUT_PATH, mimetype='image/png') | |
| return jsonify({"status": "error", "message": "Output not found."}), 404 | |
| def get_status(): | |
| """Returns a JSON with the last image base64 encoded, its checksum, and the processing queue.""" | |
| if os.path.exists(OUTPUT_PATH): | |
| with open(OUTPUT_PATH, "rb") as f: | |
| img_data = f.read() | |
| base64_image = base64.b64encode(img_data).decode('utf-8') | |
| checksum = hashlib.sha256(img_data).hexdigest() | |
| else: | |
| base64_image = "" | |
| checksum = "" | |
| return jsonify({ | |
| "image_base64": base64_image, | |
| "checksum": checksum, | |
| "processing_queue": processing_queue | |
| }) | |
| def index(): | |
| # HTML template for the preview page | |
| html_template = """ | |
| <!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Preview Page</title> | |
| <style> | |
| body, html { | |
| margin: 0; | |
| padding: 0; | |
| height: 100%; | |
| background-color: black; | |
| } | |
| .full-screen-image { | |
| width: 100%; | |
| height: 100%; | |
| object-fit: contain; | |
| } | |
| </style> | |
| <script> | |
| function refreshImage() { | |
| var img = document.getElementById("output-image"); | |
| img.src = "/get_output?" + new Date().getTime(); | |
| } | |
| // Auto-refresh every 2 seconds to show the latest image | |
| setInterval(refreshImage, 2000); | |
| </script> | |
| </head> | |
| <body> | |
| <img id="output-image" src="/get_output" class="full-screen-image"> | |
| </body> | |
| </html> | |
| """ | |
| return render_template_string(html_template) | |
| def draw_page(): | |
| # HTML template for the drawing page at /draw | |
| html_template = """ | |
| <!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Drawing Page</title> | |
| <style> | |
| body, html { | |
| margin: 0; | |
| padding: 0; | |
| height: 100%; | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| background-color: #f0f0f0; | |
| } | |
| .canvas-container { | |
| border: none; | |
| position: relative; | |
| } | |
| .toolbar { | |
| display: flex; | |
| justify-content: center; | |
| margin-bottom: 10px; | |
| } | |
| button { | |
| margin-right: 5px; | |
| } | |
| canvas { | |
| cursor: crosshair; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div style="position: fixed; | |
| bottom: 0; | |
| width: 100%;"> | |
| <div class="toolbar"> | |
| <button id="brush" onclick="setTool('brush')">Brush</button> | |
| <button id="line" onclick="setTool('line')">Line</button> | |
| <button id="eraser" onclick="setTool('eraser')">Eraser</button> | |
| <button id="clear" onclick="clearCanvas()">Clear</button> | |
| <input type="color" id="colorPicker" value="#000000"> | |
| <input type="range" id="brushSize" min="1" max="20" value="4"> | |
| </div> | |
| </div> | |
| <div class="canvas-container"> | |
| <canvas id="drawingCanvas" width="512" height="512"></canvas> | |
| </div> | |
| <script> | |
| let canvas = document.getElementById('drawingCanvas'); | |
| let ctx = canvas.getContext('2d'); | |
| let drawing = false; | |
| let tool = 'brush'; | |
| let lastX = 0, lastY = 0; | |
| // Fill the canvas with white background | |
| ctx.fillStyle = "#ffffff"; | |
| ctx.fillRect(0, 0, canvas.width, canvas.height); | |
| canvas.addEventListener('mousedown', (e) => { | |
| drawing = true; | |
| [lastX, lastY] = [e.offsetX, e.offsetY]; | |
| }); | |
| canvas.addEventListener('mousemove', draw); | |
| canvas.addEventListener('mouseup', () => { | |
| drawing = false; | |
| sendDrawingToBackend(); | |
| }); | |
| canvas.addEventListener('mouseout', () => drawing = false); | |
| function draw(e) { | |
| if (!drawing) return; | |
| ctx.strokeStyle = document.getElementById('colorPicker').value; | |
| ctx.lineWidth = document.getElementById('brushSize').value; | |
| ctx.lineJoin = 'round'; | |
| ctx.lineCap = 'round'; | |
| ctx.beginPath(); | |
| ctx.moveTo(lastX, lastY); | |
| ctx.lineTo(e.offsetX, e.offsetY); | |
| ctx.stroke(); | |
| [lastX, lastY] = [e.offsetX, e.offsetY]; | |
| } | |
| function setTool(selectedTool) { | |
| tool = selectedTool; | |
| if (tool === 'eraser') { | |
| ctx.strokeStyle = "#ffffff"; // Use white color for eraser | |
| } else { | |
| ctx.strokeStyle = document.getElementById('colorPicker').value; | |
| } | |
| ctx.globalCompositeOperation = 'source-over'; | |
| } | |
| function clearCanvas() { | |
| ctx.fillStyle = "#ffffff"; | |
| ctx.fillRect(0, 0, canvas.width, canvas.height); | |
| fetch('/clear_preview', { method: 'POST' }) | |
| .then(response => response.json()) | |
| .then(data => console.log('Cleared preview', data)) | |
| .catch(error => console.error('Error clearing preview:', error)); | |
| } | |
| function sendDrawingToBackend() { | |
| let dataURL = canvas.toDataURL('image/png'); | |
| fetch('/process-image', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify({ image: dataURL }), | |
| }) | |
| .then(response => response.json()) | |
| .then(data => console.log('Image processed', data)) | |
| .catch(error => console.error('Error processing image:', error)); | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return render_template_string(html_template) | |
| def clear_preview(): | |
| if os.path.exists(OUTPUT_PATH): | |
| os.remove(OUTPUT_PATH) | |
| return jsonify({"status": "cleared"}) | |
| def start_flask_app(): | |
| app.run(host=config["server"]["host"], port=config["server"]["port"], threaded=True) | |
| def signal_handler(sig, frame): | |
| print("Ctrl+C pressed, shutting down.") | |
| sys.exit(0) | |
| # Register the signal handler for Ctrl+C | |
| signal.signal(signal.SIGINT, signal_handler) | |
| if __name__ == "__main__": | |
| start_flask_app() | |