|
import torch |
|
import math |
|
from tqdm import trange, tqdm |
|
|
|
import k_diffusion as K |
|
|
|
|
|
def get_alphas_sigmas(t): |
|
"""Returns the scaling factors for the clean image (alpha) and for the |
|
noise (sigma), given a timestep.""" |
|
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) |
|
|
|
def alpha_sigma_to_t(alpha, sigma): |
|
"""Returns a timestep, given the scaling factors for the clean image and for |
|
the noise.""" |
|
return torch.atan2(sigma, alpha) / math.pi * 2 |
|
|
|
def t_to_alpha_sigma(t): |
|
"""Returns the scaling factors for the clean image and for the noise, given |
|
a timestep.""" |
|
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) |
|
|
|
|
|
@torch.no_grad() |
|
def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): |
|
"""Draws samples from a model given starting noise. Euler method""" |
|
|
|
|
|
ts = x.new_ones([x.shape[0]]) |
|
|
|
|
|
t = torch.linspace(sigma_max, 0, steps + 1) |
|
|
|
|
|
|
|
for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): |
|
|
|
t_curr_tensor = t_curr * torch.ones( |
|
(x.shape[0],), dtype=x.dtype, device=x.device |
|
) |
|
dt = t_prev - t_curr |
|
x = x + dt * model(x, t_curr_tensor, **extra_args) |
|
|
|
|
|
return x |
|
|
|
@torch.no_grad() |
|
def sample(model, x, steps, eta, **extra_args): |
|
"""Draws samples from a model given starting noise. v-diffusion""" |
|
ts = x.new_ones([x.shape[0]]) |
|
|
|
|
|
t = torch.linspace(1, 0, steps + 1)[:-1] |
|
|
|
alphas, sigmas = get_alphas_sigmas(t) |
|
|
|
|
|
for i in trange(steps): |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
v = model(x, ts * t[i], **extra_args).float() |
|
|
|
|
|
pred = x * alphas[i] - v * sigmas[i] |
|
eps = x * sigmas[i] + v * alphas[i] |
|
|
|
|
|
|
|
if i < steps - 1: |
|
|
|
|
|
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ |
|
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() |
|
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() |
|
|
|
|
|
|
|
x = pred * alphas[i + 1] + eps * adjusted_sigma |
|
|
|
|
|
if eta: |
|
x += torch.randn_like(x) * ddim_sigma |
|
|
|
|
|
return pred |
|
|
|
|
|
|
|
def get_bmask(i, steps, mask): |
|
strength = (i+1)/(steps) |
|
|
|
bmask = torch.where(mask<=strength,1,0) |
|
return bmask |
|
|
|
def make_cond_model_fn(model, cond_fn): |
|
def cond_model_fn(x, sigma, **kwargs): |
|
with torch.enable_grad(): |
|
x = x.detach().requires_grad_() |
|
denoised = model(x, sigma, **kwargs) |
|
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() |
|
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) |
|
return cond_denoised |
|
return cond_model_fn |
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_k( |
|
model_fn, |
|
noise, |
|
init_data=None, |
|
mask=None, |
|
steps=100, |
|
sampler_type="dpmpp-2m-sde", |
|
sigma_min=0.5, |
|
sigma_max=50, |
|
rho=1.0, device="cuda", |
|
callback=None, |
|
cond_fn=None, |
|
**extra_args |
|
): |
|
|
|
denoiser = K.external.VDenoiser(model_fn) |
|
|
|
if cond_fn is not None: |
|
denoiser = make_cond_model_fn(denoiser, cond_fn) |
|
|
|
|
|
sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) |
|
|
|
noise = noise * sigmas[0] |
|
|
|
wrapped_callback = callback |
|
|
|
if mask is None and init_data is not None: |
|
|
|
|
|
x = init_data + noise |
|
elif mask is not None and init_data is not None: |
|
|
|
bmask = get_bmask(0, steps, mask) |
|
|
|
input_noised = init_data + noise |
|
|
|
x = input_noised * bmask + noise * (1-bmask) |
|
|
|
|
|
|
|
|
|
def inpainting_callback(args): |
|
i = args["i"] |
|
x = args["x"] |
|
sigma = args["sigma"] |
|
|
|
|
|
input_noised = init_data + torch.randn_like(init_data) * sigma |
|
|
|
bmask = get_bmask(i, steps, mask) |
|
|
|
new_x = input_noised * bmask + x * (1-bmask) |
|
|
|
x[:,:,:] = new_x[:,:,:] |
|
|
|
if callback is None: |
|
wrapped_callback = inpainting_callback |
|
else: |
|
wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) |
|
else: |
|
|
|
|
|
x = noise |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
if sampler_type == "k-heun": |
|
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
elif sampler_type == "k-lms": |
|
return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
elif sampler_type == "k-dpmpp-2s-ancestral": |
|
return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
elif sampler_type == "k-dpm-2": |
|
return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
elif sampler_type == "k-dpm-fast": |
|
return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
elif sampler_type == "k-dpm-adaptive": |
|
return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
elif sampler_type == "dpmpp-2m-sde": |
|
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
elif sampler_type == "dpmpp-3m-sde": |
|
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_rf( |
|
model_fn, |
|
noise, |
|
init_data=None, |
|
steps=100, |
|
sigma_max=1, |
|
device="cuda", |
|
callback=None, |
|
cond_fn=None, |
|
**extra_args |
|
): |
|
|
|
if sigma_max > 1: |
|
sigma_max = 1 |
|
|
|
if cond_fn is not None: |
|
denoiser = make_cond_model_fn(denoiser, cond_fn) |
|
|
|
wrapped_callback = callback |
|
|
|
if init_data is not None: |
|
|
|
|
|
x = init_data * (1 - sigma_max) + noise * sigma_max |
|
else: |
|
|
|
|
|
x = noise |
|
|
|
with torch.cuda.amp.autocast(): |
|
|
|
|
|
return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) |