import json import os from functools import partial from types import SimpleNamespace from typing import Any, Mapping, Optional, Tuple, Union, List import torch import numpy as np from einops import rearrange from torch import nn from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd from xora.models.autoencoders.pixel_norm import PixelNorm from xora.models.autoencoders.vae import AutoencoderKLWrapper class CausalVideoAutoencoder(AutoencoderKLWrapper): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs): config_local_path = pretrained_model_name_or_path / "config.json" config = cls.load_config(config_local_path, **kwargs) video_vae = cls.from_config(config) video_vae.to(kwargs["torch_dtype"]) model_local_path = pretrained_model_name_or_path / "autoencoder.pth" ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) video_vae.load_state_dict(ckpt_state_dict) statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json" if statistics_local_path.exists(): with open(statistics_local_path, "r") as file: data = json.load(file) transposed_data = list(zip(*data["data"])) data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)} video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) video_vae.register_buffer( "mean_of_means", data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"])) ) return video_vae @staticmethod def from_config(config): assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder" if isinstance(config["dims"], list): config["dims"] = tuple(config["dims"]) assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" double_z = config.get("double_z", True) latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none") use_quant_conv = config.get("use_quant_conv", True) if use_quant_conv and latent_log_var == "uniform": raise ValueError("uniform latent_log_var requires use_quant_conv=False") encoder = Encoder( dims=config["dims"], in_channels=config.get("in_channels", 3), out_channels=config["latent_channels"], blocks=config["blocks"], patch_size=config.get("patch_size", 1), latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), ) decoder = Decoder( dims=config["dims"], in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), blocks=config["blocks"], patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), ) dims = config["dims"] return CausalVideoAutoencoder( encoder=encoder, decoder=decoder, latent_channels=config["latent_channels"], dims=dims, use_quant_conv=use_quant_conv, ) @property def config(self): return SimpleNamespace( _class_name="CausalVideoAutoencoder", dims=self.dims, in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, out_channels=self.decoder.conv_out.out_channels // self.decoder.patch_size**2, latent_channels=self.decoder.conv_in.in_channels, blocks=self.encoder.blocks_desc, scaling_factor=1.0, norm_layer=self.encoder.norm_layer, patch_size=self.encoder.patch_size, latent_log_var=self.encoder.latent_log_var, use_quant_conv=self.use_quant_conv, causal_decoder=self.decoder.causal, ) @property def is_video_supported(self): """ Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. """ return self.dims != 2 @property def spatial_downscale_factor(self): return ( 2 ** len([block for block in self.encoder.blocks_desc if block[0] in ["compress_space", "compress_all"]]) * self.encoder.patch_size ) @property def temporal_downscale_factor(self): return 2 ** len([block for block in self.encoder.blocks_desc if block[0] in ["compress_time", "compress_all"]]) def to_json_string(self) -> str: import json return json.dumps(self.config.__dict__) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): model_keys = set(name for name, _ in self.named_parameters()) key_mapping = { ".resnets.": ".res_blocks.", "downsamplers.0": "downsample", "upsamplers.0": "upsample", } converted_state_dict = {} for key, value in state_dict.items(): for k, v in key_mapping.items(): key = key.replace(k, v) if "norm" in key and key not in model_keys: print(f"Removing key {key} from state_dict as it is not present in the model") continue converted_state_dict[key] = value super().load_state_dict(converted_state_dict, strict=strict) def last_layer(self): if hasattr(self.decoder, "conv_out"): if isinstance(self.decoder.conv_out, nn.Sequential): last_layer = self.decoder.conv_out[-1] else: last_layer = self.decoder.conv_out else: last_layer = self.decoder.layers[-1] return last_layer class Encoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. Args: dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): The number of dimensions to use in convolutions. in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): The blocks to use. Each block is a tuple of the block name and the number of layers. base_channels (`int`, *optional*, defaults to 128): The number of output channels for the first convolutional layer. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. patch_size (`int`, *optional*, defaults to 1): The patch size to use. Should be a power of 2. norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use. Can be either `group_norm` or `pixel_norm`. latent_log_var (`str`, *optional*, defaults to `per_channel`): The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. """ def __init__( self, dims: Union[int, Tuple[int, int]] = 3, in_channels: int = 3, out_channels: int = 3, blocks: List[Tuple[str, int]] = [("res_x", 1)], base_channels: int = 128, norm_num_groups: int = 32, patch_size: Union[int, Tuple[int]] = 1, norm_layer: str = "group_norm", # group_norm, pixel_norm latent_log_var: str = "per_channel", ): super().__init__() self.patch_size = patch_size self.norm_layer = norm_layer self.latent_channels = out_channels self.latent_log_var = latent_log_var self.blocks_desc = blocks in_channels = in_channels * patch_size**2 output_channel = base_channels self.conv_in = make_conv_nd( dims=dims, in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, padding=1, causal=True, ) self.down_blocks = nn.ModuleList([]) for block_name, num_layers in blocks: input_channel = output_channel if block_name == "res_x": block = UNetMidBlock3D( dims=dims, in_channels=input_channel, num_layers=num_layers, resnet_eps=1e-6, resnet_groups=norm_num_groups, norm_layer=norm_layer, ) elif block_name == "res_x_y": output_channel = 2 * output_channel block = ResnetBlock3D( dims=dims, in_channels=input_channel, out_channels=output_channel, eps=1e-6, groups=norm_num_groups, norm_layer=norm_layer, ) elif block_name == "compress_time": block = make_conv_nd( dims=dims, in_channels=input_channel, out_channels=output_channel, kernel_size=3, stride=(2, 1, 1), causal=True, ) elif block_name == "compress_space": block = make_conv_nd( dims=dims, in_channels=input_channel, out_channels=output_channel, kernel_size=3, stride=(1, 2, 2), causal=True, ) elif block_name == "compress_all": block = make_conv_nd( dims=dims, in_channels=input_channel, out_channels=output_channel, kernel_size=3, stride=(2, 2, 2), causal=True, ) else: raise ValueError(f"unknown block: {block_name}") self.down_blocks.append(block) # out if norm_layer == "group_norm": self.conv_norm_out = nn.GroupNorm(num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6) elif norm_layer == "pixel_norm": self.conv_norm_out = PixelNorm() elif norm_layer == "layer_norm": self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = out_channels if latent_log_var == "per_channel": conv_out_channels *= 2 elif latent_log_var == "uniform": conv_out_channels += 1 elif latent_log_var != "none": raise ValueError(f"Invalid latent_log_var: {latent_log_var}") self.conv_out = make_conv_nd(dims, output_channel, conv_out_channels, 3, padding=1, causal=True) self.gradient_checkpointing = False def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) sample = self.conv_in(sample) checkpoint_fn = ( partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) if self.gradient_checkpointing and self.training else lambda x: x ) for down_block in self.down_blocks: sample = checkpoint_fn(down_block)(sample) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if self.latent_log_var == "uniform": last_channel = sample[:, -1:, ...] num_dims = sample.dim() if num_dims == 4: # For shape (B, C, H, W) repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1) sample = torch.cat([sample, repeated_last_channel], dim=1) elif num_dims == 5: # For shape (B, C, F, H, W) repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1) sample = torch.cat([sample, repeated_last_channel], dim=1) else: raise ValueError(f"Invalid input shape: {sample.shape}") return sample class Decoder(nn.Module): r""" The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): The number of dimensions to use in convolutions. in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): The blocks to use. Each block is a tuple of the block name and the number of layers. base_channels (`int`, *optional*, defaults to 128): The number of output channels for the first convolutional layer. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. patch_size (`int`, *optional*, defaults to 1): The patch size to use. Should be a power of 2. norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use. Can be either `group_norm` or `pixel_norm`. causal (`bool`, *optional*, defaults to `True`): Whether to use causal convolutions or not. """ def __init__( self, dims, in_channels: int = 3, out_channels: int = 3, blocks: List[Tuple[str, int]] = [("res_x", 1)], base_channels: int = 128, layers_per_block: int = 2, norm_num_groups: int = 32, patch_size: int = 1, norm_layer: str = "group_norm", causal: bool = True, ): super().__init__() self.patch_size = patch_size self.layers_per_block = layers_per_block out_channels = out_channels * patch_size**2 num_channel_doubles = len([x for x in blocks if x[0] == "res_x_y"]) output_channel = base_channels * 2**num_channel_doubles self.causal = causal self.conv_in = make_conv_nd( dims, in_channels, output_channel, kernel_size=3, stride=1, padding=1, causal=True, ) self.up_blocks = nn.ModuleList([]) for block_name, num_layers in list(reversed(blocks)): input_channel = output_channel if block_name == "res_x": block = UNetMidBlock3D( dims=dims, in_channels=input_channel, num_layers=num_layers, resnet_eps=1e-6, resnet_groups=norm_num_groups, norm_layer=norm_layer, ) elif block_name == "res_x_y": output_channel = output_channel // 2 block = ResnetBlock3D( dims=dims, in_channels=input_channel, out_channels=output_channel, eps=1e-6, groups=norm_num_groups, norm_layer=norm_layer, ) elif block_name == "compress_time": block = DepthToSpaceUpsample(dims=dims, in_channels=input_channel, stride=(2, 1, 1)) elif block_name == "compress_space": block = DepthToSpaceUpsample(dims=dims, in_channels=input_channel, stride=(1, 2, 2)) elif block_name == "compress_all": block = DepthToSpaceUpsample(dims=dims, in_channels=input_channel, stride=(2, 2, 2)) else: raise ValueError(f"unknown layer: {block_name}") self.up_blocks.append(block) if norm_layer == "group_norm": self.conv_norm_out = nn.GroupNorm(num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6) elif norm_layer == "pixel_norm": self.conv_norm_out = PixelNorm() elif norm_layer == "layer_norm": self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = make_conv_nd(dims, output_channel, out_channels, 3, padding=1, causal=True) self.gradient_checkpointing = False def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" assert target_shape is not None, "target_shape must be provided" sample = self.conv_in(sample, causal=self.causal) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype checkpoint_fn = ( partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) if self.gradient_checkpointing and self.training else lambda x: x ) sample = sample.to(upscale_dtype) for up_block in self.up_blocks: sample = checkpoint_fn(up_block)(sample, causal=self.causal) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample, causal=self.causal) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) return sample class UNetMidBlock3D(nn.Module): """ A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. Args: in_channels (`int`): The number of input channels. dropout (`float`, *optional*, defaults to 0.0): The dropout rate. num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks. Returns: `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, height, width)`. """ def __init__( self, dims: Union[int, Tuple[int, int]], in_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_groups: int = 32, norm_layer: str = "group_norm", ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.res_blocks = nn.ModuleList( [ ResnetBlock3D( dims=dims, in_channels=in_channels, out_channels=in_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, norm_layer=norm_layer, ) for _ in range(num_layers) ] ) def forward(self, hidden_states: torch.FloatTensor, causal: bool = True) -> torch.FloatTensor: for resnet in self.res_blocks: hidden_states = resnet(hidden_states, causal=causal) return hidden_states class DepthToSpaceUpsample(nn.Module): def __init__(self, dims, in_channels, stride): super().__init__() self.stride = stride self.out_channels = np.prod(stride) * in_channels self.conv = make_conv_nd( dims=dims, in_channels=in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, causal=True, ) def forward(self, x, causal: bool = True): x = self.conv(x, causal=causal) x = rearrange( x, "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", p1=self.stride[0], p2=self.stride[1], p3=self.stride[2], ) if self.stride[0] == 2: x = x[:, :, 1:, :, :] return x class LayerNorm(nn.Module): def __init__(self, dim, eps, elementwise_affine=True) -> None: super().__init__() self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) def forward(self, x): x = rearrange(x, "b c d h w -> b d h w c") x = self.norm(x) x = rearrange(x, "b d h w c -> b c d h w") return x class ResnetBlock3D(nn.Module): r""" A Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv layer. If None, same as `in_channels`. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. """ def __init__( self, dims: Union[int, Tuple[int, int]], in_channels: int, out_channels: Optional[int] = None, conv_shortcut: bool = False, dropout: float = 0.0, groups: int = 32, eps: float = 1e-6, norm_layer: str = "group_norm", ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut if norm_layer == "group_norm": self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) elif norm_layer == "pixel_norm": self.norm1 = PixelNorm() elif norm_layer == "layer_norm": self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) self.non_linearity = nn.SiLU() self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1, causal=True) if norm_layer == "group_norm": self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) elif norm_layer == "pixel_norm": self.norm2 = PixelNorm() elif norm_layer == "layer_norm": self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) self.dropout = torch.nn.Dropout(dropout) self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1, causal=True) self.conv_shortcut = ( make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) if in_channels != out_channels else nn.Identity() ) self.norm3 = ( LayerNorm(in_channels, eps=eps, elementwise_affine=True) if in_channels != out_channels else nn.Identity() ) def forward( self, input_tensor: torch.FloatTensor, causal: bool = True, ) -> torch.FloatTensor: hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.non_linearity(hidden_states) hidden_states = self.conv1(hidden_states, causal=causal) hidden_states = self.norm2(hidden_states) hidden_states = self.non_linearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, causal=causal) input_tensor = self.norm3(input_tensor) input_tensor = self.conv_shortcut(input_tensor) output_tensor = input_tensor + hidden_states return output_tensor def patchify(x, patch_size_hw, patch_size_t=1): if patch_size_hw == 1 and patch_size_t == 1: return x if x.dim() == 4: x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) elif x.dim() == 5: x = rearrange(x, "b c (f p) (h q) (w r) -> b (c p r q) f h w", p=patch_size_t, q=patch_size_hw, r=patch_size_hw) else: raise ValueError(f"Invalid input shape: {x.shape}") return x def unpatchify(x, patch_size_hw, patch_size_t=1): if patch_size_hw == 1 and patch_size_t == 1: return x if x.dim() == 4: x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) elif x.dim() == 5: x = rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)", p=patch_size_t, q=patch_size_hw, r=patch_size_hw) return x def create_video_autoencoder_config( latent_channels: int = 64, ): config = { "_class_name": "CausalVideoAutoencoder", "dims": 3, # (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d "in_channels": 3, # Number of input color channels (e.g., RGB) "out_channels": 3, # Number of output color channels "latent_channels": latent_channels, # Number of channels in the latent space representation "blocks": [ ("res_x", 4), ("compress_space", 1), ("res_x_y", 1), ("res_x", 2), ("compress_all", 1), ("res_x", 3), ("compress_all", 1), ("res_x_y", 1), ("res_x", 2), ("compress_time", 1), ("res_x", 3), ("res_x", 3), ], "patch_size": 4, "latent_log_var": "uniform", "use_quant_conv": False, "norm_layer": "layer_norm", "causal_decoder": True, } return config def test_vae_patchify_unpatchify(): import torch x = torch.randn(2, 3, 8, 64, 64) x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) assert torch.allclose(x, x_unpatched) def demo_video_autoencoder_forward_backward(): # Configuration for the VideoAutoencoder config = create_video_autoencoder_config() # Instantiate the VideoAutoencoder with the specified configuration video_autoencoder = CausalVideoAutoencoder.from_config(config) print(video_autoencoder) video_autoencoder.eval() # Print the total number of parameters in the video autoencoder total_params = sum(p.numel() for p in video_autoencoder.parameters()) print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") # Create a mock input tensor simulating a batch of videos # Shape: (batch_size, channels, depth, height, width) # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame input_videos = torch.randn(2, 3, 17, 64, 64) # Forward pass: encode and decode the input videos latent = video_autoencoder.encode(input_videos).latent_dist.mode() print(f"input shape={input_videos.shape}") print(f"latent shape={latent.shape}") reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample print(f"reconstructed shape={reconstructed_videos.shape}") # Validate that single image gets treated the same way as first frame input_image = input_videos[:, :, :1, :, :] image_latent = video_autoencoder.encode(input_image).latent_dist.mode() reconstructed_image = video_autoencoder.decode(image_latent, target_shape=image_latent.shape).sample first_frame_latent = latent[:, :, :1, :, :] # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) assert (image_latent == first_frame_latent).all() assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() # Calculate the loss (e.g., mean squared error) loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) # Perform backward pass loss.backward() print(f"Demo completed with loss: {loss.item()}") # Ensure to call the demo function to execute the forward and backward pass if __name__ == "__main__": demo_video_autoencoder_forward_backward()