# Copyright © Alibaba, Inc. and its affiliates. # The implementation here is modifed based on diffusers.StableDiffusionControlNetImg2ImgPipeline, # originally Apache 2.0 License and public available at # https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py import copy import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from diffusers import (AutoencoderKL, DiffusionPipeline, StableDiffusionControlNetImg2ImgPipeline) from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import ControlNetModel try: from diffusers.models.autoencoders.vae import DecoderOutput except: from diffusers.models.vae import DecoderOutput from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging, replace_example_docstring from diffusers.utils.torch_utils import is_compiled_module from transformers import CLIPTokenizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from PIL import Image >>> from txt2panoimage.pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline >>> base_model_path = "models/sr-base" >>> controlnet_path = "models/sr-control" >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained(base_model_path, controlnet=controlnet, ... torch_dtype=torch.float16) >>> pipe.vae.enable_tiling() >>> # remove following line if xformers is not installed >>> pipe.enable_xformers_memory_efficient_attention() >>> pipe.enable_model_cpu_offload() >>> input_image_path = 'data/test.png' >>> image = Image.open(input_image_path) >>> image = pipe( ... "futuristic-looking woman", ... num_inference_steps=20, ... image=image, ... height=768, ... width=1536, ... control_image=image, ... ).images[0] ``` """ re_attention = re.compile( r""" \\\(| \\\)| \\\[| \\]| \\\\| \\| \(| \[| :([+-]?[.\d]+)\)| \)| ]| [^\\()\[\]:]+| : """, re.X, ) def parse_prompt_attention(text): """ Parses a string with attention tokens and returns a list of pairs: text and its associated weight. Accepted tokens are: (abc) - increases attention to abc by a multiplier of 1.1 (abc:3.12) - increases attention to abc by a multiplier of 3.12 [abc] - decreases attention to abc by a multiplier of 1.1 """ res = [] round_brackets = [] square_brackets = [] round_bracket_multiplier = 1.1 square_bracket_multiplier = 1 / 1.1 def multiply_range(start_position, multiplier): for p in range(start_position, len(res)): res[p][1] *= multiplier for m in re_attention.finditer(text): text = m.group(0) weight = m.group(1) if text.startswith('\\'): res.append([text[1:], 1.0]) elif text == '(': round_brackets.append(len(res)) elif text == '[': square_brackets.append(len(res)) elif weight is not None and len(round_brackets) > 0: multiply_range(round_brackets.pop(), float(weight)) elif text == ')' and len(round_brackets) > 0: multiply_range(round_brackets.pop(), round_bracket_multiplier) elif text == ']' and len(square_brackets) > 0: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: res.append([text, 1.0]) for pos in round_brackets: multiply_range(pos, round_bracket_multiplier) for pos in square_brackets: multiply_range(pos, square_bracket_multiplier) if len(res) == 0: res = [['', 1.0]] # merge runs of identical weights i = 0 while i + 1 < len(res): if res[i][1] == res[i + 1][1]: res[i][0] += res[i + 1][0] res.pop(i + 1) else: i += 1 return res def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. """ tokens = [] weights = [] truncated = False for text in prompt: texts_and_weights = parse_prompt_attention(text) text_token = [] text_weight = [] for word, weight in texts_and_weights: # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1] text_token += token # copy the weight by length of token text_weight += [weight] * len(token) # stop if the text is too long (longer than truncation limit) if len(text_token) > max_length: truncated = True break # truncate if len(text_token) > max_length: truncated = True text_token = text_token[:max_length] text_weight = text_weight[:max_length] tokens.append(text_token) weights.append(text_weight) if truncated: logger.warning( 'Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples' ) return tokens, weights def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): r""" Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. """ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length for i in range(len(tokens)): tokens[i] = [ bos ] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos] if no_boseos_middle: weights[i] = [1.0] + weights[i] + [1.0] * ( max_length - 1 - len(weights[i])) else: w = [] if len(weights[i]) == 0: w = [1.0] * weights_length else: for j in range(max_embeddings_multiples): w.append(1.0) # weight for starting token in this chunk w += weights[i][j * (chunk_length - 2):min( len(weights[i]), (j + 1) * (chunk_length - 2))] w.append(1.0) # weight for ending token in this chunk w += [1.0] * (weights_length - len(w)) weights[i] = w[:] return tokens, weights def get_unweighted_text_embeddings( pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True, ): """ When the length of tokens is a multiple of the capacity of the text encoder, it should be split into chunks and sent to the text encoder individually. """ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) if max_embeddings_multiples > 1: text_embeddings = [] for i in range(max_embeddings_multiples): # extract the i-th chunk text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1) * (chunk_length - 2) + 2].clone() # cover the head and the tail by the starting and the ending tokens text_input_chunk[:, 0] = text_input[0, 0] text_input_chunk[:, -1] = text_input[0, -1] text_embedding = pipe.text_encoder(text_input_chunk)[0] if no_boseos_middle: if i == 0: # discard the ending token text_embedding = text_embedding[:, :-1] elif i == max_embeddings_multiples - 1: # discard the starting token text_embedding = text_embedding[:, 1:] else: # discard both starting and ending tokens text_embedding = text_embedding[:, 1:-1] text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: text_embeddings = pipe.text_encoder(text_input)[0] return text_embeddings def get_weighted_text_embeddings( pipe: DiffusionPipeline, prompt: Union[str, List[str]], uncond_prompt: Optional[Union[str, List[str]]] = None, max_embeddings_multiples: Optional[int] = 3, no_boseos_middle: Optional[bool] = False, skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, ): r""" Prompts can be assigned with local weights using brackets. For example, prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. Args: pipe (`DiffusionPipeline`): Pipe to provide access to the tokenizer and the text encoder. prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. uncond_prompt (`str` or `List[str]`): The unconditional prompt or prompts for guide the image generation. If unconditional prompt is provided, the embeddings of prompt and uncond_prompt are concatenated. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. no_boseos_middle (`bool`, *optional*, defaults to `False`): If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and ending token in each of the chunk in the middle. skip_parsing (`bool`, *optional*, defaults to `False`): Skip the parsing of brackets. skip_weighting (`bool`, *optional*, defaults to `False`): Skip the weighting. When the parsing is skipped, it is forced True. """ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 if isinstance(prompt, str): prompt = [prompt] if not skip_parsing: prompt_tokens, prompt_weights = get_prompts_with_weights( pipe, prompt, max_length - 2) if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] uncond_tokens, uncond_weights = get_prompts_with_weights( pipe, uncond_prompt, max_length - 2) else: prompt_tokens = [ token[1:-1] for token in pipe.tokenizer( prompt, max_length=max_length, truncation=True).input_ids ] prompt_weights = [[1.0] * len(token) for token in prompt_tokens] if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] uncond_tokens = [ token[1:-1] for token in pipe.tokenizer( uncond_prompt, max_length=max_length, truncation=True).input_ids ] uncond_weights = [[1.0] * len(token) for token in uncond_tokens] # round up the longest length of tokens to a multiple of (model_max_length - 2) max_length = max([len(token) for token in prompt_tokens]) if uncond_prompt is not None: max_length = max(max_length, max([len(token) for token in uncond_tokens])) max_embeddings_multiples = min( max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, ) max_embeddings_multiples = max(1, max_embeddings_multiples) max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 # pad the length of tokens and weights bos = pipe.tokenizer.bos_token_id eos = pipe.tokenizer.eos_token_id pad = getattr(pipe.tokenizer, 'pad_token_id', eos) prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights, max_length, bos, eos, pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) prompt_tokens = torch.tensor( prompt_tokens, dtype=torch.long, device=pipe.device) if uncond_prompt is not None: uncond_tokens, uncond_weights = pad_tokens_and_weights( uncond_tokens, uncond_weights, max_length, bos, eos, pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) uncond_tokens = torch.tensor( uncond_tokens, dtype=torch.long, device=pipe.device) # get the embeddings text_embeddings = get_unweighted_text_embeddings( pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, ) prompt_weights = torch.tensor( prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) if uncond_prompt is not None: uncond_embeddings = get_unweighted_text_embeddings( pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, ) uncond_weights = torch.tensor( uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) # assign weights to the prompts and normalize in the sense of mean # TODO: should we normalize by chunk or in a whole (current implementation)? if (not skip_parsing) and (not skip_weighting): previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to( text_embeddings.dtype) text_embeddings *= prompt_weights.unsqueeze(-1) current_mean = text_embeddings.float().mean(axis=[-2, -1]).to( text_embeddings.dtype) text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if uncond_prompt is not None: previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to( uncond_embeddings.dtype) uncond_embeddings *= uncond_weights.unsqueeze(-1) current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to( uncond_embeddings.dtype) uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if uncond_prompt is not None: return text_embeddings, uncond_embeddings return text_embeddings, None def prepare_image(image): if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: image = image.unsqueeze(0) image = image.to(dtype=torch.float32) else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert('RGB'))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image class StableDiffusionControlNetImg2ImgPanoPipeline( StableDiffusionControlNetImg2ImgPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/ model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ['safety_checker', 'feature_extractor'] def check_inputs( self, prompt, image, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError( f'`height` and `width` have to be divisible by 8 but are {height} and {width}.' ) condition_1 = callback_steps is not None condition_2 = not isinstance(callback_steps, int) or callback_steps <= 0 if (callback_steps is None) or (condition_1 and condition_2): raise ValueError( f'`callback_steps` has to be a positive integer but is {callback_steps} of type' f' {type(callback_steps)}.') if prompt is not None and prompt_embeds is not None: raise ValueError( f'Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to' ' only forward one of the two.') elif prompt is None and prompt_embeds is None: raise ValueError( 'Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.' ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError( f'`prompt` has to be of type `str` or `list` but is {type(prompt)}' ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f'Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:' f' {negative_prompt_embeds}. Please make sure to only forward one of the two.' ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( '`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but' f' got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`' f' {negative_prompt_embeds.shape}.') # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f'You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}' ' prompts. The conditionings will be fixed across the prompts.' ) # Check `image` is_compiled = hasattr( F, 'scaled_dot_product_attention') and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule) if (isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel)): self.check_image(image, prompt, prompt_embeds) elif (isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel)): if not isinstance(image, list): raise TypeError( 'For multiple controlnets: `image` must be type `list`') # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError( 'A single batch of multiple conditionings are supported at the moment.' ) elif len(image) != len(self.controlnet.nets): raise ValueError( 'For multiple controlnets: `image` must have the same length as the number of controlnets.' ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if (isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel)): if not isinstance(controlnet_conditioning_scale, float): raise TypeError( 'For single controlnet: `controlnet_conditioning_scale` must be type `float`.' ) elif (isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel)): if isinstance(controlnet_conditioning_scale, list): if any( isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError( 'A single batch of multiple conditionings are supported at the moment.' ) elif isinstance( controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets): raise ValueError( 'For multiple controlnets: When `controlnet_conditioning_scale` ' 'is specified as `list`, it must have' ' the same length as the number of controlnets') else: assert False def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image # is not _exactly_ correct, but it is simple. while isinstance(image, list): image = image[0] if height is None: if isinstance(image, PIL.Image.Image): height = image.height elif isinstance(image, torch.Tensor): height = image.shape[2] height = (height // 8) * 8 # round down to nearest multiple of 8 if width is None: if isinstance(image, PIL.Image.Image): width = image.width elif isinstance(image, torch.Tensor): width = image.shape[3] width = (width // 8) * 8 # round down to nearest multiple of 8 return height, width def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, max_embeddings_multiples=3, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `list(int)`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. """ if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if negative_prompt_embeds is None: if negative_prompt is None: negative_prompt = [''] * batch_size elif isinstance(negative_prompt, str): negative_prompt = [negative_prompt] * batch_size if batch_size != len(negative_prompt): raise ValueError( f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:' f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches' ' the batch size of `prompt`.') if prompt_embeds is None or negative_prompt_embeds is None: if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = self.maybe_convert_prompt( negative_prompt, self.tokenizer) prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings( pipe=self, prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, ) if prompt_embeds is None: prompt_embeds = prompt_embeds1 if negative_prompt_embeds is None: negative_prompt_embeds = negative_prompt_embeds1 bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: bs_embed, seq_len, _ = negative_prompt_embeds.shape negative_prompt_embeds = negative_prompt_embeds.repeat( 1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def denoise_latents(self, latents, t, prompt_embeds, control_image, controlnet_conditioning_scale, guess_mode, cross_attention_kwargs, do_classifier_free_guidance, guidance_scale, extra_step_kwargs, views_scheduler_status): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents self.scheduler.__dict__.update(views_scheduler_status[0]) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t) # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. controlnet_latent_model_input = latents controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: controlnet_latent_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds down_block_res_samples, mid_block_res_sample = self.controlnet( controlnet_latent_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=controlnet_conditioning_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [ torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples ] mid_block_res_sample = torch.cat( [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] return latents def blend_v(self, a, b, blend_extent): blend_extent = min(a.shape[2], b.shape[2], blend_extent) for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * ( y / blend_extent) return b def blend_h(self, a, b, blend_extent): blend_extent = min(a.shape[3], b.shape[3], blend_extent) for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * ( x / blend_extent) return b def get_blocks(self, latents, control_image, tile_latent_min_size, overlap_size): rows_latents = [] rows_control_images = [] for i in range(0, latents.shape[2] - overlap_size, overlap_size): row_latents = [] row_control_images = [] for j in range(0, latents.shape[3] - overlap_size, overlap_size): latents_input = latents[:, :, i:i + tile_latent_min_size, j:j + tile_latent_min_size] c_start_i = self.vae_scale_factor * i c_end_i = self.vae_scale_factor * (i + tile_latent_min_size) c_start_j = self.vae_scale_factor * j c_end_j = self.vae_scale_factor * (j + tile_latent_min_size) control_image_input = control_image[:, :, c_start_i:c_end_i, c_start_j:c_end_j] row_latents.append(latents_input) row_control_images.append(control_image_input) rows_latents.append(row_latents) rows_control_images.append(row_control_images) return rows_latents, rows_control_images @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = 'pil', return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, context_size: int = 768, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/ src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting than for [`~StableDiffusionControlNetPipeline.__call__`]. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. context_size ('int', *optional*, defaults to '768'): tiled size when denoise the latents. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ def tiled_decode( self, z: torch.FloatTensor, return_dict: bool = True ) -> Union[DecoderOutput, torch.FloatTensor]: r"""Decode a batch of images using a tiled decoder. Args: When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the look of the output, but they should be much less noticeable. z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ _tile_overlap_factor = 1 - self.tile_overlap_factor overlap_size = int(self.tile_latent_min_size * _tile_overlap_factor) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_extent w = z.shape[3] z = torch.cat([z, z[:, :, :, :w // 4]], dim=-1) # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] for i in range(0, z.shape[2], overlap_size): row = [] tile = z[:, :, i:i + self.tile_latent_min_size, :] tile = self.post_quant_conv(tile) decoded = self.decoder(tile) vae_scale_factor = decoded.shape[-1] // tile.shape[-1] row.append(decoded) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append( self.blend_h( tile[:, :, :row_limit, w * vae_scale_factor:], tile[:, :, :row_limit, :w * vae_scale_factor], tile.shape[-1] - w * vae_scale_factor)) result_rows.append(torch.cat(result_row, dim=3)) dec = torch.cat(result_rows, dim=2) if not return_dict: return (dec, ) return DecoderOutput(sample=dec) self.vae.tiled_decode = tiled_decode.__get__(self.vae, AutoencoderKL) # 0. Default height and width to unet height, width = self._default_height_width(height, width, image) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 controlnet = self.controlnet._orig_mod if is_compiled_module( self.controlnet) else self.controlnet if isinstance(controlnet, MultiControlNetModel) and isinstance( controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance( controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare image, and controlnet_conditioning_image image = prepare_image(image) # 5. Prepare image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] # value = torch.zeros_like(latents) _, _, height, width = control_image.size() tile_latent_min_size = context_size // self.vae_scale_factor tile_overlap_factor = 0.5 overlap_size = int(tile_latent_min_size * (1 - tile_overlap_factor)) blend_extent = int(tile_latent_min_size * tile_overlap_factor) row_limit = tile_latent_min_size - blend_extent w = latents.shape[3] latents = torch.cat([latents, latents[:, :, :, :overlap_size]], dim=-1) control_image_extend = control_image[:, :, :, :overlap_size * self.vae_scale_factor] control_image = torch.cat([control_image, control_image_extend], dim=-1) # 8. Denoising loop num_warmup_steps = len( timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latents_input, control_image_input = self.get_blocks( latents, control_image, tile_latent_min_size, overlap_size) rows = [] for latents_input_, control_image_input_ in zip( latents_input, control_image_input): num_block = len(latents_input_) # get batched latents_input latents_input_ = torch.cat( latents_input_[:num_block], dim=0) # get batched prompt_embeds prompt_embeds_ = torch.cat( [prompt_embeds.chunk(2)[0]] * num_block + [prompt_embeds.chunk(2)[1]] * num_block, dim=0) # get batched control_image_input control_image_input_ = torch.cat( [ x[0, :, :, ][None, :, :, :] for x in control_image_input_[:num_block] ] + [ x[1, :, :, ][None, :, :, :] for x in control_image_input_[:num_block] ], dim=0) latents_output = self.denoise_latents( latents_input_, t, prompt_embeds_, control_image_input_, controlnet_conditioning_scale, guess_mode, cross_attention_kwargs, do_classifier_free_guidance, guidance_scale, extra_step_kwargs, views_scheduler_status) rows.append(list(latents_output.chunk(num_block))) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) if j == 0: tile = self.blend_h(row[-1], tile, blend_extent) if i != len(rows) - 1: if j == len(row) - 1: result_row.append(tile[:, :, :row_limit, :]) else: result_row.append( tile[:, :, :row_limit, :row_limit]) else: if j == len(row) - 1: result_row.append(tile[:, :, :, :]) else: result_row.append(tile[:, :, :, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) latents = torch.cat(result_rows, dim=2) # call the callback, if provided condition_i = i == len(timesteps) - 1 condition_warm = (i + 1) > num_warmup_steps and ( i + 1) % self.scheduler.order == 0 if condition_i or condition_warm: progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = latents[:, :, :, :w] # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr( self, 'final_offload_hook') and self.final_offload_hook is not None: self.unet.to('cpu') self.controlnet.to('cpu') torch.cuda.empty_cache() if not output_type == 'latent': image = self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker( image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess( image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr( self, 'final_offload_hook') and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept)