##### # Modified from https://github.com/huggingface/diffusers/blob/v0.29.1/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py # PhotoMaker v2 @ TencentARC and MCG-NKU # Author: Zhen Li ##### # Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from diffusers.models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.pipelines import StableDiffusionXLAdapterPipeline from diffusers.utils import _get_model_file from safetensors import safe_open from huggingface_hub.utils import validate_hf_hub_args from module.model_v2 import PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps def _preprocess_adapter_image(image, height, width): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] image = [ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image ] # expand [h, w] or [h, w, c] to [b, h, w, c] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): if image[0].ndim == 3: image = torch.stack(image, dim=0) elif image[0].ndim == 4: image = torch.cat(image, dim=0) else: raise ValueError( f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" ) return image class PhotoMakerStableDiffusionXLAdapterPipeline(StableDiffusionXLAdapterPipeline): @validate_hf_hub_args def load_photomaker_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], weight_name: str, subfolder: str = '', trigger_word: str = 'img', pm_version: str = 'v2', **kwargs, ): """ Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). weight_name (`str`): The weight name NOT the path to the weight. subfolder (`str`, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. trigger_word (`str`, *optional*, defaults to `"img"`): The trigger word is used to identify the position of class word in the text prompt, and it is recommended not to set it as a common word. This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. """ # Load the main state dict first. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) # resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } if not isinstance(pretrained_model_name_or_path_or_dict, dict): model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name, cache_dir=cache_dir, force_download=force_download, # resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) if weight_name.endswith(".safetensors"): state_dict = {"id_encoder": {}, "lora_weights": {}} with safe_open(model_file, framework="pt", device="cpu") as f: for key in f.keys(): if key.startswith("id_encoder."): state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) elif key.startswith("lora_weights."): state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) else: state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict keys = list(state_dict.keys()) if keys != ["id_encoder", "lora_weights"]: raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.") self.num_tokens = 2 self.trigger_word = trigger_word # load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet print(f"Loading PhotoMaker {pm_version} components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...") self.id_image_processor = CLIPImageProcessor() if pm_version == "v1": # PhotoMaker v1 id_encoder = PhotoMakerIDEncoder() elif pm_version == "v2": # PhotoMaker v2 id_encoder = PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken() else: raise NotImplementedError(f"The PhotoMaker version [{pm_version}] does not support") id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) self.id_encoder = id_encoder # load lora into models print(f"Loading PhotoMaker {pm_version} components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]") self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") # Add trigger word token if self.tokenizer is not None: self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) def encode_prompt_with_trigger_word( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ### Added args num_id_images: int = 1, class_tokens_mask: Optional[torch.LongTensor] = None, ): device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Find the token id of the trigger word image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word) # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) print( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) clean_index = 0 clean_input_ids = [] class_token_index = [] # Find out the corresponding class word token based on the newly added trigger word token for i, token_id in enumerate(text_input_ids.tolist()[0]): if token_id == image_token_id: class_token_index.append(clean_index - 1) else: clean_input_ids.append(token_id) clean_index += 1 if len(class_token_index) != 1: raise ValueError( f"PhotoMaker currently does not support multiple trigger words in a single prompt.\ Trigger word: {self.trigger_word}, Prompt: {prompt}." ) class_token_index = class_token_index[0] # Expand the class word token and corresponding mask class_token = clean_input_ids[class_token_index] clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images * self.num_tokens + \ clean_input_ids[class_token_index+1:] # Truncation or padding max_len = tokenizer.model_max_length if len(clean_input_ids) > max_len: clean_input_ids = clean_input_ids[:max_len] else: clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( max_len - len(clean_input_ids) ) class_tokens_mask = [True if class_token_index <= i < class_token_index+(num_id_images * self.num_tokens) else False \ for i in range(len(clean_input_ids))] clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) prompt_embeds = text_encoder(clean_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, class_tokens_mask @property def interrupt(self): return self._interrupt @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, adapter_conditioning_scale: Union[float, List[float]] = 1.0, adapter_conditioning_factor: float = 1.0, clip_skip: Optional[int] = None, # Added parameters (for PhotoMaker) input_id_images: PipelineImageInput = None, start_merge_step: int = 10, # TODO: change to `style_strength_ratio` in the future class_tokens_mask: Optional[torch.LongTensor] = None, id_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_text_only: Optional[torch.FloatTensor] = None, pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Only the parameters introduced by PhotoMaker are discussed here. For explanations of the previous parameters in StableDiffusionXLControlNetPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py Args: input_id_images (`PipelineImageInput`, *optional*): Input ID Image to work with PhotoMaker. class_tokens_mask (`torch.LongTensor`, *optional*): Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. prompt_embeds_text_only (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height, width = self._default_height_width(height, width, image) device = self._execution_device use_adapter = True if image is not None else False print(f"Use adapter: {use_adapter} | output size: {(height, width)}") if use_adapter: if isinstance(self.adapter, MultiAdapter): adapter_input = [] for one_image in image: one_image = _preprocess_adapter_image(one_image, height, width) one_image = one_image.to(device=device, dtype=self.adapter.dtype) adapter_input.append(one_image) else: adapter_input = _preprocess_adapter_image(image, height, width) adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip # if prompt_embeds is not None and class_tokens_mask is None: raise ValueError( "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." ) # check the input id images if input_id_images is None: raise ValueError( "Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline." ) if not isinstance(input_id_images, list): input_id_images = [input_id_images] # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) num_id_images = len(input_id_images) ( prompt_embeds, _, pooled_prompt_embeds, _, class_tokens_mask, ) = self.encode_prompt_with_trigger_word( prompt=prompt, prompt_2=prompt_2, device=device, num_id_images=num_id_images, class_tokens_mask=class_tokens_mask, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self._clip_skip, ) # 4. Encode input prompt without the trigger word for delayed conditioning # encode, remove trigger word token, then decode tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False) trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word) tokens_text_only.remove(trigger_word_token) prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False) ( prompt_embeds_text_only, negative_prompt_embeds, pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt_text_only, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds_text_only, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds_text_only, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self._clip_skip, ) # 5. Prepare the input ID images dtype = next(self.id_encoder.parameters()).dtype if not isinstance(input_id_images[0], torch.Tensor): id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts # 6. Get the update text embedding with the stacked ID embedding if id_embeds is not None: id_embeds = id_embeds.unsqueeze(0).to(device=device, dtype=dtype) prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds) else: prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # 6.1 Get the ip adapter embedding if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 7. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 8. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline if use_adapter: if isinstance(self.adapter, MultiAdapter): adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) for k, v in enumerate(adapter_state): adapter_state[k] = v else: adapter_state = self.adapter(adapter_input) for k, v in enumerate(adapter_state): adapter_state[k] = v * adapter_conditioning_scale if num_images_per_prompt > 1: for k, v in enumerate(adapter_state): adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) if self.do_classifier_free_guidance: for k, v in enumerate(adapter_state): adapter_state[k] = torch.cat([v] * 2, dim=0) add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # Apply denoising_end if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if i <= start_merge_step: current_prompt_embeds = torch.cat( [negative_prompt_embeds, prompt_embeds_text_only], dim=0 ) if self.do_classifier_free_guidance else prompt_embeds_text_only add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0 ) if self.do_classifier_free_guidance else pooled_prompt_embeds_text_only else: current_prompt_embeds = torch.cat( [negative_prompt_embeds, prompt_embeds], dim=0 ) if self.do_classifier_free_guidance else prompt_embeds add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 ) if self.do_classifier_free_guidance else pooled_prompt_embeds if i < int(num_inference_steps * adapter_conditioning_factor) and (use_adapter): down_intrablock_additional_residuals = [state.clone() for state in adapter_state] else: down_intrablock_additional_residuals = None added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=current_prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, down_intrablock_additional_residuals=down_intrablock_additional_residuals, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image)