Spaces:
Running
on
A100
Running
on
A100
import torch | |
from torch import nn | |
from diffusers import AutoencoderKL | |
from einops import rearrange | |
from torch import Tensor | |
from torch.nn import functional | |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
class Downsample3D(nn.Module): | |
def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1): | |
super().__init__() | |
stride: int = 2 | |
self.padding = padding | |
self.in_channels = in_channels | |
self.dims = dims | |
self.conv = make_conv_nd( | |
dims=dims, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
) | |
def forward(self, x, downsample_in_time=True): | |
conv = self.conv | |
if self.padding == 0: | |
if self.dims == 2: | |
padding = (0, 1, 0, 1) | |
else: | |
padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) | |
x = functional.pad(x, padding, mode="constant", value=0) | |
if self.dims == (2, 1) and not downsample_in_time: | |
return conv(x, skip_time_conv=True) | |
return conv(x) | |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor: | |
""" | |
Encodes media items (images or videos) into latent representations using a specified VAE model. | |
The function supports processing batches of images or video frames and can handle the processing | |
in smaller sub-batches if needed. | |
Args: | |
media_items (Tensor): A torch Tensor containing the media items to encode. The expected | |
shape is (batch_size, channels, height, width) for images or (batch_size, channels, | |
frames, height, width) for videos. | |
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, | |
pre-configured and loaded with the appropriate model weights. | |
split_size (int, optional): The number of sub-batches to split the input batch into for encoding. | |
If set to more than 1, the input media items are processed in smaller batches according to | |
this value. Defaults to 1, which processes all items in a single batch. | |
Returns: | |
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted | |
to match the input shape, scaled by the model's configuration. | |
Examples: | |
>>> import torch | |
>>> from diffusers import AutoencoderKL | |
>>> vae = AutoencoderKL.from_pretrained('your-model-name') | |
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. | |
>>> latents = vae_encode(images, vae) | |
>>> print(latents.shape) # Output shape will depend on the model's latent configuration. | |
Note: | |
In case of a video, the function encodes the media item frame-by frame. | |
""" | |
is_video_shaped = media_items.dim() == 5 | |
batch_size, channels = media_items.shape[0:2] | |
if channels != 3: | |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.") | |
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)): | |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w") | |
if split_size > 1: | |
if len(media_items) % split_size != 0: | |
raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") | |
encode_bs = len(media_items) // split_size | |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] | |
latents = [] | |
for image_batch in media_items.split(encode_bs): | |
latents.append(vae.encode(image_batch).latent_dist.sample()) | |
latents = torch.cat(latents, dim=0) | |
else: | |
latents = vae.encode(media_items).latent_dist.sample() | |
latents = normalize_latents(latents, vae, vae_per_channel_normalize) | |
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)): | |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) | |
return latents | |
def vae_decode( | |
latents: Tensor, vae: AutoencoderKL, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize=False | |
) -> Tensor: | |
is_video_shaped = latents.dim() == 5 | |
batch_size = latents.shape[0] | |
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)): | |
latents = rearrange(latents, "b c n h w -> (b n) c h w") | |
if split_size > 1: | |
if len(latents) % split_size != 0: | |
raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") | |
encode_bs = len(latents) // split_size | |
image_batch = [ | |
_run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize) | |
for latent_batch in latents.split(encode_bs) | |
] | |
images = torch.cat(image_batch, dim=0) | |
else: | |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize) | |
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)): | |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) | |
return images | |
def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor: | |
if isinstance(vae, (CausalVideoAutoencoder)): | |
*_, fl, hl, wl = latents.shape | |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) | |
latents = latents.to(vae.dtype) | |
image = vae.decode( | |
un_normalize_latents(latents, vae, vae_per_channel_normalize), | |
return_dict=False, | |
target_shape=(1, 3, fl * temporal_scale if is_video else 1, hl * spatial_scale, wl * spatial_scale), | |
)[0] | |
else: | |
image = vae.decode( | |
un_normalize_latents(latents, vae, vae_per_channel_normalize), | |
return_dict=False, | |
)[0] | |
return image | |
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: | |
if isinstance(vae, CausalVideoAutoencoder): | |
spatial = vae.spatial_downscale_factor | |
temporal = vae.temporal_downscale_factor | |
else: | |
down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)]) | |
spatial = vae.config.patch_size * 2**down_blocks | |
temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae) else 1 | |
return (temporal, spatial, spatial) | |
def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: | |
return ( | |
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) | |
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) | |
if vae_per_channel_normalize | |
else latents * vae.config.scaling_factor | |
) | |
def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: | |
return ( | |
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) | |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) | |
if vae_per_channel_normalize | |
else latents / vae.config.scaling_factor | |
) | |