Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dotmap import DotMap | |
from salad.model_components.simple_module import TimePointWiseEncoder, TimestepEmbedder | |
from salad.model_components.transformer import ( | |
PositionalEncoding, | |
TimeTransformerDecoder, | |
TimeTransformerEncoder, | |
) | |
class UnCondDiffNetwork(nn.Module): | |
def __init__(self, input_dim, residual, **kwargs): | |
""" | |
Transformer Encoder. | |
""" | |
super().__init__() | |
self.input_dim = input_dim | |
self.residual = residual | |
self.__dict__.update(kwargs) | |
self.hparams = DotMap(self.__dict__) | |
self._build_model() | |
def _build_model(self): | |
self.act = F.leaky_relu | |
if self.hparams.get("use_timestep_embedder"): | |
self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim) | |
dim_ctx = self.hparams.timestep_embedder_dim | |
else: | |
dim_ctx = 3 | |
""" | |
Encoder part | |
""" | |
enc_dim = self.hparams.embedding_dim | |
self.embedding = nn.Linear(self.hparams.input_dim, enc_dim) | |
if not self.hparams.get("encoder_type"): | |
self.encoder = TimeTransformerEncoder( | |
enc_dim, | |
dim_ctx=dim_ctx, | |
num_heads=self.hparams.num_heads | |
if self.hparams.get("num_heads") | |
else 4, | |
use_time=True, | |
num_layers=self.hparams.enc_num_layers, | |
last_fc=True, | |
last_fc_dim_out=self.hparams.input_dim, | |
) | |
else: | |
if self.hparams.encoder_type == "transformer": | |
self.encoder = TimeTransformerEncoder( | |
enc_dim, | |
dim_ctx=dim_ctx, | |
num_heads=self.hparams.num_heads | |
if self.hparams.get("num_heads") | |
else 4, | |
use_time=True, | |
num_layers=self.hparams.enc_num_layers, | |
last_fc=True, | |
last_fc_dim_out=self.hparams.input_dim, | |
dropout=self.hparams.get("attn_dropout", 0.0) | |
) | |
else: | |
raise ValueError | |
def forward(self, x, beta): | |
""" | |
Input: | |
x: [B,G,D] latent | |
beta: B | |
Output: | |
eta: [B,G,D] | |
""" | |
B, G = x.shape[:2] | |
if self.hparams.get("use_timestep_embedder"): | |
time_emb = self.time_embedder(beta).unsqueeze(1) | |
else: | |
beta = beta.view(B, 1, 1) | |
time_emb = torch.cat( | |
[beta, torch.sin(beta), torch.cos(beta)], dim=-1 | |
) # [B,1,3] | |
ctx = time_emb | |
x_emb = self.embedding(x) | |
out = self.encoder(x_emb, ctx=ctx) | |
if self.hparams.residual: | |
out = out + x | |
return out | |
class CondDiffNetwork(nn.Module): | |
def __init__(self, input_dim, residual, **kwargs): | |
""" | |
Transformer Encoder + Decoder. | |
""" | |
super().__init__() | |
self.input_dim = input_dim | |
self.residual = residual | |
self.__dict__.update(kwargs) | |
self.hparams = DotMap(self.__dict__) | |
self._build_model() | |
def _build_model(self): | |
self.act = F.leaky_relu | |
if self.hparams.get("use_timestep_embedder"): | |
self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim) | |
dim_ctx = self.hparams.timestep_embedder_dim | |
else: | |
dim_ctx = 3 | |
""" | |
Encoder part | |
""" | |
enc_dim = self.hparams.context_embedding_dim | |
self.context_embedding = nn.Linear(self.hparams.context_dim, enc_dim) | |
if self.hparams.encoder_type == "transformer": | |
self.encoder = TimeTransformerEncoder( | |
enc_dim, | |
3, | |
num_heads=4, | |
use_time=self.hparams.encoder_use_time, | |
num_layers=self.hparams.enc_num_layers | |
if self.hparams.get("enc_num_layers") | |
else 3, | |
last_fc=False, | |
) | |
elif self.hparams.encoder_type == "pointwise": | |
self.encoder = TimePointWiseEncoder( | |
enc_dim, | |
dim_ctx=None, | |
use_time=self.hparams.encoder_use_time, | |
num_layers=self.hparams.enc_num_layers, | |
) | |
else: | |
raise ValueError | |
""" | |
Decoder part | |
""" | |
dec_dim = self.hparams.embedding_dim | |
input_dim = self.hparams.input_dim | |
self.query_embedding = nn.Linear(self.hparams.input_dim, dec_dim) | |
if self.hparams.decoder_type == "transformer_decoder": | |
self.decoder = TimeTransformerDecoder( | |
dec_dim, | |
enc_dim, | |
dim_ctx=dim_ctx, | |
num_heads=4, | |
last_fc=True, | |
last_fc_dim_out=input_dim, | |
num_layers=self.hparams.dec_num_layers | |
if self.hparams.get("dec_num_layers") | |
else 3, | |
) | |
elif self.hparams.decoder_type == "transformer_encoder": | |
self.decoder = TimeTransformerEncoder( | |
dec_dim, | |
dim_ctx=enc_dim + dim_ctx, | |
num_heads=4, | |
last_fc=True, | |
last_fc_dim_out=input_dim, | |
num_layers=self.hparams.dec_num_layers | |
if self.hparams.get("dec_num_layers") | |
else 3, | |
) | |
else: | |
raise ValueError | |
def forward(self, x, beta, context): | |
""" | |
Input: | |
x: [B,G,D] intrinsic | |
beta: B | |
context: [B,G,D2] or [B, D2] condition | |
Output: | |
eta: [B,G,D] | |
""" | |
# print(f"x: {x.shape} context: {context.shape} beta: {beta.shape}") | |
B, G = x.shape[:2] | |
if self.hparams.get("use_timestep_embedder"): | |
time_emb = self.time_embedder(beta).unsqueeze(1) | |
else: | |
beta = beta.view(B, 1, 1) | |
time_emb = torch.cat( | |
[beta, torch.sin(beta), torch.cos(beta)], dim=-1 | |
) # [B,1,3] | |
ctx = time_emb | |
""" | |
Encoding | |
""" | |
cout = self.context_embedding(context) | |
cout = self.encoder(cout, ctx=ctx if self.hparams.encoder_use_time else None) | |
if cout.ndim == 2: | |
cout = cout.unsqueeze(1).expand(-1, G, -1) | |
""" | |
Decoding | |
""" | |
out = self.query_embedding(x) | |
if self.hparams.get("use_pos_encoding"): | |
out = self.pos_encoding(out) | |
if self.hparams.decoder_type == "transformer_encoder": | |
try: | |
ctx = ctx.expand(-1, G, -1) | |
if cout.ndim == 2: | |
cout = cout.unsqueeze(1) | |
cout = cout.expand(-1, G, -1) | |
ctx = torch.cat([ctx, cout], -1) | |
except Exception as e: | |
print(e, G, ctx.shape, cout.shape) | |
out = self.decoder(out, ctx=ctx) | |
else: | |
out = self.decoder(out, cout, ctx=ctx) | |
# if hasattr(self, "last_fc"): | |
# out = self.last_fc(out) | |
if self.hparams.residual: | |
out = out + x | |
return out | |