| import torch |
| import math |
|
|
|
|
| def interpolate_fn(x, xp, yp): |
| """ |
| A piecewise linear function y = f(x), using xp and yp as keypoints. |
| We implement f(x) in a differentiable way (i.e. applicable for autograd). |
| The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) |
| |
| Args: |
| x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). |
| xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. |
| yp: PyTorch tensor with shape [C, K]. |
| Returns: |
| The function values f(x), with shape [N, C]. |
| """ |
| N, K = x.shape[0], xp.shape[1] |
| all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) |
| sorted_all_x, x_indices = torch.sort(all_x, dim=2) |
| x_idx = torch.argmin(x_indices, dim=2) |
| cand_start_idx = x_idx - 1 |
| start_idx = torch.where( |
| torch.eq(x_idx, 0), |
| torch.tensor(1, device=x.device), |
| torch.where( |
| torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, |
| ), |
| ) |
| end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) |
| start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) |
| end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) |
| start_idx2 = torch.where( |
| torch.eq(x_idx, 0), |
| torch.tensor(0, device=x.device), |
| torch.where( |
| torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, |
| ), |
| ) |
| y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) |
| start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) |
| end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) |
| cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) |
| return cand |
|
|
|
|
|
|
| class NoiseScheduleVP: |
| ''' |
| for VP, t here is always from 0 to 1. |
| It can be scaled latter when used in the denoise model. ---> model_wrapper need to handle this. |
| ''' |
| def __init__( |
| self, |
| schedule='discrete', |
| betas=None, |
| alphas_cumprod=None, |
| continuous_beta_0=0.1, |
| continuous_beta_1=20., |
| dtype=torch.float32, |
| eps=1e-3, |
| ): |
|
|
| if schedule not in ['discrete', 'linear', 'cosine']: |
| raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) |
|
|
| self.schedule = schedule |
| if schedule == 'discrete': |
| if betas is not None: |
| log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) |
| else: |
| assert alphas_cumprod is not None |
| log_alphas = 0.5 * torch.log(alphas_cumprod) |
| self.total_N = len(log_alphas) |
| self.T = 1. |
| self.log_alpha_array = ( |
| self.numerical_clip_alpha(log_alphas) |
| .reshape( |
| ( |
| 1, |
| -1, |
| ) |
| ) |
| .to(dtype=dtype) |
| ) |
| self.eps = 1 / self.total_N |
| self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) |
| |
| else: |
| self.total_N = 1000 |
| self.beta_0 = continuous_beta_0 |
| self.beta_1 = continuous_beta_1 |
| self.cosine_s = 0.008 |
| self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) |
| self.schedule = schedule |
| if schedule == 'cosine': |
| self.T = 0.9946 |
| else: |
| self.T = 1. |
|
|
| self.lambda_max = self.marginal_lambda(eps).item() |
| self.lambda_min = self.marginal_lambda(self.T).item() |
|
|
| self.eps = eps |
| |
| @staticmethod |
| def derivative(f, t, h=1e-6): |
| """ |
| Calculate the derivative of the function f at point t using finite difference method. |
| |
| Parameters: |
| f (function): The function for which the derivative is to be calculated. |
| t (float): The point at which to calculate the derivative. |
| h (float): The step size for numerical differentiation. Default is 1e-6. |
| |
| Returns: |
| float: The numerical approximation of the derivative of f at t. |
| """ |
| return (f(t + h) - f(t)) / h |
|
|
| def update_lambda_max(self, eps): |
| self.lambda_max = self.marginal_lambda(eps).item() |
| |
| def update_time_min(self, lambda_max): |
| self.eps = self.inverse_lambda(lambda_max).item() |
|
|
| def marginal_log_mean_coeff(self, t): |
| """ |
| Compute log(alpha_t) of a given continuous-time label t in [0, T]. |
| """ |
| |
| if not isinstance(t, torch.Tensor): |
| t = torch.tensor(t) |
| if self.schedule == 'discrete': |
| return interpolate_fn( |
| t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device) |
| ).reshape((-1)) |
| elif self.schedule == 'linear': |
| return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 |
| elif self.schedule == 'cosine': |
| log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) |
| log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 |
| return log_alpha_t |
| else: |
| raise ValueError("Unsupported noise schedule {}".format(self.schedule)) |
| |
| def dalpha_dt(self, t): |
| if self.schedule == 'cosine': |
| return 0.5 * math.pi * torch.sin((t + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.) / (1. + self.cosine_s) |
| else: |
| return -0.5 * t * (self.beta_1 - self.beta_0) - 0.5 * self.beta_0 |
| |
| def dsigma_dt(self, t): |
| alpha_t = torch.exp(self.marginal_log_mean_coeff(t)) |
| |
| |
| return - alpha_t * self.dalpha_dt(t) / torch.sqrt(1. - alpha_t ** 2) |
| |
| def ft(self, t): |
| if self.schedule == 'discrete': |
| return NoiseScheduleVP.derivative(self.marginal_log_mean_coeff, t) |
| dalpha_dt = self.dalpha_dt(t) |
| alpha_t = torch.exp(self.marginal_log_mean_coeff(t)) |
| return dalpha_dt / alpha_t |
| |
| def gt(self, t): |
| if self.schedule == 'discrete': |
| marginal_std_at_t = self.marginal_std(t) |
| return 2 * marginal_std_at_t * NoiseScheduleVP.derivative(self.marginal_std, t) - 2 * (marginal_std_at_t ** 2 )* self.ft(t) |
| |
| |
| gt = torch.sqrt(2 * self.marginal_std(t) * self.dsigma_dt(t) - 2 * self.ft(t) * self.marginal_std(t) ** 2) |
| return gt |
|
|
| def marginal_alpha(self, t): |
| """ |
| Compute alpha_t of a given continuous-time label t in [0, T]. |
| """ |
| if not isinstance(t, torch.Tensor): |
| t = torch.tensor(t) |
| return torch.exp(self.marginal_log_mean_coeff(t)) |
|
|
| def marginal_std(self, t): |
| """ |
| Compute sigma_t of a given continuous-time label t in [0, T]. |
| """ |
| if not isinstance(t, torch.Tensor): |
| t = torch.tensor(t) |
| return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) |
|
|
| def marginal_lambda(self, t): |
| """ |
| Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. |
| """ |
| if not isinstance(t, torch.Tensor): |
| t = torch.tensor(t) |
| log_mean_coeff = self.marginal_log_mean_coeff(t) |
| log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) |
| return log_mean_coeff - log_std |
|
|
| def inverse_lambda(self, lamb): |
| """ |
| Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. |
| """ |
| if not isinstance(lamb, torch.Tensor): |
| lamb = torch.tensor(lamb) |
| scalar=True |
| |
| if self.schedule == 'linear': |
| tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) |
| Delta = self.beta_0**2 + tmp |
| ret = tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) |
| elif self.schedule == 'discrete': |
| log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) |
| t = interpolate_fn( |
| log_alpha.reshape((-1, 1)), |
| torch.flip(self.log_alpha_array.to(lamb.device), [1]), |
| torch.flip(self.t_array.to(lamb.device), [1]), |
| ) |
| ret = t.reshape((-1,)) |
| elif self.schedule == 'cosine': |
| log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) |
| t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s |
| ret = t_fn(log_alpha) |
| else: |
| raise ValueError("Unsupported noise schedule {}".format(self.schedule)) |
| |
| return ret |
| |
| def prior_transformation(self, latents): |
| return latents |
| |
| def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): |
| """ |
| For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. |
| We clip the log-SNR near t=T within -5.1 to ensure the stability. |
| Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. |
| """ |
| log_sigmas = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_alphas)) |
| lambs = log_alphas - log_sigmas |
| idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) |
| if idx > 0: |
| log_alphas = log_alphas[:-idx] |
| return log_alphas |
|
|
|
|
| class NoiseScheduleVE: |
| def __init__( |
| self, |
| schedule='edm', |
| sigma_min=0.002, |
| sigma_max=80., |
| N=1000, |
| ): |
| """Create a wrapper class for the forward SDE (VE type). |
| |
| The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). |
| We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). |
| Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: |
| |
| log_alpha_t = self.marginal_log_mean_coeff(t) |
| sigma_t = self.marginal_std(t) |
| lambda_t = self.marginal_lambda(t) |
| |
| Moreover, as lambda(t) is an invertible function, we also support its inverse function: |
| |
| t = self.inverse_lambda(lambda_t) |
| |
| =============================================================== |
| |
| Args: |
| sigma_min: A `float` number. The smallest sigma for the VE schedule. |
| sigma_max: A `float` number. The largest sigma for the VE schedule. |
| N: An `int` number. The number of time steps for the VE schedule. |
| Returns: |
| A wrapper object of the forward SDE (VE type). |
| """ |
|
|
| self.schedule = schedule |
| self.sigma_min = sigma_min |
| self.sigma_max = sigma_max |
| self.N = N |
| |
| if schedule == 'heat': |
| self.T = self.sigma_max ** 2 |
| self.eps = self.sigma_min ** 2 |
| elif schedule == 'edm': |
| self.T = self.sigma_max |
| self.eps = self.sigma_min |
| |
| self.lambda_max = self.marginal_lambda(self.eps).item() |
| self.lambda_min = self.marginal_lambda(self.T).item() |
|
|
| ''' |
| sdeVE |
| ''' |
| |
| def ft(self, t): |
| return 0. |
| |
| def gt(self, t): |
| if self.schedule == 'heat': |
| return 1. |
| elif self.schedule == 'edm': |
| return torch.sqrt(2 * t) |
| |
| def marginal_log_mean_coeff(self, t): |
| """ |
| Compute log(alpha_t) of a given continuous-time label t in [0, T]. |
| """ |
| if not isinstance(t, torch.Tensor): |
| t = torch.tensor(t) |
| return torch.zeros_like(t) |
| |
| def marginal_alpha(self, t): |
| """ |
| Compute alpha_t of a given continuous-time label t in [0, T]. |
| """ |
| if not isinstance(t, torch.Tensor): |
| t = torch.tensor(t) |
| return torch.ones_like(t) |
| |
| def marginal_std(self, t): |
| """ |
| Compute sigma_t of a given continuous-time label t in [0, T]. |
| """ |
| if not isinstance(t, torch.Tensor): |
| t = torch.tensor(t) |
| if self.schedule == 'heat': |
| return torch.sqrt(t) |
| elif self.schedule == 'edm': |
| return t |
| |
| def inverse_std(self, sigma): |
| """ |
| Compute the continuous-time label t in [0, T] of a given standard deviation sigma_t. |
| """ |
| if not isinstance(sigma, torch.Tensor): |
| sigma = torch.tensor(sigma) |
| if self.schedule == 'heat': |
| return sigma ** 2 |
| elif self.schedule == 'edm': |
| return sigma |
| |
| def marginal_lambda(self, t): |
| """ |
| Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. |
| """ |
| |
| if not isinstance(t, torch.Tensor): |
| t = torch.tensor(t) |
| if self.schedule == 'heat': |
| return -0.5 * torch.log(t) |
| elif self.schedule == 'edm': |
| return -torch.log(t) |
| |
| |
|
|
| def inverse_lambda(self, lamb): |
| """ |
| Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. |
| """ |
| |
| if not isinstance(lamb, torch.Tensor): |
| lamb = torch.tensor(lamb) |
| |
| if self.schedule == 'heat': |
| return torch.exp(-2. * lamb) |
| elif self.schedule == 'edm': |
| return torch.exp(-lamb) |
| |
| def prior_transformation(self, latents): |
| if self.schedule == 'heat': |
| return latents * torch.sqrt(self.T) |
| elif self.schedule == 'edm': |
| return latents * self.T |
|
|
|
|