origordon commited on
Commit
07ddecf
·
1 Parent(s): c4b2a35

Decoder: add AttentionResBlocks block

Browse files

1. Support attention block after residual block.
2. Add flash attention support.

xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -9,10 +9,12 @@ import numpy as np
9
  from einops import rearrange
10
  from torch import nn
11
  from diffusers.utils import logging
 
12
 
13
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
14
  from xora.models.autoencoders.pixel_norm import PixelNorm
15
  from xora.models.autoencoders.vae import AutoencoderKLWrapper
 
16
 
17
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
 
@@ -212,6 +214,12 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
212
  last_layer = self.decoder.layers[-1]
213
  return last_layer
214
 
 
 
 
 
 
 
215
 
216
  class Encoder(nn.Module):
217
  r"""
@@ -483,6 +491,16 @@ class Decoder(nn.Module):
483
  norm_layer=norm_layer,
484
  inject_noise=block_params.get("inject_noise", False),
485
  )
 
 
 
 
 
 
 
 
 
 
486
  elif block_name == "res_x_y":
487
  output_channel = output_channel // block_params.get("multiplier", 2)
488
  block = ResnetBlock3D(
@@ -558,6 +576,129 @@ class Decoder(nn.Module):
558
  return sample
559
 
560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  class UNetMidBlock3D(nn.Module):
562
  """
563
  A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
 
9
  from einops import rearrange
10
  from torch import nn
11
  from diffusers.utils import logging
12
+ import torch.nn.functional as F
13
 
14
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
15
  from xora.models.autoencoders.pixel_norm import PixelNorm
16
  from xora.models.autoencoders.vae import AutoencoderKLWrapper
17
+ from xora.models.transformers.attention import Attention
18
 
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
 
214
  last_layer = self.decoder.layers[-1]
215
  return last_layer
216
 
217
+ def set_use_tpu_flash_attention(self):
218
+ for block in self.decoder.up_blocks:
219
+ if isinstance(block, AttentionResBlocks):
220
+ for attention_block in block.attention_blocks:
221
+ attention_block.set_use_tpu_flash_attention()
222
+
223
 
224
  class Encoder(nn.Module):
225
  r"""
 
491
  norm_layer=norm_layer,
492
  inject_noise=block_params.get("inject_noise", False),
493
  )
494
+ elif block_name == "attn_res_x":
495
+ block = AttentionResBlocks(
496
+ dims=dims,
497
+ in_channels=input_channel,
498
+ num_layers=block_params["num_layers"],
499
+ resnet_groups=norm_num_groups,
500
+ norm_layer=norm_layer,
501
+ attention_head_dim=block_params["attention_head_dim"],
502
+ inject_noise=block_params.get("inject_noise", False),
503
+ )
504
  elif block_name == "res_x_y":
505
  output_channel = output_channel // block_params.get("multiplier", 2)
506
  block = ResnetBlock3D(
 
576
  return sample
577
 
578
 
579
+ class AttentionResBlocks(nn.Module):
580
+ """
581
+ A 3D convolution residual block followed by self attention residual block
582
+
583
+ Args:
584
+ dims (`int` or `Tuple[int, int]`): The number of dimensions to use in convolutions.
585
+ in_channels (`int`): The number of input channels.
586
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
587
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
588
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
589
+ resnet_groups (`int`, *optional*, defaults to 32):
590
+ The number of groups to use in the group normalization layers of the resnet blocks.
591
+ norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use.
592
+ attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
593
+ inject_noise (`bool`, *optional*, defaults to `False`): Whether to inject noise or not between convolution layers.
594
+
595
+ Returns:
596
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
597
+ in_channels, height, width)`.
598
+
599
+ """
600
+
601
+ def __init__(
602
+ self,
603
+ dims: Union[int, Tuple[int, int]],
604
+ in_channels: int,
605
+ dropout: float = 0.0,
606
+ num_layers: int = 1,
607
+ resnet_eps: float = 1e-6,
608
+ resnet_groups: int = 32,
609
+ norm_layer: str = "group_norm",
610
+ attention_head_dim: int = 64,
611
+ inject_noise: bool = False,
612
+ ):
613
+ super().__init__()
614
+
615
+ if attention_head_dim > in_channels:
616
+ raise ValueError(
617
+ "attention_head_dim must be less than or equal to in_channels"
618
+ )
619
+
620
+ resnet_groups = (
621
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
622
+ )
623
+
624
+ self.res_blocks = []
625
+ self.attention_blocks = []
626
+ for i in range(num_layers):
627
+ self.res_blocks.append(
628
+ ResnetBlock3D(
629
+ dims=dims,
630
+ in_channels=in_channels,
631
+ out_channels=in_channels,
632
+ eps=resnet_eps,
633
+ groups=resnet_groups,
634
+ dropout=dropout,
635
+ norm_layer=norm_layer,
636
+ inject_noise=inject_noise,
637
+ )
638
+ )
639
+ self.attention_blocks.append(
640
+ Attention(
641
+ query_dim=in_channels,
642
+ heads=in_channels // attention_head_dim,
643
+ dim_head=attention_head_dim,
644
+ bias=True,
645
+ out_bias=True,
646
+ qk_norm="rms_norm",
647
+ residual_connection=True,
648
+ )
649
+ )
650
+
651
+ self.res_blocks = nn.ModuleList(self.res_blocks)
652
+ self.attention_blocks = nn.ModuleList(self.attention_blocks)
653
+
654
+ def forward(
655
+ self, hidden_states: torch.FloatTensor, causal: bool = True
656
+ ) -> torch.FloatTensor:
657
+ for resnet, attention in zip(self.res_blocks, self.attention_blocks):
658
+ hidden_states = resnet(hidden_states, causal=causal)
659
+
660
+ # Reshape the hidden states to be (batch_size, frames * height * width, channel)
661
+ batch_size, channel, frames, height, width = hidden_states.shape
662
+ hidden_states = hidden_states.view(
663
+ batch_size, channel, frames * height * width
664
+ ).transpose(1, 2)
665
+
666
+ if attention.use_tpu_flash_attention:
667
+ # Pad the second dimension to be divisible by block_k_major (block in flash attention)
668
+ seq_len = hidden_states.shape[1]
669
+ block_k_major = 512
670
+ pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
671
+ if pad_len > 0:
672
+ hidden_states = F.pad(
673
+ hidden_states, (0, 0, 0, pad_len), "constant", 0
674
+ )
675
+
676
+ # Create a mask with ones for the original sequence length and zeros for the padded indexes
677
+ mask = torch.ones(
678
+ (hidden_states.shape[0], seq_len),
679
+ device=hidden_states.device,
680
+ dtype=hidden_states.dtype,
681
+ )
682
+ if pad_len > 0:
683
+ mask = F.pad(mask, (0, pad_len), "constant", 0)
684
+
685
+ hidden_states = attention(
686
+ hidden_states,
687
+ attention_mask=None if not attention.use_tpu_flash_attention else mask,
688
+ )
689
+
690
+ if attention.use_tpu_flash_attention:
691
+ # Remove the padding
692
+ if pad_len > 0:
693
+ hidden_states = hidden_states[:, :-pad_len, :]
694
+
695
+ # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
696
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
697
+ batch_size, channel, frames, height, width
698
+ )
699
+ return hidden_states
700
+
701
+
702
  class UNetMidBlock3D(nn.Module):
703
  """
704
  A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.