LD3 / samplers /dpm_solverpp.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
import torch
from samplers.general_solver import ODESolver
from samplers.general_solver import update_lists
class DPM_SolverPP(ODESolver):
def __init__(
self,
noise_schedule,
algorithm_type="data_prediction",
):
super().__init__(noise_schedule, algorithm_type)
self.noise_schedule = noise_schedule
def dpm_solver_first_update(self, x, s, t, model_s=None):
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
return x_t
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t):
ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
phi_1 = torch.expm1(-h)
x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
return x_t
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t):
ns = self.noise_schedule
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_2),
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
phi_1 = torch.expm1(-h)
phi_2 = phi_1 / h + 1.0
phi_3 = phi_2 / h - 0.5
x_t = (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
+ (alpha_t * phi_2) * D1
- (alpha_t * phi_3) * D2
)
return x_t
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order):
if order == 1:
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
elif order == 2:
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t)
elif order == 3:
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t)
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def one_step(self, t1, t2, t_prev_list, model_prev_list, step, x_next, order, first=True):
x_next = self.multistep_dpm_solver_update(x_next, model_prev_list, t_prev_list, t1, step)
model_x_next = None
if model_x_next is None:
model_x_next = self.model_fn(x_next, t2)
update_lists(t_prev_list, model_prev_list, t1, model_x_next, order, first=first)
return x_next
def sample(
self,
model_fn,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
lower_order_final=True,
flags=None,
):
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
t_0 = self.noise_schedule.eps if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
timesteps, timesteps2 = self.prepare_timesteps(steps=steps, t_start=t_T, t_end=t_0, skip_type=skip_type, device=device, load_from=flags.load_from)
with torch.no_grad():
return self.sample_simple(model_fn, x, order, lower_order_final, timesteps, timesteps2)
def sample_simple(self, model_fn, x, timesteps, timesteps2, order=2, lower_order_final=True, condition=None, unconditional_condition=None, **kwargs):
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])), condition, unconditional_condition)
step = 0
t1 = timesteps[step]
t2 = timesteps2[step]
steps = len(timesteps) - 1
t_prev_list = [t1]
model_prev_list = [self.model_fn(x, t2)]
for step in range(1, order):
t1 = timesteps[step]
t2 = timesteps2[step]
x = self.one_step(t1, t2, t_prev_list, model_prev_list, step, x, order, first=True)
for step in range(order, steps + 1):
t1 = timesteps[step]
t2 = timesteps2[step]
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
x = self.one_step(t1, t2, t_prev_list, model_prev_list, step_order, x, order, first=False)
return x