Decoder: add AttentionResBlocks block
Browse files1. Support attention block after residual block.
2. Add flash attention support.
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -9,10 +9,12 @@ import numpy as np
|
|
9 |
from einops import rearrange
|
10 |
from torch import nn
|
11 |
from diffusers.utils import logging
|
|
|
12 |
|
13 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
14 |
from xora.models.autoencoders.pixel_norm import PixelNorm
|
15 |
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
|
|
16 |
|
17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
18 |
|
@@ -212,6 +214,12 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
212 |
last_layer = self.decoder.layers[-1]
|
213 |
return last_layer
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
class Encoder(nn.Module):
|
217 |
r"""
|
@@ -483,6 +491,16 @@ class Decoder(nn.Module):
|
|
483 |
norm_layer=norm_layer,
|
484 |
inject_noise=block_params.get("inject_noise", False),
|
485 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
elif block_name == "res_x_y":
|
487 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
488 |
block = ResnetBlock3D(
|
@@ -558,6 +576,129 @@ class Decoder(nn.Module):
|
|
558 |
return sample
|
559 |
|
560 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
561 |
class UNetMidBlock3D(nn.Module):
|
562 |
"""
|
563 |
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
|
|
9 |
from einops import rearrange
|
10 |
from torch import nn
|
11 |
from diffusers.utils import logging
|
12 |
+
import torch.nn.functional as F
|
13 |
|
14 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
15 |
from xora.models.autoencoders.pixel_norm import PixelNorm
|
16 |
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
17 |
+
from xora.models.transformers.attention import Attention
|
18 |
|
19 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
20 |
|
|
|
214 |
last_layer = self.decoder.layers[-1]
|
215 |
return last_layer
|
216 |
|
217 |
+
def set_use_tpu_flash_attention(self):
|
218 |
+
for block in self.decoder.up_blocks:
|
219 |
+
if isinstance(block, AttentionResBlocks):
|
220 |
+
for attention_block in block.attention_blocks:
|
221 |
+
attention_block.set_use_tpu_flash_attention()
|
222 |
+
|
223 |
|
224 |
class Encoder(nn.Module):
|
225 |
r"""
|
|
|
491 |
norm_layer=norm_layer,
|
492 |
inject_noise=block_params.get("inject_noise", False),
|
493 |
)
|
494 |
+
elif block_name == "attn_res_x":
|
495 |
+
block = AttentionResBlocks(
|
496 |
+
dims=dims,
|
497 |
+
in_channels=input_channel,
|
498 |
+
num_layers=block_params["num_layers"],
|
499 |
+
resnet_groups=norm_num_groups,
|
500 |
+
norm_layer=norm_layer,
|
501 |
+
attention_head_dim=block_params["attention_head_dim"],
|
502 |
+
inject_noise=block_params.get("inject_noise", False),
|
503 |
+
)
|
504 |
elif block_name == "res_x_y":
|
505 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
506 |
block = ResnetBlock3D(
|
|
|
576 |
return sample
|
577 |
|
578 |
|
579 |
+
class AttentionResBlocks(nn.Module):
|
580 |
+
"""
|
581 |
+
A 3D convolution residual block followed by self attention residual block
|
582 |
+
|
583 |
+
Args:
|
584 |
+
dims (`int` or `Tuple[int, int]`): The number of dimensions to use in convolutions.
|
585 |
+
in_channels (`int`): The number of input channels.
|
586 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
587 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
588 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
589 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
590 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
591 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use.
|
592 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
|
593 |
+
inject_noise (`bool`, *optional*, defaults to `False`): Whether to inject noise or not between convolution layers.
|
594 |
+
|
595 |
+
Returns:
|
596 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
597 |
+
in_channels, height, width)`.
|
598 |
+
|
599 |
+
"""
|
600 |
+
|
601 |
+
def __init__(
|
602 |
+
self,
|
603 |
+
dims: Union[int, Tuple[int, int]],
|
604 |
+
in_channels: int,
|
605 |
+
dropout: float = 0.0,
|
606 |
+
num_layers: int = 1,
|
607 |
+
resnet_eps: float = 1e-6,
|
608 |
+
resnet_groups: int = 32,
|
609 |
+
norm_layer: str = "group_norm",
|
610 |
+
attention_head_dim: int = 64,
|
611 |
+
inject_noise: bool = False,
|
612 |
+
):
|
613 |
+
super().__init__()
|
614 |
+
|
615 |
+
if attention_head_dim > in_channels:
|
616 |
+
raise ValueError(
|
617 |
+
"attention_head_dim must be less than or equal to in_channels"
|
618 |
+
)
|
619 |
+
|
620 |
+
resnet_groups = (
|
621 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
622 |
+
)
|
623 |
+
|
624 |
+
self.res_blocks = []
|
625 |
+
self.attention_blocks = []
|
626 |
+
for i in range(num_layers):
|
627 |
+
self.res_blocks.append(
|
628 |
+
ResnetBlock3D(
|
629 |
+
dims=dims,
|
630 |
+
in_channels=in_channels,
|
631 |
+
out_channels=in_channels,
|
632 |
+
eps=resnet_eps,
|
633 |
+
groups=resnet_groups,
|
634 |
+
dropout=dropout,
|
635 |
+
norm_layer=norm_layer,
|
636 |
+
inject_noise=inject_noise,
|
637 |
+
)
|
638 |
+
)
|
639 |
+
self.attention_blocks.append(
|
640 |
+
Attention(
|
641 |
+
query_dim=in_channels,
|
642 |
+
heads=in_channels // attention_head_dim,
|
643 |
+
dim_head=attention_head_dim,
|
644 |
+
bias=True,
|
645 |
+
out_bias=True,
|
646 |
+
qk_norm="rms_norm",
|
647 |
+
residual_connection=True,
|
648 |
+
)
|
649 |
+
)
|
650 |
+
|
651 |
+
self.res_blocks = nn.ModuleList(self.res_blocks)
|
652 |
+
self.attention_blocks = nn.ModuleList(self.attention_blocks)
|
653 |
+
|
654 |
+
def forward(
|
655 |
+
self, hidden_states: torch.FloatTensor, causal: bool = True
|
656 |
+
) -> torch.FloatTensor:
|
657 |
+
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
|
658 |
+
hidden_states = resnet(hidden_states, causal=causal)
|
659 |
+
|
660 |
+
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
|
661 |
+
batch_size, channel, frames, height, width = hidden_states.shape
|
662 |
+
hidden_states = hidden_states.view(
|
663 |
+
batch_size, channel, frames * height * width
|
664 |
+
).transpose(1, 2)
|
665 |
+
|
666 |
+
if attention.use_tpu_flash_attention:
|
667 |
+
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
|
668 |
+
seq_len = hidden_states.shape[1]
|
669 |
+
block_k_major = 512
|
670 |
+
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
|
671 |
+
if pad_len > 0:
|
672 |
+
hidden_states = F.pad(
|
673 |
+
hidden_states, (0, 0, 0, pad_len), "constant", 0
|
674 |
+
)
|
675 |
+
|
676 |
+
# Create a mask with ones for the original sequence length and zeros for the padded indexes
|
677 |
+
mask = torch.ones(
|
678 |
+
(hidden_states.shape[0], seq_len),
|
679 |
+
device=hidden_states.device,
|
680 |
+
dtype=hidden_states.dtype,
|
681 |
+
)
|
682 |
+
if pad_len > 0:
|
683 |
+
mask = F.pad(mask, (0, pad_len), "constant", 0)
|
684 |
+
|
685 |
+
hidden_states = attention(
|
686 |
+
hidden_states,
|
687 |
+
attention_mask=None if not attention.use_tpu_flash_attention else mask,
|
688 |
+
)
|
689 |
+
|
690 |
+
if attention.use_tpu_flash_attention:
|
691 |
+
# Remove the padding
|
692 |
+
if pad_len > 0:
|
693 |
+
hidden_states = hidden_states[:, :-pad_len, :]
|
694 |
+
|
695 |
+
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
|
696 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
697 |
+
batch_size, channel, frames, height, width
|
698 |
+
)
|
699 |
+
return hidden_states
|
700 |
+
|
701 |
+
|
702 |
class UNetMidBlock3D(nn.Module):
|
703 |
"""
|
704 |
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|