|
from tqdm import tqdm |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
@torch.no_grad() |
|
def sample_stage_1(model, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
views, |
|
num_inference_steps=100, |
|
guidance_scale=7.0, |
|
reduction='mean', |
|
generator=None): |
|
|
|
|
|
num_images_per_prompt = 1 |
|
device = model.device |
|
height = model.unet.config.sample_size |
|
width = model.unet.config.sample_size |
|
batch_size = 1 |
|
num_prompts = prompt_embeds.shape[0] |
|
assert num_prompts == len(views), \ |
|
"Number of prompts must match number of views!" |
|
|
|
|
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
|
|
|
|
|
model.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = model.scheduler.timesteps |
|
|
|
|
|
noisy_images = model.prepare_intermediate_images( |
|
batch_size * num_images_per_prompt, |
|
model.unet.config.in_channels, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
) |
|
|
|
for i, t in enumerate(tqdm(timesteps)): |
|
|
|
viewed_noisy_images = [] |
|
for view_fn in views: |
|
viewed_noisy_images.append(view_fn.view(noisy_images[0])) |
|
viewed_noisy_images = torch.stack(viewed_noisy_images) |
|
|
|
|
|
|
|
model_input = torch.cat([viewed_noisy_images] * 2) |
|
model_input = model.scheduler.scale_model_input(model_input, t) |
|
|
|
|
|
noise_pred = model.unet( |
|
model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=None, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
|
|
|
inverted_preds = [] |
|
for pred, view in zip(noise_pred_uncond, views): |
|
inverted_pred = view.inverse_view(pred) |
|
inverted_preds.append(inverted_pred) |
|
noise_pred_uncond = torch.stack(inverted_preds) |
|
|
|
|
|
inverted_preds = [] |
|
for pred, view in zip(noise_pred_text, views): |
|
inverted_pred = view.inverse_view(pred) |
|
inverted_preds.append(inverted_pred) |
|
noise_pred_text = torch.stack(inverted_preds) |
|
|
|
|
|
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) |
|
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
noise_pred = noise_pred.view(-1,num_prompts,3,64,64) |
|
predicted_variance = predicted_variance.view(-1,num_prompts,3,64,64) |
|
if reduction == 'mean': |
|
noise_pred = noise_pred.mean(1) |
|
predicted_variance = predicted_variance.mean(1) |
|
elif reduction == 'alternate': |
|
noise_pred = noise_pred[:,i%num_prompts] |
|
predicted_variance = predicted_variance[:,i%num_prompts] |
|
else: |
|
raise ValueError('Reduction must be either `mean` or `alternate`') |
|
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) |
|
|
|
|
|
noisy_images = model.scheduler.step( |
|
noise_pred, t, noisy_images, generator=generator, return_dict=False |
|
)[0] |
|
|
|
|
|
return noisy_images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def sample_stage_2(model, |
|
image, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
views, |
|
num_inference_steps=100, |
|
guidance_scale=7.0, |
|
reduction='mean', |
|
noise_level=50, |
|
generator=None): |
|
|
|
|
|
batch_size = 1 |
|
num_prompts = prompt_embeds.shape[0] |
|
height = model.unet.config.sample_size |
|
width = model.unet.config.sample_size |
|
device = model.device |
|
num_images_per_prompt = 1 |
|
|
|
|
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
|
|
|
|
|
model.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = model.scheduler.timesteps |
|
|
|
num_channels = model.unet.config.in_channels // 2 |
|
noisy_images = model.prepare_intermediate_images( |
|
batch_size * num_images_per_prompt, |
|
num_channels, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
) |
|
|
|
|
|
image = model.preprocess_image(image, num_images_per_prompt, device) |
|
upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) |
|
|
|
noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) |
|
noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) |
|
upscaled = model.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) |
|
|
|
|
|
noise_level = torch.cat([noise_level] * num_prompts * 2) |
|
|
|
|
|
for i, t in enumerate(tqdm(timesteps)): |
|
|
|
model_input = torch.cat([noisy_images, upscaled], dim=1) |
|
|
|
|
|
viewed_inputs = [] |
|
for view_fn in views: |
|
viewed_inputs.append(view_fn.view(model_input[0])) |
|
viewed_inputs = torch.stack(viewed_inputs) |
|
|
|
|
|
|
|
model_input = torch.cat([viewed_inputs] * 2) |
|
model_input = model.scheduler.scale_model_input(model_input, t) |
|
|
|
|
|
noise_pred = model.unet( |
|
model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
class_labels=noise_level, |
|
cross_attention_kwargs=None, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
|
|
|
|
|
inverted_preds = [] |
|
for pred, view in zip(noise_pred_uncond, views): |
|
inverted_pred = view.inverse_view(pred) |
|
inverted_preds.append(inverted_pred) |
|
noise_pred_uncond = torch.stack(inverted_preds) |
|
|
|
|
|
inverted_preds = [] |
|
for pred, view in zip(noise_pred_text, views): |
|
inverted_pred = view.inverse_view(pred) |
|
inverted_preds.append(inverted_pred) |
|
noise_pred_text = torch.stack(inverted_preds) |
|
|
|
|
|
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) |
|
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
noise_pred = noise_pred.view(-1,num_prompts,3,256,256) |
|
predicted_variance = predicted_variance.view(-1,num_prompts,3,256,256) |
|
if reduction == 'mean': |
|
noise_pred = noise_pred.mean(1) |
|
predicted_variance = predicted_variance.mean(1) |
|
elif reduction == 'alternate': |
|
noise_pred = noise_pred[:,i%num_prompts] |
|
predicted_variance = predicted_variance[:,i%num_prompts] |
|
|
|
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) |
|
|
|
|
|
noisy_images = model.scheduler.step( |
|
noise_pred, t, noisy_images, generator=generator, return_dict=False |
|
)[0] |
|
|
|
|
|
return noisy_images |