File size: 3,875 Bytes
ab119a0 d343810 ab119a0 adc50c3 86c32c4 adc50c3 86c32c4 2a0de73 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 c27187e ab119a0 871d6aa 86c32c4 ab119a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
os.system("git clone https://github.com/mchong6/SOAT.git")
import sys
sys.path.append("SOAT")
import os
import torch
import torchvision
from torch import nn
import numpy as np
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
import math
import matplotlib.pyplot as plt
import torch.nn.functional as F
from model import *
from tqdm import tqdm as tqdm
import pickle
from copy import deepcopy
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # get rid of interpolation warning
import kornia.filters as k
from torchvision.utils import save_image
from util import *
import scipy
import gradio as gr
import PIL
from torchvision import transforms
device = 'cpu' #@param ['cuda', 'cpu']
generator = Generator(256, 512, 8, channel_multiplier=2).eval().to(device)
truncation = 0.7
def display_image(image, size=None, mode='nearest', unnorm=False, title=''):
# image is [3,h,w] or [1,3,h,w] tensor [0,1]
if image.is_cuda:
image = image.cpu()
if size is not None and image.size(-1) != size:
image = F.interpolate(image, size=(size,size), mode=mode)
if image.dim() == 4:
image = image[0]
image = ((image.clamp(-1,1)+1)/2).permute(1, 2, 0).detach().numpy()
return image
#mean_latentland = load_model(generator, 'landscape.pt')
#mean_latentface = load_model(generator, 'face.pt')
#mean_latentchurch = load_model(generator, 'church.pt')
def inferece(num, seed):
mean_latent = load_model(generator, 'landscape.pt')
num_im = int(num)
random_seed = int(seed)
plt.rcParams['figure.dpi'] = 300
# pad determines how much of an image is involve in the blending
pad = 512//4
all_im = []
random_state = np.random.RandomState(random_seed)
# latent smoothing
with torch.no_grad():
z = random_state.randn(num_im, 512).astype(np.float32)
z = scipy.ndimage.gaussian_filter(z, [.7, 0], mode='wrap')
z /= np.sqrt(np.mean(np.square(z)))
z = torch.from_numpy(z).to(device)
source = generator.get_latent(z, truncation=truncation, mean_latent=mean_latent)
# merge images 2 at a time
for i in range(num_im-1):
source1 = index_layers(source, i)
source2 = index_layers(source, i+1)
all_im.append(generator.merge_extension(source1, source2))
# display intermediate generations
# for i in all_im:
# display_image(i)
b,c,h,w = all_im[0].shape
panorama_im = torch.zeros(b,c,h,512+(num_im-2)*256)
# We created a series of 2-blended images which we can overlay to form a large panorama
# add first image
coord = 256+pad
panorama_im[..., :coord] = all_im[0][..., :coord]
for im in all_im[1:]:
panorama_im[..., coord:coord+512-2*pad] = im[..., pad:-pad]
coord += 512-2*pad
panorama_im[..., coord:] = all_im[-1][..., 512-pad:]
img = display_image(panorama_im)
return img
title = "SOAT"
description = "Gradio demo for SOAT Panorama Generaton for landscapes. Generate a panorama using a pretrained stylegan by stitching intermediate activations. To use it, simply add the number of images and random seed number . Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.01619' target='_blank'>StyleGAN of All Trades: Image Manipulation with Only Pretrained StyleGAN</a> | <a href='https://github.com/mchong6/SOAT' target='_blank'>Github Repo</a></p>"
gr.Interface(
inferece,
[gr.inputs.Number(default=5, label="Number of Images")
,gr.inputs.Number(default=90, label="Random Seed")
],
gr.outputs.Image(type="numpy", label="Output"),
title=title,
description=description,
article=article, theme="huggingface",enable_queue=True).launch(debug=True) |