Update pipeline.py
Browse files- pipeline.py +62 -6
pipeline.py
CHANGED
@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
|
12 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
13 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
14 |
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
15 |
-
from diffusers.utils import deprecate, logging
|
16 |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
17 |
|
18 |
|
@@ -40,7 +40,7 @@ re_attention = re.compile(
|
|
40 |
|
41 |
def parse_prompt_attention(text):
|
42 |
"""
|
43 |
-
Parses a string with attention tokens and returns a list of pairs: text and its
|
44 |
Accepted tokens are:
|
45 |
(abc) - increases attention to abc by a multiplier of 1.1
|
46 |
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
@@ -237,9 +237,9 @@ def get_weighted_text_embeddings(
|
|
237 |
r"""
|
238 |
Prompts can be assigned with local weights using brackets. For example,
|
239 |
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
240 |
-
and the embedding tokens corresponding to the words get
|
241 |
|
242 |
-
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the
|
243 |
|
244 |
Args:
|
245 |
pipe (`DiffusionPipeline`):
|
@@ -431,6 +431,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
431 |
new_config["steps_offset"] = 1
|
432 |
scheduler._internal_dict = FrozenDict(new_config)
|
433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
if safety_checker is None:
|
435 |
logger.warn(
|
436 |
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
@@ -451,6 +464,24 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
451 |
feature_extractor=feature_extractor,
|
452 |
)
|
453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
455 |
r"""
|
456 |
Enable sliced attention computation.
|
@@ -478,6 +509,23 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
478 |
# set slice_size = `None` to disable `attention slicing`
|
479 |
self.enable_attention_slicing(None)
|
480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
@torch.no_grad()
|
482 |
def __call__(
|
483 |
self,
|
@@ -498,6 +546,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
498 |
output_type: Optional[str] = "pil",
|
499 |
return_dict: bool = True,
|
500 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
|
501 |
callback_steps: Optional[int] = 1,
|
502 |
**kwargs,
|
503 |
):
|
@@ -560,11 +609,15 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
560 |
callback (`Callable`, *optional*):
|
561 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
562 |
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
|
|
|
|
|
|
563 |
callback_steps (`int`, *optional*, defaults to 1):
|
564 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
565 |
called at every step.
|
566 |
|
567 |
Returns:
|
|
|
568 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
569 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
570 |
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
@@ -757,8 +810,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
757 |
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
758 |
|
759 |
# call the callback, if provided
|
760 |
-
if
|
761 |
-
callback
|
|
|
|
|
|
|
762 |
|
763 |
latents = 1 / 0.18215 * latents
|
764 |
image = self.vae.decode(latents).sample
|
|
|
12 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
13 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
14 |
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
15 |
+
from diffusers.utils import deprecate, is_accelerate_available, logging
|
16 |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
17 |
|
18 |
|
|
|
40 |
|
41 |
def parse_prompt_attention(text):
|
42 |
"""
|
43 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
44 |
Accepted tokens are:
|
45 |
(abc) - increases attention to abc by a multiplier of 1.1
|
46 |
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
|
|
237 |
r"""
|
238 |
Prompts can be assigned with local weights using brackets. For example,
|
239 |
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
240 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
241 |
|
242 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
243 |
|
244 |
Args:
|
245 |
pipe (`DiffusionPipeline`):
|
|
|
431 |
new_config["steps_offset"] = 1
|
432 |
scheduler._internal_dict = FrozenDict(new_config)
|
433 |
|
434 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
435 |
+
deprecation_message = (
|
436 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
437 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
438 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
439 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
440 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
441 |
+
)
|
442 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
443 |
+
new_config = dict(scheduler.config)
|
444 |
+
new_config["clip_sample"] = False
|
445 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
446 |
+
|
447 |
if safety_checker is None:
|
448 |
logger.warn(
|
449 |
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
|
|
464 |
feature_extractor=feature_extractor,
|
465 |
)
|
466 |
|
467 |
+
def enable_xformers_memory_efficient_attention(self):
|
468 |
+
r"""
|
469 |
+
Enable memory efficient attention as implemented in xformers.
|
470 |
+
|
471 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
472 |
+
time. Speed up at training time is not guaranteed.
|
473 |
+
|
474 |
+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
475 |
+
is used.
|
476 |
+
"""
|
477 |
+
self.unet.set_use_memory_efficient_attention_xformers(True)
|
478 |
+
|
479 |
+
def disable_xformers_memory_efficient_attention(self):
|
480 |
+
r"""
|
481 |
+
Disable memory efficient attention as implemented in xformers.
|
482 |
+
"""
|
483 |
+
self.unet.set_use_memory_efficient_attention_xformers(False)
|
484 |
+
|
485 |
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
486 |
r"""
|
487 |
Enable sliced attention computation.
|
|
|
509 |
# set slice_size = `None` to disable `attention slicing`
|
510 |
self.enable_attention_slicing(None)
|
511 |
|
512 |
+
def enable_sequential_cpu_offload(self):
|
513 |
+
r"""
|
514 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
515 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
516 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
517 |
+
"""
|
518 |
+
if is_accelerate_available():
|
519 |
+
from accelerate import cpu_offload
|
520 |
+
else:
|
521 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
522 |
+
|
523 |
+
device = self.device
|
524 |
+
|
525 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
526 |
+
if cpu_offloaded_model is not None:
|
527 |
+
cpu_offload(cpu_offloaded_model, device)
|
528 |
+
|
529 |
@torch.no_grad()
|
530 |
def __call__(
|
531 |
self,
|
|
|
546 |
output_type: Optional[str] = "pil",
|
547 |
return_dict: bool = True,
|
548 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
549 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
550 |
callback_steps: Optional[int] = 1,
|
551 |
**kwargs,
|
552 |
):
|
|
|
609 |
callback (`Callable`, *optional*):
|
610 |
A function that will be called every `callback_steps` steps during inference. The function will be
|
611 |
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
612 |
+
is_cancelled_callback (`Callable`, *optional*):
|
613 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
614 |
+
`True`, the inference will be cancelled.
|
615 |
callback_steps (`int`, *optional*, defaults to 1):
|
616 |
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
617 |
called at every step.
|
618 |
|
619 |
Returns:
|
620 |
+
`None` if cancelled by `is_cancelled_callback`,
|
621 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
622 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
623 |
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
|
|
810 |
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
811 |
|
812 |
# call the callback, if provided
|
813 |
+
if i % callback_steps == 0:
|
814 |
+
if callback is not None:
|
815 |
+
callback(i, t, latents)
|
816 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
817 |
+
return None
|
818 |
|
819 |
latents = 1 / 0.18215 * latents
|
820 |
image = self.vae.decode(latents).sample
|