|
import typing as tp |
|
|
|
import torch |
|
|
|
from einops import rearrange |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from x_transformers import ContinuousTransformerWrapper, Encoder |
|
|
|
from .blocks import FourierFeatures |
|
from .transformer import ContinuousTransformer |
|
from model.stable import transformer_use_mask |
|
|
|
|
|
class DiffusionTransformerV2(nn.Module): |
|
def __init__(self, |
|
io_channels=32, |
|
patch_size=1, |
|
embed_dim=768, |
|
cond_token_dim=0, |
|
project_cond_tokens=True, |
|
global_cond_dim=0, |
|
project_global_cond=True, |
|
input_concat_dim=0, |
|
prepend_cond_dim=0, |
|
depth=12, |
|
num_heads=8, |
|
transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", |
|
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", |
|
**kwargs): |
|
|
|
super().__init__() |
|
d_model = embed_dim |
|
n_head = num_heads |
|
n_layers = depth |
|
encoder_layer = torch.nn.TransformerEncoderLayer(batch_first=True, |
|
norm_first=True, |
|
d_model=d_model, |
|
nhead=n_head) |
|
self.transformer = torch.nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
|
|
|
|
|
timestep_features_dim = 256 |
|
self.timestep_features = FourierFeatures(1, timestep_features_dim) |
|
self.to_timestep_embed = nn.Sequential( |
|
nn.Linear(timestep_features_dim, embed_dim, bias=True), |
|
nn.SiLU(), |
|
nn.Linear(embed_dim, embed_dim, bias=True), |
|
) |
|
|
|
|
|
def _forward( |
|
self, |
|
Xt_btd, |
|
t, |
|
mu_btd, |
|
): |
|
|
|
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) |
|
cated_input = torch.cat([t,mu,x_t]) |
|
|
|
|
|
if cross_attn_cond is not None: |
|
cross_attn_cond = self.to_cond_embed(cross_attn_cond) |
|
|
|
if global_embed is not None: |
|
|
|
global_embed = self.to_global_embed(global_embed) |
|
|
|
prepend_inputs = None |
|
prepend_mask = None |
|
prepend_length = 0 |
|
if prepend_cond is not None: |
|
|
|
prepend_cond = self.to_prepend_embed(prepend_cond) |
|
|
|
prepend_inputs = prepend_cond |
|
if prepend_cond_mask is not None: |
|
prepend_mask = prepend_cond_mask |
|
|
|
if input_concat_cond is not None: |
|
|
|
|
|
if input_concat_cond.shape[2] != x.shape[2]: |
|
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest') |
|
|
|
x = torch.cat([x, input_concat_cond], dim=1) |
|
|
|
|
|
try: |
|
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) |
|
except Exception as e: |
|
print("t.shape:", t.shape, "x.shape", x.shape) |
|
print("t:", t) |
|
raise e |
|
|
|
|
|
if global_embed is not None: |
|
global_embed = global_embed + timestep_embed |
|
else: |
|
global_embed = timestep_embed |
|
|
|
|
|
if self.global_cond_type == "prepend": |
|
if prepend_inputs is None: |
|
|
|
prepend_inputs = global_embed.unsqueeze(1) |
|
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) |
|
else: |
|
|
|
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) |
|
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], |
|
dim=1) |
|
|
|
prepend_length = prepend_inputs.shape[1] |
|
|
|
x = self.preprocess_conv(x) + x |
|
|
|
x = rearrange(x, "b c t -> b t c") |
|
|
|
extra_args = {} |
|
|
|
if self.global_cond_type == "adaLN": |
|
extra_args["global_cond"] = global_embed |
|
|
|
if self.patch_size > 1: |
|
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) |
|
|
|
if self.transformer_type == "x-transformers": |
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, |
|
context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, |
|
**extra_args, **kwargs) |
|
elif self.transformer_type in ["continuous_transformer", "continuous_transformer_with_mask"]: |
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, |
|
context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, |
|
return_info=return_info, **extra_args, **kwargs) |
|
|
|
if return_info: |
|
output, info = output |
|
elif self.transformer_type == "mm_transformer": |
|
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, |
|
**extra_args, **kwargs) |
|
|
|
output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:] |
|
|
|
if self.patch_size > 1: |
|
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) |
|
|
|
output = self.postprocess_conv(output) + output |
|
|
|
if return_info: |
|
return output, info |
|
|
|
return output |
|
|
|
def forward( |
|
self, |
|
x, |
|
t, |
|
cross_attn_cond=None, |
|
cross_attn_cond_mask=None, |
|
negative_cross_attn_cond=None, |
|
negative_cross_attn_mask=None, |
|
input_concat_cond=None, |
|
global_embed=None, |
|
negative_global_embed=None, |
|
prepend_cond=None, |
|
prepend_cond_mask=None, |
|
cfg_scale=1.0, |
|
cfg_dropout_prob=0.0, |
|
causal=False, |
|
scale_phi=0.0, |
|
mask=None, |
|
return_info=False, |
|
**kwargs): |
|
|
|
assert causal == False, "Causal mode is not supported for DiffusionTransformer" |
|
|
|
if cross_attn_cond_mask is not None: |
|
cross_attn_cond_mask = cross_attn_cond_mask.bool() |
|
|
|
cross_attn_cond_mask = None |
|
|
|
if prepend_cond_mask is not None: |
|
prepend_cond_mask = prepend_cond_mask.bool() |
|
|
|
|
|
if cfg_dropout_prob > 0.0: |
|
if cross_attn_cond is not None: |
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) |
|
dropout_mask = torch.bernoulli( |
|
torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to( |
|
torch.bool) |
|
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) |
|
|
|
if prepend_cond is not None: |
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) |
|
dropout_mask = torch.bernoulli( |
|
torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to( |
|
torch.bool) |
|
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) |
|
|
|
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): |
|
|
|
|
|
batch_inputs = torch.cat([x, x], dim=0) |
|
batch_timestep = torch.cat([t, t], dim=0) |
|
|
|
if global_embed is not None: |
|
batch_global_cond = torch.cat([global_embed, global_embed], dim=0) |
|
else: |
|
batch_global_cond = None |
|
|
|
if input_concat_cond is not None: |
|
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) |
|
else: |
|
batch_input_concat_cond = None |
|
|
|
batch_cond = None |
|
batch_cond_masks = None |
|
|
|
|
|
if cross_attn_cond is not None: |
|
|
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) |
|
|
|
|
|
if negative_cross_attn_cond is not None: |
|
|
|
|
|
if negative_cross_attn_mask is not None: |
|
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) |
|
|
|
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, |
|
null_embed) |
|
|
|
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) |
|
|
|
else: |
|
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) |
|
|
|
if cross_attn_cond_mask is not None: |
|
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) |
|
|
|
batch_prepend_cond = None |
|
batch_prepend_cond_mask = None |
|
|
|
if prepend_cond is not None: |
|
|
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) |
|
|
|
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) |
|
|
|
if prepend_cond_mask is not None: |
|
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) |
|
|
|
if mask is not None: |
|
batch_masks = torch.cat([mask, mask], dim=0) |
|
else: |
|
batch_masks = None |
|
|
|
batch_output = self._forward( |
|
batch_inputs, |
|
batch_timestep, |
|
cross_attn_cond=batch_cond, |
|
cross_attn_cond_mask=batch_cond_masks, |
|
mask=batch_masks, |
|
input_concat_cond=batch_input_concat_cond, |
|
global_embed=batch_global_cond, |
|
prepend_cond=batch_prepend_cond, |
|
prepend_cond_mask=batch_prepend_cond_mask, |
|
return_info=return_info, |
|
**kwargs) |
|
|
|
if return_info: |
|
batch_output, info = batch_output |
|
|
|
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) |
|
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale |
|
|
|
|
|
if scale_phi != 0.0: |
|
cond_out_std = cond_output.std(dim=1, keepdim=True) |
|
out_cfg_std = cfg_output.std(dim=1, keepdim=True) |
|
output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output |
|
else: |
|
output = cfg_output |
|
|
|
if return_info: |
|
return output, info |
|
|
|
return output |
|
|
|
else: |
|
return self._forward( |
|
x, |
|
t, |
|
cross_attn_cond=cross_attn_cond, |
|
cross_attn_cond_mask=cross_attn_cond_mask, |
|
input_concat_cond=input_concat_cond, |
|
global_embed=global_embed, |
|
prepend_cond=prepend_cond, |
|
prepend_cond_mask=prepend_cond_mask, |
|
mask=mask, |
|
return_info=return_info, |
|
**kwargs |
|
) |
|
|