File size: 9,874 Bytes
f7009b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import torch
from torch import nn
import torch.nn.functional as F
from .denoiser import ConditionalUNet
import numpy as np
def extract(v, i, shape):
out = torch.gather(v, index=i, dim=0)
out = out.to(device=i.device, dtype=torch.float32)
# reshape to (batch_size, 1, 1, 1, 1, ...) for broadcasting purposes.
out = out.view([i.shape[0]] + [1] * (len(shape) - 1))
return out
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
super().__init__()
self.model = model
self.T = T
# generate T steps of beta
self.register_buffer("beta_t", torch.linspace(*beta, T, dtype=torch.float32))
# calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
alpha_t = 1.0 - self.beta_t
alpha_t_bar = torch.cumprod(alpha_t, dim=0)
# calculate and store two coefficient of $q(x_t | x_0)$
self.register_buffer("signal_rate", torch.sqrt(alpha_t_bar))
self.register_buffer("noise_rate", torch.sqrt(1.0 - alpha_t_bar))
def forward(self, x_0, z, **kwargs):
# preprocess nan to zero
mask = torch.isnan(x_0)
x_0 = torch.nan_to_num(x_0, 0.)
# get a random training step $t \sim Uniform({1, ..., T})$
t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)
# generate $\epsilon \sim N(0, 1)$
epsilon = torch.randn_like(x_0)
# predict the noise added from $x_{t-1}$ to $x_t$
x_t = (extract(self.signal_rate, t, x_0.shape) * x_0 +
extract(self.noise_rate, t, x_0.shape) * epsilon)
epsilon_theta = self.model(x_t, t, z)
# get the gradient
loss = F.mse_loss(epsilon_theta, epsilon, reduction="none")
loss[mask] = torch.nan
return loss.nanmean()
class DDPMSampler(nn.Module):
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
super().__init__()
self.model = model
self.T = T
# generate T steps of beta
self.register_buffer("beta_t", torch.linspace(*beta, T, dtype=torch.float32))
# calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
alpha_t = 1.0 - self.beta_t
alpha_t_bar = torch.cumprod(alpha_t, dim=0)
alpha_t_bar_prev = F.pad(alpha_t_bar[:-1], (1, 0), value=1.0)
self.register_buffer("coeff_1", torch.sqrt(1.0 / alpha_t))
self.register_buffer("coeff_2", self.coeff_1 * (1.0 - alpha_t) / torch.sqrt(1.0 - alpha_t_bar))
self.register_buffer("posterior_variance", self.beta_t * (1.0 - alpha_t_bar_prev) / (1.0 - alpha_t_bar))
@torch.no_grad()
def cal_mean_variance(self, x_t, t, c):
# """ Calculate the mean and variance for $q(x_{t-1} | x_t, x_0)$ """
epsilon_theta = self.model(x_t, t, c)
mean = extract(self.coeff_1, t, x_t.shape) * x_t - extract(self.coeff_2, t, x_t.shape) * epsilon_theta
# var is a constant
var = extract(self.posterior_variance, t, x_t.shape)
return mean, var
@torch.no_grad()
def sample_one_step(self, x_t, time_step, c):
# """ Calculate $x_{t-1}$ according to $x_t$ """
t = torch.full((x_t.shape[0],), time_step, device=x_t.device, dtype=torch.long)
mean, var = self.cal_mean_variance(x_t, t, c)
z = torch.randn_like(x_t) if time_step > 0 else 0
x_t_minus_one = mean + torch.sqrt(var) * z
if torch.isnan(x_t_minus_one).int().sum() != 0:
raise ValueError("nan in tensor!")
return x_t_minus_one
@torch.no_grad()
def forward(self, x_t, c, only_return_x_0=True, interval=1, **kwargs):
x = [x_t]
for time_step in reversed(range(self.T)):
x_t = self.sample_one_step(x_t, time_step, c)
if not only_return_x_0 and ((self.T - time_step) % interval == 0 or time_step == 0):
x.append(x_t)
if only_return_x_0:
return x_t # [batch_size, channels, height, width]
return torch.stack(x, dim=1) # [batch_size, sample, channels, height, width]
class DDIMSampler(nn.Module):
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
super().__init__()
self.model = model
self.T = T
# generate T steps of beta
beta_t = torch.linspace(*beta, T, dtype=torch.float32)
# calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
alpha_t = 1.0 - beta_t
self.register_buffer("alpha_t_bar", torch.cumprod(alpha_t, dim=0))
@torch.no_grad()
def sample_one_step(self, x_t, time_step, c, prev_time_step, eta):
t = torch.full((x_t.shape[0],), time_step, device=x_t.device, dtype=torch.long)
prev_t = torch.full((x_t.shape[0],), prev_time_step, device=x_t.device, dtype=torch.long)
# get current and previous alpha_cumprod
alpha_t = extract(self.alpha_t_bar, t, x_t.shape)
alpha_t_prev = extract(self.alpha_t_bar, prev_t, x_t.shape)
# predict noise using model
epsilon_theta_t = self.model(x_t, t, c)
# calculate x_{t-1}
sigma_t = eta * torch.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev))
epsilon_t = torch.randn_like(x_t)
x_t_minus_one = (torch.sqrt(alpha_t_prev / alpha_t) * x_t +
(torch.sqrt(1 - alpha_t_prev - sigma_t ** 2) - torch.sqrt(
(alpha_t_prev * (1 - alpha_t)) / alpha_t)) * epsilon_theta_t +
sigma_t * epsilon_t)
return x_t_minus_one
@torch.no_grad()
def forward(self, x_t, c, steps=60, method="linear", eta=0.05, only_return_x_0=True, interval=1, **kwargs):
if steps == 0:
return c
if method == "linear":
a = self.T // steps
time_steps = np.asarray(list(range(0, self.T, a)))
elif method == "quadratic":
time_steps = (np.linspace(0, np.sqrt(self.T * 0.8), steps) ** 2).astype(np.int)
else: # NotImplementedError
raise NotImplementedError(f"sampling method {method} is not implemented!")
# add one to get the final alpha values right (the ones from first scale to data during sampling)
time_steps = time_steps + 1
# previous sequence
time_steps_prev = np.concatenate([[0], time_steps[:-1]])
x = [x_t]
for i in reversed(range(0, steps)):
x_t = self.sample_one_step(x_t, time_steps[i], c, time_steps_prev[i], eta)
if not only_return_x_0 and ((steps - i) % interval == 0 or i == 0):
x.append(x_t)
if only_return_x_0:
return x_t # [batch_size x channels, dim]
return torch.stack(x, dim=1) # [batch_size x channels, sample, dim]
class DiffusionLoss(nn.Module):
config = {}
def __init__(self):
super().__init__()
self.net = ConditionalUNet(
layer_channels=self.config["layer_channels"],
model_dim=self.config["model_dim"],
condition_dim=self.config["condition_dim"],
kernel_size=self.config["kernel_size"],
)
self.diffusion_trainer = GaussianDiffusionTrainer(
model=self.net,
beta=self.config["beta"],
T=self.config["T"]
)
self.diffusion_sampler = self.config["sample_mode"](
model=self.net,
beta=self.config["beta"],
T=self.config["T"]
)
def forward(self, x, c, **kwargs):
if kwargs.get("parameter_weight_decay"):
x = x * (1.0 - kwargs["parameter_weight_decay"])
# Given condition z and ground truth token x, compute loss
x = x.view(-1, x.size(-1))
c = c.view(-1, c.size(-1))
real_batch = x.size(0)
batch = self.config.get("diffusion_batch")
if self.config.get("forward_once"):
random_indices = torch.randperm(x.size(0))[:batch]
x, c = x[random_indices], c[random_indices]
real_batch = x.size(0)
if batch is not None and real_batch > batch:
loss = 0.
num_loops = x.size(0) // batch if x.size(0) % batch != 0 else x.size(0) // batch - 1
for _ in range(num_loops):
loss += self.diffusion_trainer(x[:batch], c[:batch], **kwargs) * batch
x, c = x[batch:], c[batch:]
loss += self.diffusion_trainer(x, c, **kwargs) * x.size(0)
loss = loss / real_batch
else: # all as a batch
loss = self.diffusion_trainer(x, c, **kwargs)
return loss
@torch.no_grad()
def sample(self, x, c, **kwargs):
# Given condition and noise, sample x using reverse diffusion process
# Given condition z and ground truth token x, compute loss
batch = self.config.get("diffusion_batch")
# if batch is not None:
# batch = max(batch, 256)
x_shape = x.shape
x = x.view(-1, x.size(-1))
c = c.view(-1, c.size(-1))
if kwargs.get("only_return_x_0") is False:
diffusion_steps = self.diffusion_sampler(x, c, **kwargs)
return torch.permute(diffusion_steps, (1, 0, 2)) # [sample, 1 x channels, dim]
if batch is not None and x.size(0) > batch:
result = []
num_loops = x.size(0) // batch if x.size(0) % batch != 0 else x.size(0) // batch - 1
for _ in range(num_loops):
result.append(self.diffusion_sampler(x[:batch], c[:batch], **kwargs))
x, c = x[batch:], c[batch:]
result.append(self.diffusion_sampler(x, c, **kwargs))
return torch.cat(result, dim=0).view(x_shape)
else: # all as a batch
return self.diffusion_sampler(x, c, **kwargs).view(x_shape) |