skechtoimg / app.py
baondi's picture
Create app.py
fef413e verified
import os
import sys
import pdb
import random
import numpy as np
from PIL import Image
import base64
from io import BytesIO
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
import gradio as gr
from src.model import make_1step_sched
from src.pix2pix_turbo import Pix2Pix_Turbo
model = Pix2Pix_Turbo("sketch_to_image_stochastic")
style_list = [
{
"name": "No Style",
"prompt": "{prompt}",
},
{
"name": "Cinematic",
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
},
{
"name": "3D Model",
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
},
{
"name": "Anime",
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
},
{
"name": "Digital Art",
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
},
{
"name": "Photographic",
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
},
{
"name": "Pixel art",
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
},
{
"name": "Fantasy art",
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
},
{
"name": "Neonpunk",
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
},
{
"name": "Manga",
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
},
]
styles = {k["name"]: k["prompt"] for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Fantasy art"
MAX_SEED = np.iinfo(np.int32).max
def pil_image_to_data_uri(img, format='PNG'):
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/{format.lower()};base64,{img_str}"
def run(image, prompt, prompt_template, style_name, seed, val_r, brush_size, brush_color):
print(f"seed: {seed}, r_val: {val_r}")
print("sketch updated")
if image is None:
ones = Image.new("L", (512, 512), 255)
temp_uri = pil_image_to_data_uri(ones)
return ones, gr.update(link=temp_uri), gr.update(link=temp_uri)
prompt = prompt_template.replace("{prompt}", prompt)
image = image.convert("RGB")
image_t = TF.to_tensor(image) > 0.5
image_pil = TF.to_pil_image(image_t.to(torch.float32))
print(f"r_val={val_r}, seed={seed}")
with torch.no_grad():
c_t = image_t.unsqueeze(0).cuda().float()
torch.manual_seed(seed)
B,C,H,W = c_t.shape
noise = torch.randn((1,4,H//8, W//8), device=c_t.device)
output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
output_pil = TF.to_pil_image(output_image[0].cpu()*0.5+0.5)
input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255-np.array(image)))
output_image_uri = pil_image_to_data_uri(output_pil)
return output_pil, gr.update(link=input_sketch_uri), gr.update(link=output_image_uri)
def update_canvas(brush_size, brush_color):
return gr.update(brush_radius=brush_size, brush_color=brush_color, interactive=True)
def upload_sketch(file):
_img = Image.open(file.name)
_img = _img.convert("L")
return gr.update(value=_img, source="upload", interactive=True)
style_css = """
/* Colores del botón */
.gradio .input_color input[type="color"]::-webkit-color-swatch {
border: 1px solid black;
}
/* Tamaño del botón */
.gradio .input_number input[type="number"] {
width: 120px;
}
/* Tamaño del deslizador */
.gradio .input_slider input[type="range"] {
width: 120px;
}
"""
scripts = """
async () => {
globalThis.theSketchDownloadFunction = () => {
console.log("test")
var link = document.createElement("a");
dataUri = document.getElementById('download_sketch').href
link.setAttribute("href", dataUri)
link.setAttribute("download", "sketch.png")
document.body.appendChild(link); // Required for Firefox
link.click();
document.body.removeChild(link); // Clean up
// also call the output download function
theOutputDownloadFunction();
return false
}
globalThis.theOutputDownloadFunction = () => {
console.log("test output download function")
var link = document.createElement("a");
dataUri = document.getElementById('download_output').href
link.setAttribute("href", dataUri);
link.setAttribute("download", "output.png");
document.body.appendChild(link); // Required for Firefox
link.click();
document.body.removeChild(link); // Clean up
return false
}
globalThis.UNDO_SKETCH_FUNCTION = () => {
console.log("undo sketch function")
var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)');
// Create a new 'click' event
var event = new MouseEvent('click', {
'view': window,
'bubbles': true,
'cancelable': true
});
button_undo.dispatchEvent(event);
}
globalThis.DELETE_SKETCH_FUNCTION = () => {
console.log("delete sketch function")
var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
// Create a new 'click' event
var event = new MouseEvent('click', {
'view': window,
'bubbles': true,
'cancelable': true
});
button_del.dispatchEvent(event);
}
globalThis.togglePencil = () => {
el_pencil = document.getElementById('my-toggle-pencil');
el_pencil.classList.toggle('clicked');
// simulate a click on the gradio button
btn_gradio = document.querySelector("#cb-line > label > input");
var event = new MouseEvent('click', {
'view': window,
'bubbles': true,
'cancelable': true
});
btn_gradio.dispatchEvent(event);
if (el_pencil.classList.contains('clicked')) {
document.getElementById('my-toggle-eraser').classList.remove('clicked');
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
document.getElementById('my-div-eraser').style.backgroundColor = "white";
}
else {
document.getElementById('my-toggle-eraser').classList.add('clicked');
document.getElementById('my-div-pencil').style.backgroundColor = "white";
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
}
}
globalThis.toggleEraser = () => {
element = document.getElementById('my-toggle-eraser');
element.classList.toggle('clicked');
// simulate a click on the gradio button
btn_gradio = document.querySelector("#cb-eraser > label > input");
var event = new MouseEvent('click', {
'view': window,
'bubbles': true,
'cancelable': true
});
btn_gradio.dispatchEvent(event);
if (element.classList.contains('clicked')) {
document.getElementById('my-toggle-pencil').classList.remove('clicked');
document.getElementById('my-div-pencil').style.backgroundColor = "white";
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
}
else {
document.getElementById('my-toggle-pencil').classList.add('clicked');
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
document.getElementById('my-div-eraser').style.backgroundColor = "white";
}
}
}
"""
inputs = [
gr.inputs.Image(label="Input Sketch", type="pil"),
gr.inputs.Text(label="Prompt", default=""),
gr.inputs.Dropdown(label="Style", choices=STYLE_NAMES, default=DEFAULT_STYLE_NAME),
gr.inputs.Slider(label="Seed", min_value=0, max_value=MAX_SEED, default=42, step=1,),
gr.inputs.Slider(label="R", min_value=0, max_value=1, default=0.6, step=0.01),
gr.inputs.Slider(label="Brush Size", min_value=1, max_value=50, default=4, step=1),
gr.inputs.ColorPicker(label="Brush Color", default="#000000")
]
outputs = [
gr.outputs.Image(label="Result"),
gr.outputs.Image(label="Input Sketch", type="pil"),
gr.outputs.Image(label="Output Image", type="pil"),
]
title = "pix2pix-Turbo: Sketch"
description = "One-Step Image Translation with Text-to-Image Models. Paper: [One-Step Image Translation with Text-to-Image Models](https://arxiv.org/abs/2403.12036). GitHub: [pix2pix-Turbo](https://github.com/GaParmar/img2img-turbo)"
examples = [["example.jpg", "A cat", "Fantasy art", 42, 0.6, 4, "#000000"]]
server_port = 7860
if __name__ == "__main__":
gr.Interface(
run,
inputs,
outputs,
title=title,
description=description,
examples=examples,
theme="huggingface",
allow_flagging=False,
layout="unaligned",
live=True,
capture_session=True,
server_port=server_port,
css=style_css,
scripts=scripts,
update_canvas=update_canvas,
upload_sketch=upload_sketch,
).launch(share=True)