nightfury commited on
Commit
4d7c61e
·
1 Parent(s): c5957c4

Create new file

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ # import torch
3
+ # from torch import autocast
4
+ # from diffusers import StableDiffusionPipeline
5
+ # from datasets import load_dataset
6
+ from PIL import Image
7
+ import re
8
+ import os
9
+
10
+ # auth_token = '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' #os.getenv("auth_token")
11
+ # model_id = "CompVis/stable-diffusion-v1-4"
12
+ # device = "cpu"
13
+ # pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16)
14
+ # pipe = pipe.to(device)
15
+
16
+ stable_diffusion = gr.Blocks.load(name="spaces/stabilityai/stable-diffusion")
17
+
18
+ def get_images(prompt):
19
+ gallery_dir = stable_diffusion(prompt, fn_index=2)
20
+ return [os.path.join(gallery_dir, img) for img in os.listdir(gallery_dir)]
21
+
22
+
23
+ def infer(prompt, samples, steps, scale, seed):
24
+ generator = torch.Generator(device=device).manual_seed(seed)
25
+ images_list = pipe(
26
+ [prompt] * samples,
27
+ num_inference_steps=steps,
28
+ guidance_scale=scale,
29
+ generator=generator,
30
+ )
31
+ images = []
32
+ # safe_image = Image.open(r"unsafe.png")
33
+ for i, image in enumerate(images_list["sample"]):
34
+ images.append(image)
35
+ # if(images_list["nsfw_content_detected"][i]):
36
+ # images.append(safe_image)
37
+ # else:
38
+ # images.append(image)
39
+ return images
40
+
41
+
42
+ block = gr.Blocks()
43
+
44
+ with block:
45
+ with gr.Group():
46
+ with gr.Box():
47
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
48
+ text = gr.Textbox(
49
+ label="Enter your prompt",
50
+ show_label=False,
51
+ max_lines=1,
52
+ placeholder="Enter your prompt",
53
+ ).style(
54
+ border=(True, False, True, True),
55
+ rounded=(True, False, False, True),
56
+ container=False,
57
+ )
58
+ btn = gr.Button("Generate image").style(
59
+ margin=False,
60
+ rounded=(False, True, True, False),
61
+ )
62
+ gallery = gr.Gallery(
63
+ label="Generated images", show_label=False, elem_id="gallery"
64
+ ).style(grid=[2], height="auto")
65
+
66
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
67
+
68
+ with gr.Row(elem_id="advanced-options"):
69
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
70
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
71
+ scale = gr.Slider(
72
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
73
+ )
74
+ seed = gr.Slider(
75
+ label="Seed",
76
+ minimum=0,
77
+ maximum=2147483647,
78
+ step=1,
79
+ randomize=True,
80
+ )
81
+ text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
82
+ btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
83
+ advanced_button.click(
84
+ None,
85
+ [],
86
+ text,
87
+ )
88
+
89
+ block.launch()