Diffusers documentation

Caching methods

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.32.2).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Caching methods

Pyramid Attention Broadcast

Pyramid Attention Broadcast from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.

Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.

Enable PAB with ~PyramidAttentionBroadcastConfig on any pipeline. For some benchmarks, refer to this pull request.

import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
    spatial_attention_block_skip_range=2,
    spatial_attention_timestep_skip_range=(100, 800),
    current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)

CacheMixin

class diffusers.CacheMixin

< >

( )

A class for enable/disabling caching techniques on diffusion models.

Supported caching techniques:

enable_cache

< >

( config )

Parameters

  • config (Union[PyramidAttentionBroadcastConfig]) — The configuration for applying the caching technique. Currently supported caching techniques are:

Enable caching techniques on the model.

Example:

>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig

>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")

>>> config = PyramidAttentionBroadcastConfig(
...     spatial_attention_block_skip_range=2,
...     spatial_attention_timestep_skip_range=(100, 800),
...     current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> pipe.transformer.enable_cache(config)

PyramidAttentionBroadcastConfig

class diffusers.PyramidAttentionBroadcastConfig

< >

( spatial_attention_block_skip_range: typing.Optional[int] = None temporal_attention_block_skip_range: typing.Optional[int] = None cross_attention_block_skip_range: typing.Optional[int] = None spatial_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) temporal_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) cross_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) spatial_attention_block_identifiers: typing.Tuple[str, ...] = ('blocks', 'transformer_blocks', 'single_transformer_blocks') temporal_attention_block_identifiers: typing.Tuple[str, ...] = ('temporal_transformer_blocks',) cross_attention_block_identifiers: typing.Tuple[str, ...] = ('blocks', 'transformer_blocks') current_timestep_callback: typing.Callable[[], int] = None )

Parameters

  • spatial_attention_block_skip_range (int, optional, defaults to None) — The number of times a specific spatial attention broadcast is skipped before computing the attention states to re-use. If this is set to the value N, the attention computation will be skipped N - 1 times (i.e., old attention states will be re-used) before computing the new attention states again.
  • temporal_attention_block_skip_range (int, optional, defaults to None) — The number of times a specific temporal attention broadcast is skipped before computing the attention states to re-use. If this is set to the value N, the attention computation will be skipped N - 1 times (i.e., old attention states will be re-used) before computing the new attention states again.
  • cross_attention_block_skip_range (int, optional, defaults to None) — The number of times a specific cross-attention broadcast is skipped before computing the attention states to re-use. If this is set to the value N, the attention computation will be skipped N - 1 times (i.e., old attention states will be re-used) before computing the new attention states again.
  • spatial_attention_timestep_skip_range (Tuple[int, int], defaults to (100, 800)) — The range of timesteps to skip in the spatial attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range.
  • temporal_attention_timestep_skip_range (Tuple[int, int], defaults to (100, 800)) — The range of timesteps to skip in the temporal attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range.
  • cross_attention_timestep_skip_range (Tuple[int, int], defaults to (100, 800)) — The range of timesteps to skip in the cross-attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range.
  • spatial_attention_block_identifiers (Tuple[str, ...], defaults to ("blocks", "transformer_blocks")) — The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
  • temporal_attention_block_identifiers (Tuple[str, ...], defaults to ("temporal_transformer_blocks",)) — The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
  • cross_attention_block_identifiers (Tuple[str, ...], defaults to ("blocks", "transformer_blocks")) — The identifiers to match against the layer names to determine if the layer is a cross-attention layer.

Configuration for Pyramid Attention Broadcast.

diffusers.apply_pyramid_attention_broadcast

< >

( module: Module config: PyramidAttentionBroadcastConfig )

Parameters

  • module (torch.nn.Module) — The module to apply Pyramid Attention Broadcast to.
  • config (Optional[PyramidAttentionBroadcastConfig], optional, defaults to None) — The configuration to use for Pyramid Attention Broadcast.

Apply Pyramid Attention Broadcast to a given pipeline.

PAB is an attention approximation method that leverages the similarity in attention states between timesteps to reduce the computational cost of attention computation. The key takeaway from the paper is that the attention similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.

Example:

>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
>>> from diffusers.utils import export_to_video

>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")

>>> config = PyramidAttentionBroadcastConfig(
...     spatial_attention_block_skip_range=2,
...     spatial_attention_timestep_skip_range=(100, 800),
...     current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
< > Update on GitHub