|
import torch |
|
from torch.nn import functional as F |
|
from .dit import DiffusionTransformer |
|
from .adp import UNet1d |
|
from .sampling import sample |
|
import math |
|
from model.base import BaseModule |
|
import pdb |
|
|
|
target_length = 1536 |
|
def pad_and_create_mask(matrix, target_length): |
|
|
|
T = matrix.shape[2] |
|
if T > target_length: |
|
raise ValueError("The third dimension length %s should not exceed %s"%(T, target_length)) |
|
|
|
padding_size = target_length - T |
|
|
|
padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0) |
|
|
|
mask = torch.ones((1, target_length)) |
|
mask[:, T:] = 0 |
|
|
|
return padded_matrix.to(matrix.device), mask.to(matrix.device) |
|
|
|
|
|
class Stable_Diffusion(BaseModule): |
|
def __init__(self): |
|
super(Stable_Diffusion, self).__init__() |
|
self.diffusion = DiffusionTransformer( |
|
io_channels=80, |
|
|
|
embed_dim=768, |
|
|
|
depth=24, |
|
num_heads=24, |
|
project_cond_tokens=False, |
|
transformer_type="continuous_transformer", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.rng = torch.quasirandom.SobolEngine(1, scramble=True) |
|
|
|
@torch.no_grad() |
|
def forward(self, mu, mask, n_timesteps): |
|
|
|
mask = mask.squeeze(1) |
|
|
|
|
|
|
|
extra_args = {"mask": mask} |
|
fakes = sample(self.diffusion, mu, n_timesteps, 0, **extra_args) |
|
|
|
return fakes |
|
|
|
|
|
def compute_loss(self, x0, mask, mu): |
|
|
|
|
|
t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device) |
|
alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) |
|
|
|
alphas = alphas[:, None, None] |
|
sigmas = sigmas[:, None, None] |
|
noise = torch.randn_like(x0) |
|
noised_inputs = x0 * alphas + noise * sigmas |
|
targets = mu * alphas - x0 * sigmas |
|
mask = mask.squeeze(1) |
|
|
|
|
|
|
|
output = self.diffusion(noised_inputs, t, mask=mask, cfg_dropout_prob=0.1) |
|
|
|
return self.mse_loss(output, targets, mask), output |
|
|
|
|
|
def mse_loss(self, output, targets, mask): |
|
|
|
mse_loss = F.mse_loss(output, targets, reduction='none') |
|
|
|
if mask.ndim == 2 and mse_loss.ndim == 3: |
|
mask = mask.unsqueeze(1) |
|
|
|
if mask.shape[1] != mse_loss.shape[1]: |
|
mask = mask.repeat(1, mse_loss.shape[1], 1) |
|
|
|
mse_loss = mse_loss[mask] |
|
|
|
mse_loss = mse_loss.mean() |
|
|
|
return mse_loss |