Spaces:
Paused
Paused
| import torch | |
| from PIL import Image | |
| import torchvision.transforms.functional as F | |
| from src.pix2pix_turbo import Pix2Pix_Turbo | |
| import numpy as np | |
| def process_sketch(sketch_path, output_path, prompt, val_r=0.4, seed=42): | |
| # Load the model | |
| model = Pix2Pix_Turbo("sketch_to_image_stochastic") | |
| # Set the seed for reproducibility | |
| torch.manual_seed(seed) | |
| # Load the sketch image | |
| image = Image.open(sketch_path).convert("RGB") | |
| # Convert the image to tensor and threshold it | |
| image_t = F.to_tensor(image) > 0.5 | |
| # Prepare the input tensor | |
| with torch.no_grad(): | |
| c_t = image_t.unsqueeze(0).cuda().float() | |
| B, C, H, W = c_t.shape | |
| # Create a random noise map | |
| noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) | |
| # Call the Pix2Pix model | |
| output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
| # Convert the output tensor to PIL image | |
| output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) | |
| # Save the output image | |
| output_pil.save(output_path) | |
| print(f"Output image saved to {output_path}") | |
| if __name__ == "__main__": | |
| sketch_path = "sketch.png" | |
| output_path = "output.png" | |
| prompt = ("a fantasy concept art of a magical castle in the sky, ") | |
| process_sketch(sketch_path, output_path, prompt) | |