File size: 9,874 Bytes
f7009b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import torch
from torch import nn
import torch.nn.functional as F
from .denoiser import ConditionalUNet
import numpy as np


def extract(v, i, shape):
    out = torch.gather(v, index=i, dim=0)
    out = out.to(device=i.device, dtype=torch.float32)
    # reshape to (batch_size, 1, 1, 1, 1, ...) for broadcasting purposes.
    out = out.view([i.shape[0]] + [1] * (len(shape) - 1))
    return out


class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
        super().__init__()
        self.model = model
        self.T = T
        # generate T steps of beta
        self.register_buffer("beta_t", torch.linspace(*beta, T, dtype=torch.float32))
        # calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
        alpha_t = 1.0 - self.beta_t
        alpha_t_bar = torch.cumprod(alpha_t, dim=0)
        # calculate and store two coefficient of $q(x_t | x_0)$
        self.register_buffer("signal_rate", torch.sqrt(alpha_t_bar))
        self.register_buffer("noise_rate", torch.sqrt(1.0 - alpha_t_bar))

    def forward(self, x_0, z, **kwargs):
        # preprocess nan to zero
        mask = torch.isnan(x_0)
        x_0 = torch.nan_to_num(x_0, 0.)
        # get a random training step $t \sim Uniform({1, ..., T})$
        t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)
        # generate $\epsilon \sim N(0, 1)$
        epsilon = torch.randn_like(x_0)
        # predict the noise added from $x_{t-1}$ to $x_t$
        x_t = (extract(self.signal_rate, t, x_0.shape) * x_0 +
               extract(self.noise_rate, t, x_0.shape) * epsilon)
        epsilon_theta = self.model(x_t, t, z)
        # get the gradient
        loss = F.mse_loss(epsilon_theta, epsilon, reduction="none")
        loss[mask] = torch.nan
        return loss.nanmean()


class DDPMSampler(nn.Module):
    def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
        super().__init__()
        self.model = model
        self.T = T
        # generate T steps of beta
        self.register_buffer("beta_t", torch.linspace(*beta, T, dtype=torch.float32))
        # calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
        alpha_t = 1.0 - self.beta_t
        alpha_t_bar = torch.cumprod(alpha_t, dim=0)
        alpha_t_bar_prev = F.pad(alpha_t_bar[:-1], (1, 0), value=1.0)
        self.register_buffer("coeff_1", torch.sqrt(1.0 / alpha_t))
        self.register_buffer("coeff_2", self.coeff_1 * (1.0 - alpha_t) / torch.sqrt(1.0 - alpha_t_bar))
        self.register_buffer("posterior_variance", self.beta_t * (1.0 - alpha_t_bar_prev) / (1.0 - alpha_t_bar))

    @torch.no_grad()
    def cal_mean_variance(self, x_t, t, c):
        # """ Calculate the mean and variance for $q(x_{t-1} | x_t, x_0)$ """
        epsilon_theta = self.model(x_t, t, c)
        mean = extract(self.coeff_1, t, x_t.shape) * x_t - extract(self.coeff_2, t, x_t.shape) * epsilon_theta
        # var is a constant
        var = extract(self.posterior_variance, t, x_t.shape)
        return mean, var

    @torch.no_grad()
    def sample_one_step(self, x_t, time_step, c):
        # """ Calculate $x_{t-1}$ according to $x_t$ """
        t = torch.full((x_t.shape[0],), time_step, device=x_t.device, dtype=torch.long)
        mean, var = self.cal_mean_variance(x_t, t, c)
        z = torch.randn_like(x_t) if time_step > 0 else 0
        x_t_minus_one = mean + torch.sqrt(var) * z
        if torch.isnan(x_t_minus_one).int().sum() != 0:
            raise ValueError("nan in tensor!")
        return x_t_minus_one

    @torch.no_grad()
    def forward(self, x_t, c, only_return_x_0=True, interval=1, **kwargs):
        x = [x_t]
        for time_step in reversed(range(self.T)):
            x_t = self.sample_one_step(x_t, time_step, c)
            if not only_return_x_0 and ((self.T - time_step) % interval == 0 or time_step == 0):
                x.append(x_t)
        if only_return_x_0:
            return x_t  # [batch_size, channels, height, width]
        return torch.stack(x, dim=1)  # [batch_size, sample, channels, height, width]


class DDIMSampler(nn.Module):
    def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
        super().__init__()
        self.model = model
        self.T = T
        # generate T steps of beta
        beta_t = torch.linspace(*beta, T, dtype=torch.float32)
        # calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
        alpha_t = 1.0 - beta_t
        self.register_buffer("alpha_t_bar", torch.cumprod(alpha_t, dim=0))

    @torch.no_grad()
    def sample_one_step(self, x_t, time_step, c, prev_time_step, eta):
        t = torch.full((x_t.shape[0],), time_step, device=x_t.device, dtype=torch.long)
        prev_t = torch.full((x_t.shape[0],), prev_time_step, device=x_t.device, dtype=torch.long)
        # get current and previous alpha_cumprod
        alpha_t = extract(self.alpha_t_bar, t, x_t.shape)
        alpha_t_prev = extract(self.alpha_t_bar, prev_t, x_t.shape)
        # predict noise using model
        epsilon_theta_t = self.model(x_t, t, c)
        # calculate x_{t-1}
        sigma_t = eta * torch.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev))
        epsilon_t = torch.randn_like(x_t)
        x_t_minus_one = (torch.sqrt(alpha_t_prev / alpha_t) * x_t +
                         (torch.sqrt(1 - alpha_t_prev - sigma_t ** 2) - torch.sqrt(
                             (alpha_t_prev * (1 - alpha_t)) / alpha_t)) * epsilon_theta_t +
                         sigma_t * epsilon_t)
        return x_t_minus_one

    @torch.no_grad()
    def forward(self, x_t, c, steps=60, method="linear", eta=0.05, only_return_x_0=True, interval=1, **kwargs):
        if steps == 0:
            return c
        if method == "linear":
            a = self.T // steps
            time_steps = np.asarray(list(range(0, self.T, a)))
        elif method == "quadratic":
            time_steps = (np.linspace(0, np.sqrt(self.T * 0.8), steps) ** 2).astype(np.int)
        else:  # NotImplementedError
            raise NotImplementedError(f"sampling method {method} is not implemented!")
        # add one to get the final alpha values right (the ones from first scale to data during sampling)
        time_steps = time_steps + 1
        # previous sequence
        time_steps_prev = np.concatenate([[0], time_steps[:-1]])
        x = [x_t]
        for i in reversed(range(0, steps)):
            x_t = self.sample_one_step(x_t, time_steps[i], c, time_steps_prev[i], eta)
            if not only_return_x_0 and ((steps - i) % interval == 0 or i == 0):
                x.append(x_t)
        if only_return_x_0:
            return x_t  # [batch_size x channels, dim]
        return torch.stack(x, dim=1)  # [batch_size x channels, sample, dim]




class DiffusionLoss(nn.Module):
    config = {}

    def __init__(self):
        super().__init__()
        self.net = ConditionalUNet(
            layer_channels=self.config["layer_channels"],
            model_dim=self.config["model_dim"],
            condition_dim=self.config["condition_dim"],
            kernel_size=self.config["kernel_size"],
        )
        self.diffusion_trainer = GaussianDiffusionTrainer(
            model=self.net,
            beta=self.config["beta"],
            T=self.config["T"]
        )
        self.diffusion_sampler = self.config["sample_mode"](
            model=self.net,
            beta=self.config["beta"],
            T=self.config["T"]
        )

    def forward(self, x, c, **kwargs):
        if kwargs.get("parameter_weight_decay"):
            x = x * (1.0 - kwargs["parameter_weight_decay"])
        # Given condition z and ground truth token x, compute loss
        x = x.view(-1, x.size(-1))
        c = c.view(-1, c.size(-1))
        real_batch = x.size(0)
        batch = self.config.get("diffusion_batch")
        if self.config.get("forward_once"):
            random_indices = torch.randperm(x.size(0))[:batch]
            x, c = x[random_indices], c[random_indices]
            real_batch = x.size(0)
        if batch is not None and real_batch > batch:
            loss = 0.
            num_loops = x.size(0) // batch if x.size(0) % batch != 0 else x.size(0) // batch - 1
            for _ in range(num_loops):
                loss += self.diffusion_trainer(x[:batch], c[:batch], **kwargs) * batch
                x, c = x[batch:], c[batch:]
            loss += self.diffusion_trainer(x, c, **kwargs) * x.size(0)
            loss = loss / real_batch
        else:  # all as a batch
            loss = self.diffusion_trainer(x, c, **kwargs)
        return loss

    @torch.no_grad()
    def sample(self, x, c, **kwargs):
        # Given condition and noise, sample x using reverse diffusion process
        # Given condition z and ground truth token x, compute loss
        batch = self.config.get("diffusion_batch")
        # if batch is not None:
        #     batch = max(batch, 256)
        x_shape = x.shape
        x = x.view(-1, x.size(-1))
        c = c.view(-1, c.size(-1))
        if kwargs.get("only_return_x_0") is False:
            diffusion_steps = self.diffusion_sampler(x, c, **kwargs)
            return torch.permute(diffusion_steps, (1, 0, 2))  # [sample, 1 x channels, dim]
        if batch is not None and x.size(0) > batch:
            result = []
            num_loops = x.size(0) // batch if x.size(0) % batch != 0 else x.size(0) // batch - 1
            for _ in range(num_loops):
                result.append(self.diffusion_sampler(x[:batch], c[:batch], **kwargs))
                x, c = x[batch:], c[batch:]
            result.append(self.diffusion_sampler(x, c, **kwargs))
            return torch.cat(result, dim=0).view(x_shape)
        else:  # all as a batch
            return self.diffusion_sampler(x, c, **kwargs).view(x_shape)