wlmov / pyramid_dit /modeling_embedding.py
multimodalart's picture
Upload 33 files
f0533a5 verified
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
import numpy as np
import math
from diffusers.models.activations import get_activation
from einops import rearrange
def get_1d_sincos_pos_embed(
embed_dim, num_frames, cls_token=False, extra_tokens=0,
):
t = np.arange(num_frames, dtype=np.float32)
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t) # (T, D)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
self.act = get_activation(act_fn)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = get_activation(act_fn)
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class CombinedTimestepConditionEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
class CombinedTimestepEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class PatchEmbed3D(nn.Module):
"""Support the 3D Tensor input"""
def __init__(
self,
height=128,
width=128,
patch_size=2,
in_channels=16,
embed_dim=1536,
layer_norm=False,
bias=True,
interpolation_scale=1,
pos_embed_type="sincos",
temp_pos_embed_type='rope',
pos_embed_max_size=192, # For SD3 cropping
max_num_frames=64,
add_temp_pos_embed=False,
interp_condition_pos=False,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.layer_norm = layer_norm
self.pos_embed_max_size = pos_embed_max_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
self.add_temp_pos_embed = add_temp_pos_embed
# Calculate positional embeddings based on max size or default
if pos_embed_max_size:
grid_size = pos_embed_max_size
else:
grid_size = int(num_patches**0.5)
if pos_embed_type is None:
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_2d_sincos_pos_embed(
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
persistent = True if pos_embed_max_size else False
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
if add_temp_pos_embed and temp_pos_embed_type == 'sincos':
time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames)
self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True)
elif pos_embed_type == "rope":
print("Using the rotary position embedding")
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
self.pos_embed_type = pos_embed_type
self.temp_pos_embed_type = temp_pos_embed_type
self.interp_condition_pos = interp_condition_pos
def cropped_pos_embed(self, height, width, ori_height, ori_width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
height = height // self.patch_size
width = width // self.patch_size
ori_height = ori_height // self.patch_size
ori_width = ori_width // self.patch_size
assert ori_height >= height, "The ori_height needs >= height"
assert ori_width >= width, "The ori_width needs >= width"
if height > self.pos_embed_max_size:
raise ValueError(
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if width > self.pos_embed_max_size:
raise ValueError(
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if self.interp_condition_pos:
top = (self.pos_embed_max_size - ori_height) // 2
left = (self.pos_embed_max_size - ori_width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :] # [b h w c]
if ori_height != height or ori_width != width:
spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2)
spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear')
spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1)
else:
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None):
if self.pos_embed_max_size is not None:
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
bs = latent.shape[0]
temp = latent.shape[2]
latent = rearrange(latent, 'b c t h w -> (b t) c h w')
latent = self.proj(latent)
latent = latent.flatten(2).transpose(1, 2) # (BT)CHW -> (BT)NC
if self.layer_norm:
latent = self.norm(latent)
if self.pos_embed_type == 'sincos':
# Spatial position embedding, Interpolate or crop positional embeddings as needed
if self.pos_embed_max_size:
pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width)
else:
raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop")
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos':
latent_dtype = latent.dtype
latent = latent + pos_embed
latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp)
latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :]
latent = latent.to(latent_dtype)
latent = rearrange(latent, '(b n) t c -> b t n c', b=bs)
else:
latent = (latent + pos_embed).to(latent.dtype)
latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
else:
assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding"
latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
return latent
def forward(self, latent):
"""
Arguments:
past_condition_latents (Torch.FloatTensor): The past latent during the generation
flatten_input (bool): True indicate flatten the latent into 1D sequence
"""
if isinstance(latent, list):
output_list = []
for latent_ in latent:
if not isinstance(latent_, list):
latent_ = [latent_]
output_latent = []
time_index = 0
ori_height, ori_width = latent_[-1].shape[-2:]
for each_latent in latent_:
hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width)
time_index += each_latent.shape[2]
hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c")
output_latent.append(hidden_state)
output_latent = torch.cat(output_latent, dim=1)
output_list.append(output_latent)
return output_list
else:
hidden_states = self.forward_func(latent)
hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c")
return hidden_states