Spaces:
Paused
Paused
| import random | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import base64 | |
| from io import BytesIO | |
| import torch | |
| import torchvision.transforms.functional as F | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from src.pix2pix_turbo import Pix2Pix_Turbo | |
| import nltk | |
| from nltk import pos_tag | |
| from nltk.tokenize import word_tokenize | |
| import re | |
| import os | |
| import json | |
| import logging | |
| import gc | |
| import gradio as gr | |
| from torch.cuda.amp import autocast | |
| # Set environment variable for better memory management | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' | |
| # Function to clear CUDA cache and collect garbage | |
| def clear_memory(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Load the configuration from config.json | |
| with open('config.json', 'r') as config_file: | |
| config = json.load(config_file) | |
| # Setup logging as per config | |
| logging.basicConfig(level=config["logging"]["level"], format=config["logging"]["format"]) | |
| # Ensure NLTK resources are downloaded | |
| nltk.download('averaged_perceptron_tagger') | |
| nltk.download('punkt') | |
| # File paths for storing sketches and outputs | |
| SKETCH_PATH = config["file_paths"]["sketch_path"] | |
| OUTPUT_PATH = config["file_paths"]["output_path"] | |
| # Global Constants and Configuration | |
| STYLE_LIST = config["style_list"] | |
| STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST} | |
| DEFAULT_STYLE_NAME = config["default_style_name"] | |
| RANDOM_VALUES = config["random_values"] | |
| PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"] | |
| DEVICE = config["model_params"]["device"] | |
| DEFAULT_SEED = config["model_params"]["default_seed"] | |
| VAL_R_DEFAULT = config["model_params"]["val_r_default"] | |
| MAX_SEED = config["model_params"]["max_seed"] | |
| # Canvas configuration | |
| CANVAS_WIDTH = config["canvas"]["width"] | |
| CANVAS_HEIGHT = config["canvas"]["height"] | |
| # Preload Models | |
| logging.debug("Loading BLIP and Pix2Pix models...") | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE) | |
| pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME) | |
| logging.debug("Models loaded.") | |
| def pil_image_to_data_uri(img: Image, format="PNG") -> str: | |
| """Converts a PIL image to a data URI.""" | |
| buffered = BytesIO() | |
| img.save(buffered, format=format) | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/{format.lower()};base64,{img_str}" | |
| def generate_prompt_from_sketch(image: Image) -> str: | |
| """Generates a text prompt based on a sketch using the BLIP model.""" | |
| logging.debug("Generating prompt from sketch...") | |
| image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS) | |
| inputs = processor(image, return_tensors="pt").to(DEVICE) | |
| out = blip_model.generate(**inputs, max_new_tokens=50) | |
| text_prompt = processor.decode(out[0], skip_special_tokens=True) | |
| logging.debug(f"Generated prompt: {text_prompt}") | |
| recognized_items = [extract_main_words(item) for item in text_prompt.split(', ') if item.strip()] | |
| random_prefix = random.choice(RANDOM_VALUES) | |
| prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}" | |
| logging.debug(f"Final prompt: {prompt}") | |
| return prompt | |
| def extract_main_words(item: str) -> str: | |
| """Extracts all nouns from a given text fragment and returns them as a space-separated string.""" | |
| words = word_tokenize(item.strip()) | |
| tagged = pos_tag(words) | |
| nouns = [word.capitalize() for word, tag in tagged if tag in ('NN', 'NNP', 'NNPS', 'NNS')] | |
| return ' '.join(nouns) | |
| def normalize_image(image, range_from=(-1, 1)): | |
| """ | |
| Normalize the input image to a specified range. | |
| :param image: The PIL Image to be normalized. | |
| :param range_from: The target range for normalization, typically (-1, 1) or (0, 1). | |
| :return: Normalized image tensor. | |
| """ | |
| # Convert the image to a tensor | |
| image_t = F.to_tensor(image) | |
| if range_from == (-1, 1): | |
| # Normalize from [0, 1] to [-1, 1] | |
| image_t = image_t * 2 - 1 | |
| return image_t | |
| def run(image, prompt, prompt_template, style_name, seed, val_r): | |
| """Runs the main image processing pipeline.""" | |
| logging.debug("Running model inference...") | |
| if image is None: | |
| blank_image = Image.new("L", (CANVAS_WIDTH, CANVAS_HEIGHT), 255) | |
| blank_image.save(SKETCH_PATH) # Save blank image as sketch | |
| logging.debug("No image provided. Saving blank image.") | |
| return "", "", "", "" | |
| if not prompt.strip(): | |
| prompt = generate_prompt_from_sketch(image) | |
| # Save the sketch to a file | |
| image.save(SKETCH_PATH) | |
| # Show the original prompt before processing | |
| original_prompt = f"Original Prompt: {prompt}" | |
| logging.debug(original_prompt) | |
| prompt = prompt_template.replace("{prompt}", prompt) | |
| logging.debug(f"Processing with prompt: {prompt}") | |
| image = image.convert("RGB") | |
| image_tensor = F.to_tensor(image) * 2 - 1 # Normalize to [-1, 1] | |
| clear_memory() # Clear memory before running the model | |
| try: | |
| with torch.no_grad(): | |
| c_t = image_tensor.unsqueeze(0).to(DEVICE).float() | |
| torch.manual_seed(seed) | |
| B, C, H, W = c_t.shape | |
| noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) | |
| logging.debug("Calling Pix2Pix model...") | |
| # Enable mixed precision | |
| with autocast(): | |
| output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
| logging.debug("Model inference completed.") | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| logging.warning("CUDA out of memory error. Falling back to CPU.") | |
| with torch.no_grad(): | |
| c_t = c_t.cpu() | |
| noise = noise.cpu() | |
| pix2pix_model_cpu = pix2pix_model.cpu() # Move the model to CPU | |
| output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
| else: | |
| raise e | |
| output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) | |
| output_pil.save(OUTPUT_PATH) | |
| logging.debug("Output image saved.") | |
| return output_pil | |
| def gradio_interface(image, prompt, style_name, seed, val_r): | |
| """Gradio interface function to handle inputs and generate outputs.""" | |
| # Endpoint: `image` - Input image from user (Sketch Image) | |
| # Endpoint: `prompt` - Text prompt (optional) | |
| # Endpoint: `style_name` - Selected style from dropdown | |
| # Endpoint: `seed` - Seed for reproducibility | |
| # Endpoint: `val_r` - Sketch guidance value | |
| prompt_template = STYLES.get(style_name, STYLES[DEFAULT_STYLE_NAME]) | |
| result_image = run(image, prompt, prompt_template, style_name, seed, val_r) | |
| return result_image | |
| # Create the Gradio Interface | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Image(source="upload", type="pil", label="Sketch Image"), # Endpoint: `image` | |
| gr.Textbox(lines=2, placeholder="Enter a text prompt (optional)", label="Prompt"), # Endpoint: `prompt` | |
| gr.Dropdown(choices=list(STYLES.keys()), value=DEFAULT_STYLE_NAME, label="Style"), # Endpoint: `style_name` | |
| gr.Slider(minimum=0, maximum=MAX_SEED, step=1, default=DEFAULT_SEED, label="Seed"), # Endpoint: `seed` | |
| gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=VAL_R_DEFAULT, label="Sketch Guidance") # Endpoint: `val_r` | |
| ], | |
| outputs=gr.Image(label="Generated Image"), # Output endpoint: `result_image` | |
| title="Sketch to Image Generation", | |
| description="Upload a sketch and generate an image based on a prompt and style." | |
| ) | |
| if __name__ == "__main__": | |
| # Launch the Gradio interface | |
| interface.launch(share=True) | |