callum-canavan's picture
Add helpers, change to hot dog example
954caab
raw
history blame
8.58 kB
from tqdm import tqdm
import torch
import torch.nn.functional as F
from diffusers.utils.torch_utils import randn_tensor
@torch.no_grad()
def sample_stage_1(model,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=100,
guidance_scale=7.0,
reduction='mean',
generator=None):
# Params
num_images_per_prompt = 1
device = model.device
height = model.unet.config.sample_size
width = model.unet.config.sample_size
batch_size = 1 # TODO: Support larger batch sizes, maybe
num_prompts = prompt_embeds.shape[0]
assert num_prompts == len(views), \
"Number of prompts must match number of views!"
# For CFG
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# Setup timesteps
model.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = model.scheduler.timesteps
# Make intermediate_images
noisy_images = model.prepare_intermediate_images(
batch_size * num_images_per_prompt,
model.unet.config.in_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
for i, t in enumerate(tqdm(timesteps)):
# Apply views to noisy_image
viewed_noisy_images = []
for view_fn in views:
viewed_noisy_images.append(view_fn.view(noisy_images[0]))
viewed_noisy_images = torch.stack(viewed_noisy_images)
# Duplicate inputs for CFG
# Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
model_input = torch.cat([viewed_noisy_images] * 2)
model_input = model.scheduler.scale_model_input(model_input, t)
# Predict noise estimate
noise_pred = model.unet(
model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
return_dict=False,
)[0]
# Extract uncond (neg) and cond noise estimates
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# Invert the unconditional (negative) estimates
inverted_preds = []
for pred, view in zip(noise_pred_uncond, views):
inverted_pred = view.inverse_view(pred)
inverted_preds.append(inverted_pred)
noise_pred_uncond = torch.stack(inverted_preds)
# Invert the conditional estimates
inverted_preds = []
for pred, view in zip(noise_pred_text, views):
inverted_pred = view.inverse_view(pred)
inverted_preds.append(inverted_pred)
noise_pred_text = torch.stack(inverted_preds)
# Split into noise estimate and variance estimates
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Reduce predicted noise and variances
noise_pred = noise_pred.view(-1,num_prompts,3,64,64)
predicted_variance = predicted_variance.view(-1,num_prompts,3,64,64)
if reduction == 'mean':
noise_pred = noise_pred.mean(1)
predicted_variance = predicted_variance.mean(1)
elif reduction == 'alternate':
noise_pred = noise_pred[:,i%num_prompts]
predicted_variance = predicted_variance[:,i%num_prompts]
else:
raise ValueError('Reduction must be either `mean` or `alternate`')
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
noisy_images = model.scheduler.step(
noise_pred, t, noisy_images, generator=generator, return_dict=False
)[0]
# Return denoised images
return noisy_images
@torch.no_grad()
def sample_stage_2(model,
image,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=100,
guidance_scale=7.0,
reduction='mean',
noise_level=50,
generator=None):
# Params
batch_size = 1 # TODO: Support larger batch sizes, maybe
num_prompts = prompt_embeds.shape[0]
height = model.unet.config.sample_size
width = model.unet.config.sample_size
device = model.device
num_images_per_prompt = 1
# For CFG
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# Get timesteps
model.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = model.scheduler.timesteps
num_channels = model.unet.config.in_channels // 2
noisy_images = model.prepare_intermediate_images(
batch_size * num_images_per_prompt,
num_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
# Prepare upscaled image and noise level
image = model.preprocess_image(image, num_images_per_prompt, device)
upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
upscaled = model.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
# Condition on noise level, for each model input
noise_level = torch.cat([noise_level] * num_prompts * 2)
# Denoising Loop
for i, t in enumerate(tqdm(timesteps)):
# Cat noisy image with upscaled conditioning image
model_input = torch.cat([noisy_images, upscaled], dim=1)
# Apply views to noisy_image
viewed_inputs = []
for view_fn in views:
viewed_inputs.append(view_fn.view(model_input[0]))
viewed_inputs = torch.stack(viewed_inputs)
# Duplicate inputs for CFG
# Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
model_input = torch.cat([viewed_inputs] * 2)
model_input = model.scheduler.scale_model_input(model_input, t)
# predict the noise residual
noise_pred = model.unet(
model_input,
t,
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=None,
return_dict=False,
)[0]
# Extract uncond (neg) and cond noise estimates
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# Invert the unconditional (negative) estimates
# TODO: pretty sure you can combine these into one loop
inverted_preds = []
for pred, view in zip(noise_pred_uncond, views):
inverted_pred = view.inverse_view(pred)
inverted_preds.append(inverted_pred)
noise_pred_uncond = torch.stack(inverted_preds)
# Invert the conditional estimates
inverted_preds = []
for pred, view in zip(noise_pred_text, views):
inverted_pred = view.inverse_view(pred)
inverted_preds.append(inverted_pred)
noise_pred_text = torch.stack(inverted_preds)
# Split predicted noise and predicted variances
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Combine noise estimates (and variance estimates)
noise_pred = noise_pred.view(-1,num_prompts,3,256,256)
predicted_variance = predicted_variance.view(-1,num_prompts,3,256,256)
if reduction == 'mean':
noise_pred = noise_pred.mean(1)
predicted_variance = predicted_variance.mean(1)
elif reduction == 'alternate':
noise_pred = noise_pred[:,i%num_prompts]
predicted_variance = predicted_variance[:,i%num_prompts]
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
noisy_images = model.scheduler.step(
noise_pred, t, noisy_images, generator=generator, return_dict=False
)[0]
# Return denoised images
return noisy_images