from importlib import import_module from typing import Callable, Optional, Union import math from einops import rearrange, repeat import torch import torch.nn.functional as F from torch import nn from diffusers.utils import deprecate, logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer from diffusers.models.attention_processor import ( Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0, AttnProcessor, AttnProcessor2_0, SpatialNorm, LORA_ATTENTION_PROCESSORS, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, SlicedAttnAddedKVProcessor, XFormersAttnAddedKVProcessor, LoRAAttnAddedKVProcessor, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor, LoRAAttnProcessor2_0, SlicedAttnProcessor, AttentionProcessor ) from .rotary_embedding import RotaryEmbedding logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_xformers_available(): import xformers import xformers.ops else: xformers = None @maybe_allow_in_graph class ConditionalAttention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, ): super().__init__() self.inner_dim = dim_head * heads self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." ) if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None if spatial_norm_dim is not None: self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(self.cross_attention_dim) elif cross_attention_norm == "group_norm": if self.added_kv_proj_dim is not None: # The given `encoder_hidden_states` are initially of shape # (batch_size, seq_len, added_kv_proj_dim) before being projected # to (batch_size, seq_len, cross_attention_dim). The norm is applied # before the projection, so we need to use `added_kv_proj_dim` as # the number of channels for the group norm. norm_cross_num_channels = added_kv_proj_dim else: norm_cross_num_channels = self.cross_attention_dim self.norm_cross = nn.GroupNorm( num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True ) else: raise ValueError( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( self.processor, LORA_ATTENTION_PROCESSORS, ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) ) is_added_kv_processor = hasattr(self, "processor") and isinstance( self.processor, ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, XFormersAttnAddedKVProcessor, LoRAAttnAddedKVProcessor, ), ) if use_memory_efficient_attention_xformers: if is_added_kv_processor and (is_lora or is_custom_diffusion): raise NotImplementedError( f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" ) if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers" ), name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e if is_lora: # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? processor = LoRAXFormersAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: processor = CustomDiffusionXFormersAttnProcessor( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) elif is_added_kv_processor: # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP # which uses this type of cross attention ONLY because the attention mask of format # [0, ..., -10.000, ..., 0, ...,] is not supported # throw warning logger.info( "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." ) processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_lora: attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) processor = attn_processor_class( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: processor = CustomDiffusionAttnProcessor( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_attention_slice(self, slice_size): if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") if slice_size is not None and self.added_kv_proj_dim is not None: processor = SlicedAttnAddedKVProcessor(slice_size) elif slice_size is not None: processor = SlicedAttnProcessor(slice_size) elif self.added_kv_proj_dim is not None: processor = AttnAddedKVProcessor() else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): if ( hasattr(self, "processor") and not isinstance(processor, LORA_ATTENTION_PROCESSORS) and self.to_q.lora_layer is not None ): deprecate( "set_processor to offload LoRA", "0.26.0", "In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", ) # (Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete # We need to remove all LoRA layers for module in self.modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( hasattr(self, "processor") and isinstance(self.processor, torch.nn.Module) and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") self.processor = processor def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": if not return_deprecated_lora: return self.processor # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible # serialization format for LoRA Attention Processors. It should be deleted once the integration # with PEFT is completed. is_lora_activated = { name: module.lora_layer is not None for name, module in self.named_modules() if hasattr(module, "lora_layer") } # 1. if no layer has a LoRA activated we can return the processor as usual if not any(is_lora_activated.values()): return self.processor # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` is_lora_activated.pop("add_k_proj", None) is_lora_activated.pop("add_v_proj", None) # 2. else it is not posssible that only some layers have LoRA activated if not all(is_lora_activated.values()): raise ValueError( f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" ) # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor non_lora_processor_cls_name = self.processor.__class__.__name__ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) hidden_size = self.inner_dim # now create a LoRA attention processor from the LoRA layers if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: kwargs = { "cross_attention_dim": self.cross_attention_dim, "rank": self.to_q.lora_layer.rank, "network_alpha": self.to_q.lora_layer.network_alpha, "q_rank": self.to_q.lora_layer.rank, "q_hidden_size": self.to_q.lora_layer.out_features, "k_rank": self.to_k.lora_layer.rank, "k_hidden_size": self.to_k.lora_layer.out_features, "v_rank": self.to_v.lora_layer.rank, "v_hidden_size": self.to_v.lora_layer.out_features, "out_rank": self.to_out[0].lora_layer.rank, "out_hidden_size": self.to_out[0].lora_layer.out_features, } if hasattr(self.processor, "attention_op"): kwargs["attention_op"] = self.prcoessor.attention_op lora_processor = lora_processor_cls(hidden_size, **kwargs) lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) elif lora_processor_cls == LoRAAttnAddedKVProcessor: lora_processor = lora_processor_cls( hidden_size, cross_attention_dim=self.add_k_proj.weight.shape[0], rank=self.to_q.lora_layer.rank, network_alpha=self.to_q.lora_layer.network_alpha, ) lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) # only save if used if self.add_k_proj.lora_layer is not None: lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) else: lora_processor.add_k_proj_lora = None lora_processor.add_v_proj_lora = None else: raise ValueError(f"{lora_processor_cls} does not exist.") return lora_processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) def batch_to_head_dim(self, tensor): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def head_to_batch_dim(self, tensor, out_dim=3): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def get_attention_scores(self, query, key, attention_mask=None): dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, ) del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) del attention_scores attention_probs = attention_probs.to(dtype) return attention_probs def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): if batch_size is None: deprecate( "batch_size=None", "0.22.0", ( "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" " `prepare_attention_mask` when preparing the attention_mask." ), ) batch_size = 1 head_size = self.heads if attention_mask is None: return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave(head_size, dim=1) return attention_mask def norm_encoder_hidden_states(self, encoder_hidden_states): assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" if isinstance(self.norm_cross, nn.LayerNorm): encoder_hidden_states = self.norm_cross(encoder_hidden_states) elif isinstance(self.norm_cross, nn.GroupNorm): # Group norm norms along the channels dimension and expects # input to be in the shape of (N, C, *). In this case, we want # to norm along the hidden dimension, so we need to move # (batch_size, sequence_length, hidden_size) -> # (batch_size, hidden_size, sequence_length) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) encoder_hidden_states = self.norm_cross(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) else: assert False return encoder_hidden_states class TemporalConditionalAttention(Attention): def __init__(self, n_frames=8, rotary_emb=False, *args, **kwargs): super().__init__(processor=RotaryEmbAttnProcessor2_0() if rotary_emb else None, *args, **kwargs) if not rotary_emb: self.pos_enc = PositionalEncoding(self.inner_dim) else: rotary_bias = RelativePositionBias(heads=kwargs['heads'], max_distance=32) self.rotary_bias = rotary_bias self.rotary_emb = RotaryEmbedding(self.inner_dim // 2) self.use_rotary_emb = rotary_emb self.n_frames = n_frames def forward( self, hidden_states, encoder_hidden_states=None, attention_mask=None, adjacent_slices=None, **cross_attention_kwargs): key_pos_idx = None bt, hw, c = hidden_states.shape hidden_states = rearrange(hidden_states, '(b t) hw c -> b hw t c', t=self.n_frames) if not self.use_rotary_emb: pos_embed = self.pos_enc(self.n_frames) hidden_states = hidden_states + pos_embed hidden_states = rearrange(hidden_states, 'b hw t c -> (b hw) t c') if encoder_hidden_states is not None: assert adjacent_slices is None encoder_hidden_states = encoder_hidden_states[::self.n_frames] encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b hw) n c', hw=hw) if adjacent_slices is not None: assert encoder_hidden_states is None adjacent_slices = rearrange(adjacent_slices, 'b c h w n -> b (h w) n c') if not self.use_rotary_emb: first_frame_pos_embed = pos_embed[0:1, :] adjacent_slices = adjacent_slices + first_frame_pos_embed else: pos_idx = torch.arange(self.n_frames, device=hidden_states.device, dtype=hidden_states.dtype) first_frame_pos_pad = torch.zeros(adjacent_slices.shape[2], device=hidden_states.device, dtype=hidden_states.dtype) key_pos_idx = torch.cat([pos_idx, first_frame_pos_pad], dim=0) adjacent_slices = rearrange(adjacent_slices, 'b hw n c -> (b hw) n c') encoder_hidden_states = torch.cat([hidden_states, adjacent_slices], dim=1) if not self.use_rotary_emb: out = self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) else: out = self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, key_pos_idx=key_pos_idx, **cross_attention_kwargs, ) out = rearrange(out, '(b hw) t c -> (b t) hw c', hw=hw) return out def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers, attention_op=None): if use_memory_efficient_attention_xformers: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e processor = XFormersAttnProcessor(attention_op=attention_op) else: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) class PositionalEncoding(nn.Module): def __init__(self, dim, max_pos=512): super().__init__() pos = torch.arange(max_pos) freq = torch.arange(dim//2) / dim freq = (freq * torch.tensor(10000).log()).exp() x = rearrange(pos, 'L -> L 1') / freq x = rearrange(x, 'L d -> L d 1') pe = torch.cat((x.sin(), x.cos()), dim=-1) self.pe = rearrange(pe, 'L d sc -> L (d sc)') self.dummy = nn.Parameter(torch.rand(1)) def forward(self, length): enc = self.pe[:length] enc = enc.to(self.dummy.device, self.dummy.dtype) return enc # code taken from https://github.com/Vchitect/LaVie/blob/main/base/models/temporal_attention.py class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128, ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, qlen, klen, device, dtype): q_pos = torch.arange(qlen, dtype = torch.long, device = device) k_pos = torch.arange(klen, dtype = torch.long, device = device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) values = self.relative_attention_bias(rp_bucket) values = values.to(device, dtype) return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames class RotaryEmbAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). Add rotary embedding support """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, scale: float = 1.0, key_pos_idx: Optional[torch.Tensor] = None, ): assert attention_mask is None residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) # if attention_mask is not None: # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # # scaled_dot_product_attention expects attention_mask shape to be # # (batch, heads, source_length, target_length) # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) qlen = hidden_states.shape[1] klen = encoder_hidden_states.shape[1] # currently only add bias for self attention. Relative distance doesn't make sense for cross attention. # if qlen == klen: # time_rel_pos_bias = attn.rotary_bias(qlen, klen, device=hidden_states.device, dtype=hidden_states.dtype) # attention_mask = repeat(time_rel_pos_bias, "h d1 d2 -> b h d1 d2", b=batch_size) key = attn.to_k(encoder_hidden_states, scale=scale) value = attn.to_v(encoder_hidden_states, scale=scale) query = attn.rotary_emb.rotate_queries_or_keys(query) if qlen == klen: key = attn.rotary_emb.rotate_queries_or_keys(key) elif key_pos_idx is not None: key = attn.rotary_emb.rotate_queries_or_keys(key, seq_pos=key_pos_idx) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states