Caching methods
Pyramid Attention Broadcast
Pyramid Attention Broadcast from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with ~PyramidAttentionBroadcastConfig on any pipeline. For some benchmarks, refer to this pull request.
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
CacheMixin
A class for enable/disabling caching techniques on diffusion models.
Supported caching techniques:
enable_cache
< source >( config )
Enable caching techniques on the model.
Example:
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> pipe.transformer.enable_cache(config)
PyramidAttentionBroadcastConfig
class diffusers.PyramidAttentionBroadcastConfig
< source >( spatial_attention_block_skip_range: typing.Optional[int] = None temporal_attention_block_skip_range: typing.Optional[int] = None cross_attention_block_skip_range: typing.Optional[int] = None spatial_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) temporal_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) cross_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) spatial_attention_block_identifiers: typing.Tuple[str, ...] = ('blocks', 'transformer_blocks', 'single_transformer_blocks') temporal_attention_block_identifiers: typing.Tuple[str, ...] = ('temporal_transformer_blocks',) cross_attention_block_identifiers: typing.Tuple[str, ...] = ('blocks', 'transformer_blocks') current_timestep_callback: typing.Callable[[], int] = None )
Parameters
- spatial_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific spatial attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - temporal_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific temporal attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - cross_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific cross-attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - spatial_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the spatial attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - temporal_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the temporal attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - cross_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the cross-attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - spatial_attention_block_identifiers (
Tuple[str, ...]
, defaults to("blocks", "transformer_blocks")
) — The identifiers to match against the layer names to determine if the layer is a spatial attention layer. - temporal_attention_block_identifiers (
Tuple[str, ...]
, defaults to("temporal_transformer_blocks",)
) — The identifiers to match against the layer names to determine if the layer is a temporal attention layer. - cross_attention_block_identifiers (
Tuple[str, ...]
, defaults to("blocks", "transformer_blocks")
) — The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
Configuration for Pyramid Attention Broadcast.
diffusers.apply_pyramid_attention_broadcast
< source >( module: Module config: PyramidAttentionBroadcastConfig )
Apply Pyramid Attention Broadcast to a given pipeline.
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to reduce the computational cost of attention computation. The key takeaway from the paper is that the attention similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
Example:
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)