skytnt commited on
Commit
3f924fb
·
1 Parent(s): 086a8d8

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +545 -237
pipeline.py CHANGED
@@ -1,39 +1,30 @@
1
  import inspect
2
  import re
3
- from typing import Callable, List, Optional, Union
4
 
5
  import numpy as np
 
6
  import torch
 
 
7
 
8
- import diffusers
9
- import PIL
10
- from diffusers import SchedulerMixin, StableDiffusionPipeline
 
11
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
13
- from diffusers.utils import deprecate, logging
14
- from packaging import version
15
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
16
-
17
-
18
- try:
19
- from diffusers.utils import PIL_INTERPOLATION
20
- except ImportError:
21
- if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
22
- PIL_INTERPOLATION = {
23
- "linear": PIL.Image.Resampling.BILINEAR,
24
- "bilinear": PIL.Image.Resampling.BILINEAR,
25
- "bicubic": PIL.Image.Resampling.BICUBIC,
26
- "lanczos": PIL.Image.Resampling.LANCZOS,
27
- "nearest": PIL.Image.Resampling.NEAREST,
28
- }
29
- else:
30
- PIL_INTERPOLATION = {
31
- "linear": PIL.Image.LINEAR,
32
- "bilinear": PIL.Image.BILINEAR,
33
- "bicubic": PIL.Image.BICUBIC,
34
- "lanczos": PIL.Image.LANCZOS,
35
- "nearest": PIL.Image.NEAREST,
36
- }
37
  # ------------------------------------------------------------------------------
38
 
39
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -144,7 +135,7 @@ def parse_prompt_attention(text):
144
  return res
145
 
146
 
147
- def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
148
  r"""
149
  Tokenize a list of prompts and return its tokens with weights of each token.
150
 
@@ -179,14 +170,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m
179
  return tokens, weights
180
 
181
 
182
- def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
183
  r"""
184
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
185
  """
186
  max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
187
  weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
188
  for i in range(len(tokens)):
189
- tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
190
  if no_boseos_middle:
191
  weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
192
  else:
@@ -205,7 +196,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
205
 
206
 
207
  def get_unweighted_text_embeddings(
208
- pipe: StableDiffusionPipeline,
209
  text_input: torch.Tensor,
210
  chunk_length: int,
211
  no_boseos_middle: Optional[bool] = True,
@@ -245,14 +236,13 @@ def get_unweighted_text_embeddings(
245
 
246
 
247
  def get_weighted_text_embeddings(
248
- pipe: StableDiffusionPipeline,
249
  prompt: Union[str, List[str]],
250
  uncond_prompt: Optional[Union[str, List[str]]] = None,
251
  max_embeddings_multiples: Optional[int] = 3,
252
  no_boseos_middle: Optional[bool] = False,
253
  skip_parsing: Optional[bool] = False,
254
  skip_weighting: Optional[bool] = False,
255
- **kwargs,
256
  ):
257
  r"""
258
  Prompts can be assigned with local weights using brackets. For example,
@@ -262,7 +252,7 @@ def get_weighted_text_embeddings(
262
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
263
 
264
  Args:
265
- pipe (`StableDiffusionPipeline`):
266
  Pipe to provide access to the tokenizer and the text encoder.
267
  prompt (`str` or `List[str]`):
268
  The prompt or prompts to guide the image generation.
@@ -318,12 +308,14 @@ def get_weighted_text_embeddings(
318
  # pad the length of tokens and weights
319
  bos = pipe.tokenizer.bos_token_id
320
  eos = pipe.tokenizer.eos_token_id
 
321
  prompt_tokens, prompt_weights = pad_tokens_and_weights(
322
  prompt_tokens,
323
  prompt_weights,
324
  max_length,
325
  bos,
326
  eos,
 
327
  no_boseos_middle=no_boseos_middle,
328
  chunk_length=pipe.tokenizer.model_max_length,
329
  )
@@ -335,6 +327,7 @@ def get_weighted_text_embeddings(
335
  max_length,
336
  bos,
337
  eos,
 
338
  no_boseos_middle=no_boseos_middle,
339
  chunk_length=pipe.tokenizer.model_max_length,
340
  )
@@ -375,30 +368,50 @@ def get_weighted_text_embeddings(
375
  return text_embeddings, None
376
 
377
 
378
- def preprocess_image(image):
379
  w, h = image.size
380
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
381
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
382
  image = np.array(image).astype(np.float32) / 255.0
383
- image = image[None].transpose(0, 3, 1, 2)
384
  image = torch.from_numpy(image)
385
  return 2.0 * image - 1.0
386
 
387
 
388
- def preprocess_mask(mask, scale_factor=8):
389
- mask = mask.convert("L")
390
- w, h = mask.size
391
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
392
- mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
393
- mask = np.array(mask).astype(np.float32) / 255.0
394
- mask = np.tile(mask, (4, 1, 1))
395
- mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
396
- mask = 1 - mask # repaint white, keep black
397
- mask = torch.from_numpy(mask)
398
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
 
401
- class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
 
 
402
  r"""
403
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
404
  weighting in prompt.
@@ -423,70 +436,200 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
423
  safety_checker ([`StableDiffusionSafetyChecker`]):
424
  Classification module that estimates whether generated images could be considered offensive or harmful.
425
  Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
426
- feature_extractor ([`CLIPFeatureExtractor`]):
427
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
428
  """
429
 
430
- if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
431
-
432
- def __init__(
433
- self,
434
- vae: AutoencoderKL,
435
- text_encoder: CLIPTextModel,
436
- tokenizer: CLIPTokenizer,
437
- unet: UNet2DConditionModel,
438
- scheduler: SchedulerMixin,
439
- safety_checker: StableDiffusionSafetyChecker,
440
- feature_extractor: CLIPFeatureExtractor,
441
- requires_safety_checker: bool = True,
442
- ):
443
- super().__init__(
444
- vae=vae,
445
- text_encoder=text_encoder,
446
- tokenizer=tokenizer,
447
- unet=unet,
448
- scheduler=scheduler,
449
- safety_checker=safety_checker,
450
- feature_extractor=feature_extractor,
451
- requires_safety_checker=requires_safety_checker,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  )
453
- self.__init__additional__()
454
 
455
- else:
 
 
 
 
456
 
457
- def __init__(
458
- self,
459
- vae: AutoencoderKL,
460
- text_encoder: CLIPTextModel,
461
- tokenizer: CLIPTokenizer,
462
- unet: UNet2DConditionModel,
463
- scheduler: SchedulerMixin,
464
- safety_checker: StableDiffusionSafetyChecker,
465
- feature_extractor: CLIPFeatureExtractor,
466
- ):
467
- super().__init__(
468
- vae=vae,
469
- text_encoder=text_encoder,
470
- tokenizer=tokenizer,
471
- unet=unet,
472
- scheduler=scheduler,
473
- safety_checker=safety_checker,
474
- feature_extractor=feature_extractor,
475
  )
476
- self.__init__additional__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
- def __init__additional__(self):
479
- if not hasattr(self, "vae_scale_factor"):
480
- setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
  @property
 
483
  def _execution_device(self):
484
  r"""
485
  Returns the device on which the pipeline's models will be executed. After calling
486
  `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
487
  hooks.
488
  """
489
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
490
  return self.device
491
  for module in self.unet.modules():
492
  if (
@@ -503,8 +646,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
503
  device,
504
  num_images_per_prompt,
505
  do_classifier_free_guidance,
506
- negative_prompt,
507
- max_embeddings_multiples,
 
 
508
  ):
509
  r"""
510
  Encodes the prompt into text encoder hidden states.
@@ -524,47 +669,71 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
524
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
525
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
526
  """
527
- batch_size = len(prompt) if isinstance(prompt, list) else 1
528
-
529
- if negative_prompt is None:
530
- negative_prompt = [""] * batch_size
531
- elif isinstance(negative_prompt, str):
532
- negative_prompt = [negative_prompt] * batch_size
533
- if batch_size != len(negative_prompt):
534
- raise ValueError(
535
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
536
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
537
- " the batch size of `prompt`."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  )
 
 
 
 
539
 
540
- text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
541
- pipe=self,
542
- prompt=prompt,
543
- uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
544
- max_embeddings_multiples=max_embeddings_multiples,
545
- )
546
- bs_embed, seq_len, _ = text_embeddings.shape
547
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
548
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
549
 
550
  if do_classifier_free_guidance:
551
- bs_embed, seq_len, _ = uncond_embeddings.shape
552
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
553
- uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
554
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
555
 
556
- return text_embeddings
557
 
558
- def check_inputs(self, prompt, height, width, strength, callback_steps):
559
- if not isinstance(prompt, str) and not isinstance(prompt, list):
560
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
 
 
 
 
 
 
 
 
 
561
 
562
  if strength < 0 or strength > 1:
563
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
564
 
565
- if height % 8 != 0 or width % 8 != 0:
566
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
567
-
568
  if (callback_steps is None) or (
569
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
570
  ):
@@ -573,17 +742,42 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
573
  f" {type(callback_steps)}."
574
  )
575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
577
  if is_text2img:
578
  return self.scheduler.timesteps.to(device), num_inference_steps
579
  else:
580
  # get the original timestep using init_timestep
581
- offset = self.scheduler.config.get("steps_offset", 0)
582
- init_timestep = int(num_inference_steps * strength) + offset
583
- init_timestep = min(init_timestep, num_inference_steps)
 
584
 
585
- t_start = max(num_inference_steps - init_timestep + offset, 0)
586
- timesteps = self.scheduler.timesteps[t_start:].to(device)
587
  return timesteps, num_inference_steps - t_start
588
 
589
  def run_safety_checker(self, image, device, dtype):
@@ -597,10 +791,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
597
  return image, has_nsfw_concept
598
 
599
  def decode_latents(self, latents):
600
- latents = 1 / 0.18215 * latents
601
  image = self.vae.decode(latents).sample
602
  image = (image / 2 + 0.5).clamp(0, 1)
603
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
604
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
605
  return image
606
 
@@ -621,43 +815,51 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
621
  extra_step_kwargs["generator"] = generator
622
  return extra_step_kwargs
623
 
624
- def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  if image is None:
626
- shape = (
627
- batch_size,
628
- self.unet.in_channels,
629
- height // self.vae_scale_factor,
630
- width // self.vae_scale_factor,
631
- )
 
632
 
633
  if latents is None:
634
- if device.type == "mps":
635
- # randn does not work reproducibly on mps
636
- latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
637
- else:
638
- latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
639
  else:
640
- if latents.shape != shape:
641
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
642
  latents = latents.to(device)
643
 
644
  # scale the initial noise by the standard deviation required by the scheduler
645
  latents = latents * self.scheduler.init_noise_sigma
646
  return latents, None, None
647
  else:
 
648
  init_latent_dist = self.vae.encode(image).latent_dist
649
  init_latents = init_latent_dist.sample(generator=generator)
650
- init_latents = 0.18215 * init_latents
651
- init_latents = torch.cat([init_latents] * batch_size, dim=0)
 
 
652
  init_latents_orig = init_latents
653
- shape = init_latents.shape
654
 
655
  # add noise to latents using the timesteps
656
- if device.type == "mps":
657
- noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
658
- else:
659
- noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
660
- latents = self.scheduler.add_noise(init_latents, noise, timestep)
661
  return latents, init_latents_orig, noise
662
 
663
  @torch.no_grad()
@@ -673,16 +875,19 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
673
  guidance_scale: float = 7.5,
674
  strength: float = 0.8,
675
  num_images_per_prompt: Optional[int] = 1,
 
676
  eta: float = 0.0,
677
- generator: Optional[torch.Generator] = None,
678
  latents: Optional[torch.FloatTensor] = None,
 
 
679
  max_embeddings_multiples: Optional[int] = 3,
680
  output_type: Optional[str] = "pil",
681
  return_dict: bool = True,
682
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
683
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
684
- callback_steps: Optional[int] = 1,
685
- **kwargs,
686
  ):
687
  r"""
688
  Function invoked when calling the pipeline for generation.
@@ -722,16 +927,26 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
722
  `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
723
  num_images_per_prompt (`int`, *optional*, defaults to 1):
724
  The number of images to generate per prompt.
 
 
 
725
  eta (`float`, *optional*, defaults to 0.0):
726
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
727
  [`schedulers.DDIMScheduler`], will be ignored for others.
728
- generator (`torch.Generator`, *optional*):
729
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
730
- deterministic.
731
  latents (`torch.FloatTensor`, *optional*):
732
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
733
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
734
  tensor will ge generated by sampling using the supplied random `generator`.
 
 
 
 
 
 
 
735
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
736
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
737
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -749,6 +964,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
749
  callback_steps (`int`, *optional*, defaults to 1):
750
  The frequency at which the `callback` function will be called. If not specified, the callback will be
751
  called at every step.
 
 
 
 
752
 
753
  Returns:
754
  `None` if cancelled by `is_cancelled_callback`,
@@ -758,19 +977,23 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
758
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
759
  (nsfw) content, according to the `safety_checker`.
760
  """
761
- message = "Please use `image` instead of `init_image`."
762
- init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
763
- image = init_image or image
764
-
765
  # 0. Default height and width to unet
766
  height = height or self.unet.config.sample_size * self.vae_scale_factor
767
  width = width or self.unet.config.sample_size * self.vae_scale_factor
768
 
769
  # 1. Check inputs. Raise error if not correct
770
- self.check_inputs(prompt, height, width, strength, callback_steps)
 
 
771
 
772
  # 2. Define call parameters
773
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
 
 
 
 
 
 
774
  device = self._execution_device
775
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
776
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -778,26 +1001,28 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
778
  do_classifier_free_guidance = guidance_scale > 1.0
779
 
780
  # 3. Encode input prompt
781
- text_embeddings = self._encode_prompt(
782
  prompt,
783
  device,
784
  num_images_per_prompt,
785
  do_classifier_free_guidance,
786
  negative_prompt,
787
  max_embeddings_multiples,
 
 
788
  )
789
- dtype = text_embeddings.dtype
790
 
791
  # 4. Preprocess image and mask
792
  if isinstance(image, PIL.Image.Image):
793
- image = preprocess_image(image)
794
  if image is not None:
795
  image = image.to(device=self.device, dtype=dtype)
796
  if isinstance(mask_image, PIL.Image.Image):
797
- mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
798
  if mask_image is not None:
799
  mask = mask_image.to(device=self.device, dtype=dtype)
800
- mask = torch.cat([mask] * batch_size * num_images_per_prompt)
801
  else:
802
  mask = None
803
 
@@ -810,7 +1035,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
810
  latents, init_latents_orig, noise = self.prepare_latents(
811
  image,
812
  latent_timestep,
813
- batch_size * num_images_per_prompt,
 
 
814
  height,
815
  width,
816
  dtype,
@@ -823,43 +1050,70 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
823
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
824
 
825
  # 8. Denoising loop
826
- for i, t in enumerate(self.progress_bar(timesteps)):
827
- # expand the latents if we are doing classifier free guidance
828
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
829
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
830
-
831
- # predict the noise residual
832
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
833
-
834
- # perform guidance
835
- if do_classifier_free_guidance:
836
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
837
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
838
-
839
- # compute the previous noisy sample x_t -> x_t-1
840
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
841
-
842
- if mask is not None:
843
- # masking
844
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
845
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
846
-
847
- # call the callback, if provided
848
- if i % callback_steps == 0:
849
- if callback is not None:
850
- callback(i, t, latents)
851
- if is_cancelled_callback is not None and is_cancelled_callback():
852
- return None
853
-
854
- # 9. Post-processing
855
- image = self.decode_latents(latents)
856
-
857
- # 10. Run safety checker
858
- image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
859
-
860
- # 11. Convert to PIL
861
- if output_type == "pil":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  image = self.numpy_to_pil(image)
 
 
 
 
 
 
 
 
 
 
863
 
864
  if not return_dict:
865
  return image, has_nsfw_concept
@@ -876,15 +1130,17 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
876
  guidance_scale: float = 7.5,
877
  num_images_per_prompt: Optional[int] = 1,
878
  eta: float = 0.0,
879
- generator: Optional[torch.Generator] = None,
880
  latents: Optional[torch.FloatTensor] = None,
 
 
881
  max_embeddings_multiples: Optional[int] = 3,
882
  output_type: Optional[str] = "pil",
883
  return_dict: bool = True,
884
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
885
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
886
- callback_steps: Optional[int] = 1,
887
- **kwargs,
888
  ):
889
  r"""
890
  Function for text-to-image generation.
@@ -912,13 +1168,20 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
912
  eta (`float`, *optional*, defaults to 0.0):
913
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
914
  [`schedulers.DDIMScheduler`], will be ignored for others.
915
- generator (`torch.Generator`, *optional*):
916
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
917
- deterministic.
918
  latents (`torch.FloatTensor`, *optional*):
919
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
920
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
921
  tensor will ge generated by sampling using the supplied random `generator`.
 
 
 
 
 
 
 
922
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
923
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
924
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -936,7 +1199,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
936
  callback_steps (`int`, *optional*, defaults to 1):
937
  The frequency at which the `callback` function will be called. If not specified, the callback will be
938
  called at every step.
 
 
 
 
 
939
  Returns:
 
940
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
941
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
942
  When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -954,13 +1223,15 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
954
  eta=eta,
955
  generator=generator,
956
  latents=latents,
 
 
957
  max_embeddings_multiples=max_embeddings_multiples,
958
  output_type=output_type,
959
  return_dict=return_dict,
960
  callback=callback,
961
  is_cancelled_callback=is_cancelled_callback,
962
  callback_steps=callback_steps,
963
- **kwargs,
964
  )
965
 
966
  def img2img(
@@ -973,14 +1244,16 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
973
  guidance_scale: Optional[float] = 7.5,
974
  num_images_per_prompt: Optional[int] = 1,
975
  eta: Optional[float] = 0.0,
976
- generator: Optional[torch.Generator] = None,
 
 
977
  max_embeddings_multiples: Optional[int] = 3,
978
  output_type: Optional[str] = "pil",
979
  return_dict: bool = True,
980
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
981
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
982
- callback_steps: Optional[int] = 1,
983
- **kwargs,
984
  ):
985
  r"""
986
  Function for image-to-image generation.
@@ -1013,9 +1286,16 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1013
  eta (`float`, *optional*, defaults to 0.0):
1014
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1015
  [`schedulers.DDIMScheduler`], will be ignored for others.
1016
- generator (`torch.Generator`, *optional*):
1017
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1018
- deterministic.
 
 
 
 
 
 
 
1019
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1020
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1021
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1033,8 +1313,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1033
  callback_steps (`int`, *optional*, defaults to 1):
1034
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1035
  called at every step.
 
 
 
 
 
1036
  Returns:
1037
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1038
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1039
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1040
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
@@ -1050,13 +1335,15 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1050
  num_images_per_prompt=num_images_per_prompt,
1051
  eta=eta,
1052
  generator=generator,
 
 
1053
  max_embeddings_multiples=max_embeddings_multiples,
1054
  output_type=output_type,
1055
  return_dict=return_dict,
1056
  callback=callback,
1057
  is_cancelled_callback=is_cancelled_callback,
1058
  callback_steps=callback_steps,
1059
- **kwargs,
1060
  )
1061
 
1062
  def inpaint(
@@ -1069,15 +1356,18 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1069
  num_inference_steps: Optional[int] = 50,
1070
  guidance_scale: Optional[float] = 7.5,
1071
  num_images_per_prompt: Optional[int] = 1,
 
1072
  eta: Optional[float] = 0.0,
1073
- generator: Optional[torch.Generator] = None,
 
 
1074
  max_embeddings_multiples: Optional[int] = 3,
1075
  output_type: Optional[str] = "pil",
1076
  return_dict: bool = True,
1077
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1078
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
1079
- callback_steps: Optional[int] = 1,
1080
- **kwargs,
1081
  ):
1082
  r"""
1083
  Function for inpaint.
@@ -1111,12 +1401,22 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1111
  usually at the expense of lower image quality.
1112
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1113
  The number of images to generate per prompt.
 
 
 
1114
  eta (`float`, *optional*, defaults to 0.0):
1115
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1116
  [`schedulers.DDIMScheduler`], will be ignored for others.
1117
- generator (`torch.Generator`, *optional*):
1118
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1119
- deterministic.
 
 
 
 
 
 
 
1120
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1121
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1122
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1134,8 +1434,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1134
  callback_steps (`int`, *optional*, defaults to 1):
1135
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1136
  called at every step.
 
 
 
 
 
1137
  Returns:
1138
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1139
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1140
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1141
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
@@ -1150,13 +1455,16 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1150
  guidance_scale=guidance_scale,
1151
  strength=strength,
1152
  num_images_per_prompt=num_images_per_prompt,
 
1153
  eta=eta,
1154
  generator=generator,
 
 
1155
  max_embeddings_multiples=max_embeddings_multiples,
1156
  output_type=output_type,
1157
  return_dict=return_dict,
1158
  callback=callback,
1159
  is_cancelled_callback=is_cancelled_callback,
1160
  callback_steps=callback_steps,
1161
- **kwargs,
1162
  )
 
1
  import inspect
2
  import re
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
 
5
  import numpy as np
6
+ import PIL
7
  import torch
8
+ from packaging import version
9
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
 
11
+ from diffusers import DiffusionPipeline
12
+ from diffusers.configuration_utils import FrozenDict
13
+ from diffusers.image_processor import VaeImageProcessor
14
+ from diffusers.loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
15
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
16
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils import (
19
+ PIL_INTERPOLATION,
20
+ deprecate,
21
+ is_accelerate_available,
22
+ is_accelerate_version,
23
+ logging,
24
+ randn_tensor,
25
+ )
26
+
27
+
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # ------------------------------------------------------------------------------
29
 
30
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
135
  return res
136
 
137
 
138
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
139
  r"""
140
  Tokenize a list of prompts and return its tokens with weights of each token.
141
 
 
170
  return tokens, weights
171
 
172
 
173
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
174
  r"""
175
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
176
  """
177
  max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
178
  weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
179
  for i in range(len(tokens)):
180
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
181
  if no_boseos_middle:
182
  weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
183
  else:
 
196
 
197
 
198
  def get_unweighted_text_embeddings(
199
+ pipe: DiffusionPipeline,
200
  text_input: torch.Tensor,
201
  chunk_length: int,
202
  no_boseos_middle: Optional[bool] = True,
 
236
 
237
 
238
  def get_weighted_text_embeddings(
239
+ pipe: DiffusionPipeline,
240
  prompt: Union[str, List[str]],
241
  uncond_prompt: Optional[Union[str, List[str]]] = None,
242
  max_embeddings_multiples: Optional[int] = 3,
243
  no_boseos_middle: Optional[bool] = False,
244
  skip_parsing: Optional[bool] = False,
245
  skip_weighting: Optional[bool] = False,
 
246
  ):
247
  r"""
248
  Prompts can be assigned with local weights using brackets. For example,
 
252
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
253
 
254
  Args:
255
+ pipe (`DiffusionPipeline`):
256
  Pipe to provide access to the tokenizer and the text encoder.
257
  prompt (`str` or `List[str]`):
258
  The prompt or prompts to guide the image generation.
 
308
  # pad the length of tokens and weights
309
  bos = pipe.tokenizer.bos_token_id
310
  eos = pipe.tokenizer.eos_token_id
311
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
312
  prompt_tokens, prompt_weights = pad_tokens_and_weights(
313
  prompt_tokens,
314
  prompt_weights,
315
  max_length,
316
  bos,
317
  eos,
318
+ pad,
319
  no_boseos_middle=no_boseos_middle,
320
  chunk_length=pipe.tokenizer.model_max_length,
321
  )
 
327
  max_length,
328
  bos,
329
  eos,
330
+ pad,
331
  no_boseos_middle=no_boseos_middle,
332
  chunk_length=pipe.tokenizer.model_max_length,
333
  )
 
368
  return text_embeddings, None
369
 
370
 
371
+ def preprocess_image(image, batch_size):
372
  w, h = image.size
373
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
374
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
375
  image = np.array(image).astype(np.float32) / 255.0
376
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
377
  image = torch.from_numpy(image)
378
  return 2.0 * image - 1.0
379
 
380
 
381
+ def preprocess_mask(mask, batch_size, scale_factor=8):
382
+ if not isinstance(mask, torch.FloatTensor):
383
+ mask = mask.convert("L")
384
+ w, h = mask.size
385
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
386
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
387
+ mask = np.array(mask).astype(np.float32) / 255.0
388
+ mask = np.tile(mask, (4, 1, 1))
389
+ mask = np.vstack([mask[None]] * batch_size)
390
+ mask = 1 - mask # repaint white, keep black
391
+ mask = torch.from_numpy(mask)
392
+ return mask
393
+
394
+ else:
395
+ valid_mask_channel_sizes = [1, 3]
396
+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
397
+ if mask.shape[3] in valid_mask_channel_sizes:
398
+ mask = mask.permute(0, 3, 1, 2)
399
+ elif mask.shape[1] not in valid_mask_channel_sizes:
400
+ raise ValueError(
401
+ f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
402
+ f" but received mask of shape {tuple(mask.shape)}"
403
+ )
404
+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
405
+ mask = mask.mean(dim=1, keepdim=True)
406
+ h, w = mask.shape[-2:]
407
+ h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
408
+ mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
409
+ return mask
410
 
411
 
412
+ class StableDiffusionLongPromptWeightingPipeline(
413
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
414
+ ):
415
  r"""
416
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
417
  weighting in prompt.
 
436
  safety_checker ([`StableDiffusionSafetyChecker`]):
437
  Classification module that estimates whether generated images could be considered offensive or harmful.
438
  Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
439
+ feature_extractor ([`CLIPImageProcessor`]):
440
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
441
  """
442
 
443
+ _optional_components = ["safety_checker", "feature_extractor"]
444
+
445
+ def __init__(
446
+ self,
447
+ vae: AutoencoderKL,
448
+ text_encoder: CLIPTextModel,
449
+ tokenizer: CLIPTokenizer,
450
+ unet: UNet2DConditionModel,
451
+ scheduler: KarrasDiffusionSchedulers,
452
+ safety_checker: StableDiffusionSafetyChecker,
453
+ feature_extractor: CLIPImageProcessor,
454
+ requires_safety_checker: bool = True,
455
+ ):
456
+ super().__init__()
457
+
458
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
459
+ deprecation_message = (
460
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
461
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
462
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
463
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
464
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
465
+ " file"
466
+ )
467
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
468
+ new_config = dict(scheduler.config)
469
+ new_config["steps_offset"] = 1
470
+ scheduler._internal_dict = FrozenDict(new_config)
471
+
472
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
473
+ deprecation_message = (
474
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
475
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
476
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
477
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
478
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
479
+ )
480
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
481
+ new_config = dict(scheduler.config)
482
+ new_config["clip_sample"] = False
483
+ scheduler._internal_dict = FrozenDict(new_config)
484
+
485
+ if safety_checker is None and requires_safety_checker:
486
+ logger.warning(
487
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
488
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
489
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
490
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
491
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
492
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
493
  )
 
494
 
495
+ if safety_checker is not None and feature_extractor is None:
496
+ raise ValueError(
497
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
498
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
499
+ )
500
 
501
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
502
+ version.parse(unet.config._diffusers_version).base_version
503
+ ) < version.parse("0.9.0.dev0")
504
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
505
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
506
+ deprecation_message = (
507
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
508
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
509
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
510
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
511
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
512
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
513
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
514
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
515
+ " the `unet/config.json` file"
 
 
 
516
  )
517
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
518
+ new_config = dict(unet.config)
519
+ new_config["sample_size"] = 64
520
+ unet._internal_dict = FrozenDict(new_config)
521
+ self.register_modules(
522
+ vae=vae,
523
+ text_encoder=text_encoder,
524
+ tokenizer=tokenizer,
525
+ unet=unet,
526
+ scheduler=scheduler,
527
+ safety_checker=safety_checker,
528
+ feature_extractor=feature_extractor,
529
+ )
530
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
531
+
532
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
533
+ self.register_to_config(
534
+ requires_safety_checker=requires_safety_checker,
535
+ )
536
 
537
+ def enable_vae_slicing(self):
538
+ r"""
539
+ Enable sliced VAE decoding.
540
+
541
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
542
+ steps. This is useful to save some memory and allow larger batch sizes.
543
+ """
544
+ self.vae.enable_slicing()
545
+
546
+ def disable_vae_slicing(self):
547
+ r"""
548
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
549
+ computing decoding in one step.
550
+ """
551
+ self.vae.disable_slicing()
552
+
553
+ def enable_vae_tiling(self):
554
+ r"""
555
+ Enable tiled VAE decoding.
556
+
557
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
558
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
559
+ """
560
+ self.vae.enable_tiling()
561
+
562
+ def disable_vae_tiling(self):
563
+ r"""
564
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
565
+ computing decoding in one step.
566
+ """
567
+ self.vae.disable_tiling()
568
+
569
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
570
+ def enable_sequential_cpu_offload(self, gpu_id=0):
571
+ r"""
572
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
573
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
574
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
575
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
576
+ `enable_model_cpu_offload`, but performance is lower.
577
+ """
578
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
579
+ from accelerate import cpu_offload
580
+ else:
581
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
582
+
583
+ device = torch.device(f"cuda:{gpu_id}")
584
+
585
+ if self.device.type != "cpu":
586
+ self.to("cpu", silence_dtype_warnings=True)
587
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
588
+
589
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
590
+ cpu_offload(cpu_offloaded_model, device)
591
+
592
+ if self.safety_checker is not None:
593
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
594
+
595
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
596
+ def enable_model_cpu_offload(self, gpu_id=0):
597
+ r"""
598
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
599
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
600
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
601
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
602
+ """
603
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
604
+ from accelerate import cpu_offload_with_hook
605
+ else:
606
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
607
+
608
+ device = torch.device(f"cuda:{gpu_id}")
609
+
610
+ if self.device.type != "cpu":
611
+ self.to("cpu", silence_dtype_warnings=True)
612
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
613
+
614
+ hook = None
615
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
616
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
617
+
618
+ if self.safety_checker is not None:
619
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
620
+
621
+ # We'll offload the last model manually.
622
+ self.final_offload_hook = hook
623
 
624
  @property
625
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
626
  def _execution_device(self):
627
  r"""
628
  Returns the device on which the pipeline's models will be executed. After calling
629
  `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
630
  hooks.
631
  """
632
+ if not hasattr(self.unet, "_hf_hook"):
633
  return self.device
634
  for module in self.unet.modules():
635
  if (
 
646
  device,
647
  num_images_per_prompt,
648
  do_classifier_free_guidance,
649
+ negative_prompt=None,
650
+ max_embeddings_multiples=3,
651
+ prompt_embeds: Optional[torch.FloatTensor] = None,
652
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
653
  ):
654
  r"""
655
  Encodes the prompt into text encoder hidden states.
 
669
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
670
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
671
  """
672
+ if prompt is not None and isinstance(prompt, str):
673
+ batch_size = 1
674
+ elif prompt is not None and isinstance(prompt, list):
675
+ batch_size = len(prompt)
676
+ else:
677
+ batch_size = prompt_embeds.shape[0]
678
+
679
+ if negative_prompt_embeds is None:
680
+ if negative_prompt is None:
681
+ negative_prompt = [""] * batch_size
682
+ elif isinstance(negative_prompt, str):
683
+ negative_prompt = [negative_prompt] * batch_size
684
+ if batch_size != len(negative_prompt):
685
+ raise ValueError(
686
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
687
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
688
+ " the batch size of `prompt`."
689
+ )
690
+ if prompt_embeds is None or negative_prompt_embeds is None:
691
+ if isinstance(self, TextualInversionLoaderMixin):
692
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
693
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
694
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
695
+
696
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
697
+ pipe=self,
698
+ prompt=prompt,
699
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
700
+ max_embeddings_multiples=max_embeddings_multiples,
701
  )
702
+ if prompt_embeds is None:
703
+ prompt_embeds = prompt_embeds1
704
+ if negative_prompt_embeds is None:
705
+ negative_prompt_embeds = negative_prompt_embeds1
706
 
707
+ bs_embed, seq_len, _ = prompt_embeds.shape
708
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
709
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
710
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
 
 
 
 
711
 
712
  if do_classifier_free_guidance:
713
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
714
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
715
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
716
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
717
 
718
+ return prompt_embeds
719
 
720
+ def check_inputs(
721
+ self,
722
+ prompt,
723
+ height,
724
+ width,
725
+ strength,
726
+ callback_steps,
727
+ negative_prompt=None,
728
+ prompt_embeds=None,
729
+ negative_prompt_embeds=None,
730
+ ):
731
+ if height % 8 != 0 or width % 8 != 0:
732
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
733
 
734
  if strength < 0 or strength > 1:
735
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
736
 
 
 
 
737
  if (callback_steps is None) or (
738
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
739
  ):
 
742
  f" {type(callback_steps)}."
743
  )
744
 
745
+ if prompt is not None and prompt_embeds is not None:
746
+ raise ValueError(
747
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
748
+ " only forward one of the two."
749
+ )
750
+ elif prompt is None and prompt_embeds is None:
751
+ raise ValueError(
752
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
753
+ )
754
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
755
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
756
+
757
+ if negative_prompt is not None and negative_prompt_embeds is not None:
758
+ raise ValueError(
759
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
760
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
761
+ )
762
+
763
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
764
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
765
+ raise ValueError(
766
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
767
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
768
+ f" {negative_prompt_embeds.shape}."
769
+ )
770
+
771
  def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
772
  if is_text2img:
773
  return self.scheduler.timesteps.to(device), num_inference_steps
774
  else:
775
  # get the original timestep using init_timestep
776
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
777
+
778
+ t_start = max(num_inference_steps - init_timestep, 0)
779
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
780
 
 
 
781
  return timesteps, num_inference_steps - t_start
782
 
783
  def run_safety_checker(self, image, device, dtype):
 
791
  return image, has_nsfw_concept
792
 
793
  def decode_latents(self, latents):
794
+ latents = 1 / self.vae.config.scaling_factor * latents
795
  image = self.vae.decode(latents).sample
796
  image = (image / 2 + 0.5).clamp(0, 1)
797
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
798
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
799
  return image
800
 
 
815
  extra_step_kwargs["generator"] = generator
816
  return extra_step_kwargs
817
 
818
+ def prepare_latents(
819
+ self,
820
+ image,
821
+ timestep,
822
+ num_images_per_prompt,
823
+ batch_size,
824
+ num_channels_latents,
825
+ height,
826
+ width,
827
+ dtype,
828
+ device,
829
+ generator,
830
+ latents=None,
831
+ ):
832
  if image is None:
833
+ batch_size = batch_size * num_images_per_prompt
834
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
835
+ if isinstance(generator, list) and len(generator) != batch_size:
836
+ raise ValueError(
837
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
838
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
839
+ )
840
 
841
  if latents is None:
842
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
 
 
843
  else:
 
 
844
  latents = latents.to(device)
845
 
846
  # scale the initial noise by the standard deviation required by the scheduler
847
  latents = latents * self.scheduler.init_noise_sigma
848
  return latents, None, None
849
  else:
850
+ image = image.to(device=self.device, dtype=dtype)
851
  init_latent_dist = self.vae.encode(image).latent_dist
852
  init_latents = init_latent_dist.sample(generator=generator)
853
+ init_latents = self.vae.config.scaling_factor * init_latents
854
+
855
+ # Expand init_latents for batch_size and num_images_per_prompt
856
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
857
  init_latents_orig = init_latents
 
858
 
859
  # add noise to latents using the timesteps
860
+ noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
861
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
862
+ latents = init_latents
 
 
863
  return latents, init_latents_orig, noise
864
 
865
  @torch.no_grad()
 
875
  guidance_scale: float = 7.5,
876
  strength: float = 0.8,
877
  num_images_per_prompt: Optional[int] = 1,
878
+ add_predicted_noise: Optional[bool] = False,
879
  eta: float = 0.0,
880
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
881
  latents: Optional[torch.FloatTensor] = None,
882
+ prompt_embeds: Optional[torch.FloatTensor] = None,
883
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
884
  max_embeddings_multiples: Optional[int] = 3,
885
  output_type: Optional[str] = "pil",
886
  return_dict: bool = True,
887
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
888
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
889
+ callback_steps: int = 1,
890
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
891
  ):
892
  r"""
893
  Function invoked when calling the pipeline for generation.
 
927
  `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
928
  num_images_per_prompt (`int`, *optional*, defaults to 1):
929
  The number of images to generate per prompt.
930
+ add_predicted_noise (`bool`, *optional*, defaults to True):
931
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
932
+ the reverse diffusion process
933
  eta (`float`, *optional*, defaults to 0.0):
934
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
935
  [`schedulers.DDIMScheduler`], will be ignored for others.
936
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
937
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
938
+ to make generation deterministic.
939
  latents (`torch.FloatTensor`, *optional*):
940
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
941
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
942
  tensor will ge generated by sampling using the supplied random `generator`.
943
+ prompt_embeds (`torch.FloatTensor`, *optional*):
944
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
945
+ provided, text embeddings will be generated from `prompt` input argument.
946
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
947
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
948
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
949
+ argument.
950
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
951
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
952
  output_type (`str`, *optional*, defaults to `"pil"`):
 
964
  callback_steps (`int`, *optional*, defaults to 1):
965
  The frequency at which the `callback` function will be called. If not specified, the callback will be
966
  called at every step.
967
+ cross_attention_kwargs (`dict`, *optional*):
968
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
969
+ `self.processor` in
970
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
971
 
972
  Returns:
973
  `None` if cancelled by `is_cancelled_callback`,
 
977
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
978
  (nsfw) content, according to the `safety_checker`.
979
  """
 
 
 
 
980
  # 0. Default height and width to unet
981
  height = height or self.unet.config.sample_size * self.vae_scale_factor
982
  width = width or self.unet.config.sample_size * self.vae_scale_factor
983
 
984
  # 1. Check inputs. Raise error if not correct
985
+ self.check_inputs(
986
+ prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
987
+ )
988
 
989
  # 2. Define call parameters
990
+ if prompt is not None and isinstance(prompt, str):
991
+ batch_size = 1
992
+ elif prompt is not None and isinstance(prompt, list):
993
+ batch_size = len(prompt)
994
+ else:
995
+ batch_size = prompt_embeds.shape[0]
996
+
997
  device = self._execution_device
998
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
999
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
 
1001
  do_classifier_free_guidance = guidance_scale > 1.0
1002
 
1003
  # 3. Encode input prompt
1004
+ prompt_embeds = self._encode_prompt(
1005
  prompt,
1006
  device,
1007
  num_images_per_prompt,
1008
  do_classifier_free_guidance,
1009
  negative_prompt,
1010
  max_embeddings_multiples,
1011
+ prompt_embeds=prompt_embeds,
1012
+ negative_prompt_embeds=negative_prompt_embeds,
1013
  )
1014
+ dtype = prompt_embeds.dtype
1015
 
1016
  # 4. Preprocess image and mask
1017
  if isinstance(image, PIL.Image.Image):
1018
+ image = preprocess_image(image, batch_size)
1019
  if image is not None:
1020
  image = image.to(device=self.device, dtype=dtype)
1021
  if isinstance(mask_image, PIL.Image.Image):
1022
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
1023
  if mask_image is not None:
1024
  mask = mask_image.to(device=self.device, dtype=dtype)
1025
+ mask = torch.cat([mask] * num_images_per_prompt)
1026
  else:
1027
  mask = None
1028
 
 
1035
  latents, init_latents_orig, noise = self.prepare_latents(
1036
  image,
1037
  latent_timestep,
1038
+ num_images_per_prompt,
1039
+ batch_size,
1040
+ self.unet.config.in_channels,
1041
  height,
1042
  width,
1043
  dtype,
 
1050
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1051
 
1052
  # 8. Denoising loop
1053
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1054
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1055
+ for i, t in enumerate(timesteps):
1056
+ # expand the latents if we are doing classifier free guidance
1057
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1058
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1059
+
1060
+ # predict the noise residual
1061
+ noise_pred = self.unet(
1062
+ latent_model_input,
1063
+ t,
1064
+ encoder_hidden_states=prompt_embeds,
1065
+ cross_attention_kwargs=cross_attention_kwargs,
1066
+ ).sample
1067
+
1068
+ # perform guidance
1069
+ if do_classifier_free_guidance:
1070
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1071
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1072
+
1073
+ # compute the previous noisy sample x_t -> x_t-1
1074
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1075
+
1076
+ if mask is not None:
1077
+ # masking
1078
+ if add_predicted_noise:
1079
+ init_latents_proper = self.scheduler.add_noise(
1080
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
1081
+ )
1082
+ else:
1083
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1084
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
1085
+
1086
+ # call the callback, if provided
1087
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1088
+ progress_bar.update()
1089
+ if i % callback_steps == 0:
1090
+ if callback is not None:
1091
+ callback(i, t, latents)
1092
+ if is_cancelled_callback is not None and is_cancelled_callback():
1093
+ return None
1094
+
1095
+ if output_type == "latent":
1096
+ image = latents
1097
+ has_nsfw_concept = None
1098
+ elif output_type == "pil":
1099
+ # 9. Post-processing
1100
+ image = self.decode_latents(latents)
1101
+
1102
+ # 10. Run safety checker
1103
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1104
+
1105
+ # 11. Convert to PIL
1106
  image = self.numpy_to_pil(image)
1107
+ else:
1108
+ # 9. Post-processing
1109
+ image = self.decode_latents(latents)
1110
+
1111
+ # 10. Run safety checker
1112
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1113
+
1114
+ # Offload last model to CPU
1115
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1116
+ self.final_offload_hook.offload()
1117
 
1118
  if not return_dict:
1119
  return image, has_nsfw_concept
 
1130
  guidance_scale: float = 7.5,
1131
  num_images_per_prompt: Optional[int] = 1,
1132
  eta: float = 0.0,
1133
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1134
  latents: Optional[torch.FloatTensor] = None,
1135
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1136
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1137
  max_embeddings_multiples: Optional[int] = 3,
1138
  output_type: Optional[str] = "pil",
1139
  return_dict: bool = True,
1140
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1141
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
1142
+ callback_steps: int = 1,
1143
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1144
  ):
1145
  r"""
1146
  Function for text-to-image generation.
 
1168
  eta (`float`, *optional*, defaults to 0.0):
1169
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1170
  [`schedulers.DDIMScheduler`], will be ignored for others.
1171
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1172
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1173
+ to make generation deterministic.
1174
  latents (`torch.FloatTensor`, *optional*):
1175
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1176
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1177
  tensor will ge generated by sampling using the supplied random `generator`.
1178
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1179
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1180
+ provided, text embeddings will be generated from `prompt` input argument.
1181
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1182
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1183
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1184
+ argument.
1185
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1186
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1187
  output_type (`str`, *optional*, defaults to `"pil"`):
 
1199
  callback_steps (`int`, *optional*, defaults to 1):
1200
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1201
  called at every step.
1202
+ cross_attention_kwargs (`dict`, *optional*):
1203
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1204
+ `self.processor` in
1205
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1206
+
1207
  Returns:
1208
+ `None` if cancelled by `is_cancelled_callback`,
1209
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1210
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1211
  When returning a tuple, the first element is a list with the generated images, and the second element is a
 
1223
  eta=eta,
1224
  generator=generator,
1225
  latents=latents,
1226
+ prompt_embeds=prompt_embeds,
1227
+ negative_prompt_embeds=negative_prompt_embeds,
1228
  max_embeddings_multiples=max_embeddings_multiples,
1229
  output_type=output_type,
1230
  return_dict=return_dict,
1231
  callback=callback,
1232
  is_cancelled_callback=is_cancelled_callback,
1233
  callback_steps=callback_steps,
1234
+ cross_attention_kwargs=cross_attention_kwargs,
1235
  )
1236
 
1237
  def img2img(
 
1244
  guidance_scale: Optional[float] = 7.5,
1245
  num_images_per_prompt: Optional[int] = 1,
1246
  eta: Optional[float] = 0.0,
1247
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1248
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1249
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1250
  max_embeddings_multiples: Optional[int] = 3,
1251
  output_type: Optional[str] = "pil",
1252
  return_dict: bool = True,
1253
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1254
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
1255
+ callback_steps: int = 1,
1256
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1257
  ):
1258
  r"""
1259
  Function for image-to-image generation.
 
1286
  eta (`float`, *optional*, defaults to 0.0):
1287
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1288
  [`schedulers.DDIMScheduler`], will be ignored for others.
1289
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1290
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1291
+ to make generation deterministic.
1292
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1293
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1294
+ provided, text embeddings will be generated from `prompt` input argument.
1295
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1296
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1297
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1298
+ argument.
1299
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1300
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1301
  output_type (`str`, *optional*, defaults to `"pil"`):
 
1313
  callback_steps (`int`, *optional*, defaults to 1):
1314
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1315
  called at every step.
1316
+ cross_attention_kwargs (`dict`, *optional*):
1317
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1318
+ `self.processor` in
1319
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1320
+
1321
  Returns:
1322
+ `None` if cancelled by `is_cancelled_callback`,
1323
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1324
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1325
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
 
1335
  num_images_per_prompt=num_images_per_prompt,
1336
  eta=eta,
1337
  generator=generator,
1338
+ prompt_embeds=prompt_embeds,
1339
+ negative_prompt_embeds=negative_prompt_embeds,
1340
  max_embeddings_multiples=max_embeddings_multiples,
1341
  output_type=output_type,
1342
  return_dict=return_dict,
1343
  callback=callback,
1344
  is_cancelled_callback=is_cancelled_callback,
1345
  callback_steps=callback_steps,
1346
+ cross_attention_kwargs=cross_attention_kwargs,
1347
  )
1348
 
1349
  def inpaint(
 
1356
  num_inference_steps: Optional[int] = 50,
1357
  guidance_scale: Optional[float] = 7.5,
1358
  num_images_per_prompt: Optional[int] = 1,
1359
+ add_predicted_noise: Optional[bool] = False,
1360
  eta: Optional[float] = 0.0,
1361
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1362
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1363
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1364
  max_embeddings_multiples: Optional[int] = 3,
1365
  output_type: Optional[str] = "pil",
1366
  return_dict: bool = True,
1367
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1368
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
1369
+ callback_steps: int = 1,
1370
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1371
  ):
1372
  r"""
1373
  Function for inpaint.
 
1401
  usually at the expense of lower image quality.
1402
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1403
  The number of images to generate per prompt.
1404
+ add_predicted_noise (`bool`, *optional*, defaults to True):
1405
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
1406
+ the reverse diffusion process
1407
  eta (`float`, *optional*, defaults to 0.0):
1408
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1409
  [`schedulers.DDIMScheduler`], will be ignored for others.
1410
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1411
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1412
+ to make generation deterministic.
1413
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1414
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1415
+ provided, text embeddings will be generated from `prompt` input argument.
1416
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1417
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1418
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1419
+ argument.
1420
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1421
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1422
  output_type (`str`, *optional*, defaults to `"pil"`):
 
1434
  callback_steps (`int`, *optional*, defaults to 1):
1435
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1436
  called at every step.
1437
+ cross_attention_kwargs (`dict`, *optional*):
1438
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1439
+ `self.processor` in
1440
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1441
+
1442
  Returns:
1443
+ `None` if cancelled by `is_cancelled_callback`,
1444
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1445
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1446
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
 
1455
  guidance_scale=guidance_scale,
1456
  strength=strength,
1457
  num_images_per_prompt=num_images_per_prompt,
1458
+ add_predicted_noise=add_predicted_noise,
1459
  eta=eta,
1460
  generator=generator,
1461
+ prompt_embeds=prompt_embeds,
1462
+ negative_prompt_embeds=negative_prompt_embeds,
1463
  max_embeddings_multiples=max_embeddings_multiples,
1464
  output_type=output_type,
1465
  return_dict=return_dict,
1466
  callback=callback,
1467
  is_cancelled_callback=is_cancelled_callback,
1468
  callback_steps=callback_steps,
1469
+ cross_attention_kwargs=cross_attention_kwargs,
1470
  )