Ahsen Khaliq commited on
Commit
ab119a0
·
1 Parent(s): a197d62

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("git clone https://github.com/mchong6/SOAT.git")
3
+
4
+
5
+ import os
6
+ import torch
7
+ import torchvision
8
+ from torch import nn
9
+ import numpy as np
10
+ import torch.backends.cudnn as cudnn
11
+ cudnn.benchmark = True
12
+
13
+ import math
14
+ import matplotlib.pyplot as plt
15
+ import torch.nn.functional as F
16
+ from model import *
17
+ from tqdm import tqdm as tqdm
18
+ import pickle
19
+ from copy import deepcopy
20
+ import warnings
21
+ warnings.filterwarnings("ignore", category=UserWarning) # get rid of interpolation warning
22
+ import kornia.filters as k
23
+ from torchvision.utils import save_image
24
+ from util import *
25
+ import scipy
26
+
27
+ import gradio as gr
28
+
29
+ import PIL
30
+
31
+ from torchvision import transforms
32
+
33
+ device = 'cpu' #@param ['cuda', 'cpu']
34
+
35
+ generator = Generator(256, 512, 8, channel_multiplier=2).eval().to(device)
36
+ truncation = 0.7
37
+
38
+ def display_image(image, size=None, mode='nearest', unnorm=False, title=''):
39
+ # image is [3,h,w] or [1,3,h,w] tensor [0,1]
40
+ if image.is_cuda:
41
+ image = image.cpu()
42
+ if size is not None and image.size(-1) != size:
43
+ image = F.interpolate(image, size=(size,size), mode=mode)
44
+ if image.dim() == 4:
45
+ image = image[0]
46
+ image = ((image.clamp(-1,1)+1)/2).permute(1, 2, 0).detach().numpy()
47
+ return image
48
+
49
+
50
+ def inferece(num, seed):
51
+ model_type = 'landscape'
52
+ num_im = int(num)
53
+ random_seed = int(seed)
54
+
55
+ plt.rcParams['figure.dpi'] = 300
56
+
57
+ mean_latent = load_model(generator, f'{model_type}.pt')
58
+
59
+ # pad determines how much of an image is involve in the blending
60
+ pad = 512//4
61
+
62
+ all_im = []
63
+
64
+ random_state = np.random.RandomState(random_seed)
65
+
66
+ # latent smoothing
67
+ with torch.no_grad():
68
+ z = random_state.randn(num_im, 512).astype(np.float32)
69
+ z = scipy.ndimage.gaussian_filter(z, [.7, 0], mode='wrap')
70
+ z /= np.sqrt(np.mean(np.square(z)))
71
+ z = torch.from_numpy(z).cuda()
72
+
73
+ source = generator.get_latent(z, truncation=truncation, mean_latent=mean_latent)
74
+
75
+ # merge images 2 at a time
76
+ for i in range(num_im-1):
77
+ source1 = index_layers(source, i)
78
+ source2 = index_layers(source, i+1)
79
+ all_im.append(generator.merge_extension(source1, source2))
80
+
81
+ # display intermediate generations
82
+ # for i in all_im:
83
+ # display_image(i)
84
+
85
+
86
+ b,c,h,w = all_im[0].shape
87
+ panorama_im = torch.zeros(b,c,h,512+(num_im-2)*256)
88
+
89
+ # We created a series of 2-blended images which we can overlay to form a large panorama
90
+ # add first image
91
+ coord = 256+pad
92
+ panorama_im[..., :coord] = all_im[0][..., :coord]
93
+
94
+ for im in all_im[1:]:
95
+ panorama_im[..., coord:coord+512-2*pad] = im[..., pad:-pad]
96
+ coord += 512-2*pad
97
+ panorama_im[..., coord:] = all_im[-1][..., 512-pad:]
98
+
99
+ img = display_image(panorama_im)
100
+ return img
101
+
102
+ title = "SOAT"
103
+ description = "Gradio demo for SOAT Panorama Generaton for landscapes. Generate a panorama using a pretrained stylegan by stitching intermediate activations. To use it, simply add the number of images and random seed number . Read more at the links below."
104
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.01619' target='_blank'>StyleGAN of All Trades: Image Manipulation with Only Pretrained StyleGAN</a> | <a href='https://github.com/mchong6/SOAT' target='_blank'>Github Repo</a></p>"
105
+
106
+ gr.Interface(
107
+ inferece,
108
+ [gr.inputs.Number(default=5, label="Number of Images")
109
+ ,gr.inputs.Number(default=90, label="Random Seed")],
110
+ gr.outputs.Image(type="numpy", label="Output"),
111
+ title=title,
112
+ description=description,
113
+ article=article, theme="huggingface",enable_queue=True).launch(debug=True)