pix2pix-zero-01 / src /utils /edit_pipeline.py
ysharma's picture
ysharma HF staff
upload git code base
d950775
import pdb, sys
import numpy as np
import torch
from typing import Any, Callable, Dict, List, Optional, Union
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
sys.path.insert(0, "src/utils")
from base_pipeline import BasePipeline
from cross_attention import prep_unet
class EditingPipeline(BasePipeline):
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# pix2pix parameters
guidance_amount=0.1,
edit_dir=None,
x_in=None,
):
x_in.to(dtype=self.unet.dtype, device=self._execution_device)
# 0. modify the unet to be useful :D
self.unet = prep_unet(self.unet)
# 1. setup all caching objects
d_ref_t2attn = {} # reference cross attention maps
# 2. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# TODO: add the input checker function
# self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds )
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device)
# 3. Encode input prompt = 2x77x1024
prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
# randomly sample a latent code if not provided
latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,)
latents_init = latents.clone()
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. First Denoising loop for getting the reference cross attention maps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with torch.no_grad():
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
# add the cross attention map to the dictionary
d_ref_t2attn[t.item()] = {}
for name, module in self.unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and 'attn2' in name:
attn_mask = module.attn_probs # size is num_channel,s*s,77
d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu()
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# make the reference image (reconstruction)
image_rec = self.numpy_to_pil(self.decode_latents(latents.detach()))
prompt_embeds_edit = prompt_embeds.clone()
#add the edit only to the second prompt, idx 0 is the negative prompt
prompt_embeds_edit[1:2] += edit_dir
latents = latents_init
# Second denoising loop for editing the text prompt
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
x_in = latent_model_input.detach().clone()
x_in.requires_grad = True
opt = torch.optim.SGD([x_in], lr=guidance_amount)
# predict the noise residual
noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample
loss = 0.0
for name, module in self.unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and 'attn2' in name:
curr = module.attn_probs # size is num_channel,s*s,77
ref = d_ref_t2attn[t.item()][name].detach().cuda()
loss += ((curr-ref)**2).sum((1,2)).mean(0)
loss.backward(retain_graph=False)
opt.step()
# recompute the noise
with torch.no_grad():
noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample
latents = x_in.detach().chunk(2)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 8. Post-processing
image = self.decode_latents(latents.detach())
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL
image_edit = self.numpy_to_pil(image)
return image_rec, image_edit