File size: 3,875 Bytes
ab119a0
 
d343810
 
ab119a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adc50c3
 
 
86c32c4
 
adc50c3
 
86c32c4
2a0de73
 
c27187e
ab119a0
c27187e
 
ab119a0
 
c27187e
ab119a0
c27187e
ab119a0
c27187e
ab119a0
 
 
 
 
c27187e
 
ab119a0
 
 
 
 
 
 
c27187e
ab119a0
 
 
 
c27187e
ab119a0
 
c27187e
ab119a0
 
 
 
c27187e
ab119a0
 
 
 
c27187e
ab119a0
 
 
 
 
 
 
 
 
 
871d6aa
86c32c4
ab119a0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os 
os.system("git clone https://github.com/mchong6/SOAT.git")
import sys
sys.path.append("SOAT")
import os
import torch
import torchvision
from torch import nn
import numpy as np
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

import math
import matplotlib.pyplot as plt
import torch.nn.functional as F
from model import *
from tqdm import tqdm as tqdm
import pickle
from copy import deepcopy
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # get rid of interpolation warning
import kornia.filters as k
from torchvision.utils import save_image
from util import *
import scipy

import gradio as gr

import PIL

from torchvision import transforms

device = 'cpu' #@param ['cuda', 'cpu']

generator = Generator(256, 512, 8, channel_multiplier=2).eval().to(device)
truncation = 0.7

def display_image(image, size=None, mode='nearest', unnorm=False, title=''):
    # image is [3,h,w] or [1,3,h,w] tensor [0,1]
    if image.is_cuda:
        image = image.cpu()
    if size is not None and image.size(-1) != size:
        image = F.interpolate(image, size=(size,size), mode=mode)
    if image.dim() == 4:
        image = image[0]
    image = ((image.clamp(-1,1)+1)/2).permute(1, 2, 0).detach().numpy()
    return image
    

#mean_latentland = load_model(generator, 'landscape.pt')
#mean_latentface = load_model(generator, 'face.pt')
#mean_latentchurch = load_model(generator, 'church.pt')


def inferece(num, seed):
    mean_latent = load_model(generator, 'landscape.pt')

    num_im =  int(num)
    random_seed =  int(seed)
    
    plt.rcParams['figure.dpi'] = 300
    
    
    # pad determines how much of an image is involve in the blending
    pad = 512//4
    
    all_im = []
    
    random_state = np.random.RandomState(random_seed)
    
    # latent smoothing
    with torch.no_grad():
        z = random_state.randn(num_im, 512).astype(np.float32)
        z = scipy.ndimage.gaussian_filter(z, [.7, 0], mode='wrap')
        z /= np.sqrt(np.mean(np.square(z)))
        z = torch.from_numpy(z).to(device)
    
        source = generator.get_latent(z, truncation=truncation, mean_latent=mean_latent)
            
        # merge images 2 at a time
        for i in range(num_im-1):
            source1 = index_layers(source, i)
            source2 = index_layers(source, i+1)
            all_im.append(generator.merge_extension(source1, source2))
    
    # display intermediate generations    
    # for i in all_im:
    #     display_image(i)
        
    
    b,c,h,w = all_im[0].shape
    panorama_im = torch.zeros(b,c,h,512+(num_im-2)*256)
    
    # We created a series of 2-blended images which we can overlay to form a large panorama
    # add first image
    coord = 256+pad
    panorama_im[..., :coord] = all_im[0][..., :coord]
    
    for im in all_im[1:]:
        panorama_im[..., coord:coord+512-2*pad] = im[..., pad:-pad]
        coord += 512-2*pad
    panorama_im[..., coord:] = all_im[-1][..., 512-pad:]
    
    img = display_image(panorama_im)
    return img

title = "SOAT"
description = "Gradio demo for SOAT Panorama Generaton for landscapes. Generate a panorama using a pretrained stylegan by stitching intermediate activations. To use it, simply add the number of images and random seed number . Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.01619' target='_blank'>StyleGAN of All Trades: Image Manipulation with Only Pretrained StyleGAN</a> | <a href='https://github.com/mchong6/SOAT' target='_blank'>Github Repo</a></p>"

gr.Interface(
    inferece, 
    [gr.inputs.Number(default=5, label="Number of Images")
,gr.inputs.Number(default=90, label="Random Seed")
],
    gr.outputs.Image(type="numpy", label="Output"),
    title=title,
    description=description,
    article=article, theme="huggingface",enable_queue=True).launch(debug=True)