salad-demo / salad /models /language_phase1.py
DveloperY0115's picture
init repo
801501a
import numpy as np
import torch
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
from salad.model_components.lstm import LSTM
from salad.models.phase1 import Phase1Model
from salad.utils import imageutil, nputil, visutil
from salad.utils.spaghetti_util import (clip_eigenvalues,
generate_zc_from_sj_gaus,
get_mesh_from_spaghetti, load_mesher,
load_spaghetti, project_eigenvectors)
from salad.utils.train_util import get_dropout_mask
from salad.data.dataset import LangSALADDataset
class LangPhase1Model(Phase1Model):
def __init__(self, network, variance_schedule, **kwargs):
super().__init__(network, variance_schedule, **kwargs)
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
if self.hparams.get("use_lstm"):
self.bertmodel = LSTM(
text_dim=768, embedding_dim=768, vocab_size=30522, padding_idx=0
)
else:
self.bertmodel = BertModel.from_pretrained("bert-base-uncased")
if self.hparams.get("text_encoder_freeze"):
for p in self.bertmodel.parameters():
p.requires_grad_(False)
def forward(self, x, text):
"""
Input:
x: [B,G,16]
text: list of length [B]
"""
B, G = x.shape[:2]
text = self.random_mask_text(text)
lang_emb = self.text_to_embedding(text)
return self.get_loss(x, lang_emb)
def tokenizing(self, text):
tokenized = self.tokenizer(
text, return_tensors="pt", padding=True, truncation=True
).to(self.device)
return tokenized
def text_to_embedding(self, text):
"""
text: list of length [B]
return [B,768]
"""
tokenized = self.tokenizing(text)
if self.hparams.get("use_lstm"):
lang_emb, _ = self.bertmodel(tokenized.input_ids)
else:
if self.hparams.get("text_encoder_return_seq"):
lang_emb = self.bertmodel(**tokenized).last_hidden_state
else:
lang_emb = self.bertmodel(**tokenized).pooler_output
if lang_emb.ndim == 2:
lang_emb = lang_emb.unsqueeze(1)
return lang_emb
def random_mask_text(self, text):
text = list(text)
B = len(text)
if self.hparams.get("classifier_free_guidance"):
random_dp_mask = get_dropout_mask(
B, self.hparams.conditioning_dropout_prob, self.device
)
for i in range(B):
if random_dp_mask[i] == 0:
text[i] = ""
return text
def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None):
B, G, D = x0.shape
if not noisy_in:
if t is None:
t = self.var_sched.uniform_sample_t(B)
x_noisy, beta, e_rand = self.add_noise(x0, t)
else:
x_noisy = x0
beta = beta_in
e_rand = e_rand_in
e_theta = self.net(x_noisy, beta, cond)
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
return loss
def step(self, batch, stage: str):
x, text = batch
loss = self(x, text)
self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
return loss
@torch.no_grad()
def sample(
self,
num_samples_or_text,
return_traj=False,
return_cond=False,
classifier_free_guidance=True,
free_guidance_weight=2.0,
):
if isinstance(num_samples_or_text, str):
num_samples_or_text = [num_samples_or_text]
if isinstance(num_samples_or_text, int):
batch_size = num_samples_or_text
ds = self._build_dataset("val")
texts = [ds[i][1] for i in range(batch_size)]
elif isinstance(num_samples_or_text, list):
texts = num_samples_or_text
batch_size = len(num_samples_or_text)
if self.hparams.get("use_zc"):
x_T = torch.randn([batch_size, 16, 512]).to(self.device)
else:
x_T = torch.randn([batch_size, 16, 16]).to(self.device)
G = x_T.shape[1]
lang_emb = self.text_to_embedding(texts)
if classifier_free_guidance:
null_texts = ["" for _ in range(batch_size)]
null_lang_emb = self.text_to_embedding(null_texts)
traj = {self.var_sched.num_steps: x_T}
for t in range(self.var_sched.num_steps, 0, -1):
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
alpha = self.var_sched.alphas[t]
alpha_bar = self.var_sched.alpha_bars[t]
sigma = self.var_sched.get_sigmas(t, flexibility=0)
c0 = 1.0 / torch.sqrt(alpha)
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
x_t = traj[t]
beta = self.var_sched.betas[[t] * batch_size]
e_theta = self.net(x_t, beta=beta, context=lang_emb)
if classifier_free_guidance:
null_e_theta = self.net(x_t, beta=beta, context=null_lang_emb)
w = free_guidance_weight
e_theta = (1 + w) * e_theta - w * null_e_theta
x_next = c0 * (x_t - c1 * e_theta) + sigma * z
traj[t - 1] = x_next.detach()
traj[t] = traj[t].cpu()
if not return_traj:
del traj[t]
if return_traj:
if return_cond:
return traj, lang_emb
return traj
else:
if return_cond:
return traj[0], lang_emb
return traj[0]
def sampling_gaussians(
self,
num_samples_or_text,
classifier_free_guidance=True,
free_guidance_weight=2.0,
return_cond=False,
):
gaus = self.sample(
num_samples_or_text,
classifier_free_guidance=classifier_free_guidance,
free_guidance_weight=free_guidance_weight,
return_cond=return_cond,
)
if isinstance(gaus, tuple):
text = gaus[1]
gaus = gaus[0]
# gaus = reflect_and_concat_gmms(raw_gaus)
if self.hparams.get("global_normalization"):
if not hasattr(self, "data_val"):
self._build_dataset("val")
if self.hparams.get("global_normalization") == "partial":
gaus = self.data_val.unnormalize_global_static(gaus, slice(12, None))
elif self.hparams.get("global_normalization") == "all":
gaus = self.data_val.unnormalize_global_static(gaus, slice(None))
gaus = project_eigenvectors(clip_eigenvalues(gaus))
if return_cond:
return gaus, text
return gaus
def _build_dataset(self, stage):
if hasattr(self, f"data_{stage}"):
return getattr(self, f"data_{stage}")
ds_class = (
LangSALADDataset
)
if stage == "train":
ds = ds_class(**self.hparams.dataset_kwargs)
else:
dataset_kwargs = self.hparams.dataset_kwargs.copy()
dataset_kwargs["repeat"] = 1
ds = ds_class(**dataset_kwargs)
setattr(self, f"data_{stage}", ds)
return ds
def validation_zc(self):
vis_num_shapes = 4
vis_zcs = []
vis_texts = []
ds = self._build_dataset("val")
for i in [0, 1, 2, 3]:
zcs, text = ds[i]
vis_zcs.append(zcs)
vis_texts.append(text)
vis_zcs = torch.stack(vis_zcs, 0)
ldm_zcs = self.sample(vis_texts)
if not hasattr(self, "spaghetti"):
self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
spaghetti = self.spaghetti
if not hasattr(self, "mesher"):
self.mesher = load_mesher(self.device)
mesher = self.mesher
wandb_logger = self.get_wandb_logger()
images = []
for i in range(vis_num_shapes):
try:
v, f = get_mesh_from_spaghetti(spaghetti, mesher, vis_zcs[i], res=128)
gt_img = visutil.render_mesh(v, f, resolution=(256, 256))
except:
pass
try:
v, f = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i], res=128)
pred_img = visutil.render_mesh(v, f, resolution=(256, 256))
except:
pass
img = imageutil.merge_images([gt_img, pred_img])
img = imageutil.draw_text(
img,
f"Left: GT | Right: Pred \n{vis_texts[i]}",
font_size=14,
max_seq_length=50,
)
images.append([img])
images = imageutil.merge_images(images)
wandb_logger.log_image("vis", [images])
def validation(self):
if self.hparams.get("use_zc"):
self.validation_zc()
return
vis_num_shapes = 4
vis_gaus = []
vis_texts = []
ds = self._build_dataset("val")
vis_indices = [18453, 13036, 13204, 48244]
for i in vis_indices:
gaus, text = ds[i]
vis_gaus.append(gaus)
vis_texts.append(text)
vis_gaus = torch.stack(vis_gaus, 0)
if self.hparams.get("global_normalization"):
if self.hparams.get("global_normalization") == "partial":
vis_gaus = self.data_val.unnormalize_global_static(
vis_gaus, slice(12, None)
)
elif self.hparams.get("global_normalization") == "all":
vis_gaus = self.dataval.unnormalize_global_static(vis_gaus, slice(None))
# vis_gaus = reflect_and_concat_gmms(vis_gaus)
pred_gaus = self.sampling_gaussians(vis_texts)
if not hasattr(self, "spaghetti"):
self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
spaghetti = self.spaghetti
if not hasattr(self, "mesher"):
self.mesher = load_mesher(self.device)
mesher = self.mesher
""" get intrinsics """
# TODO change the ckpt path.
if not hasattr(self, "phase2_model"):
phase2_ckpt = "/home/juil/pvddir/results/phase2/augment_final_0214/0214_202607/checkpoints/epoch=4999-val_loss=0.0000.ckpt"
self.phase2_model = SpaghettiConditionSALDM.load_from_checkpoint(
phase2_ckpt, strict=False
).to(self.device)
self.phase2_model.eval()
for p in self.phase2_model.parameters():
p.requires_grad_(False)
phase2_model = self.phase2_model
gt_sj = phase2_model.sample(vis_gaus)
pred_sj = phase2_model.sample(pred_gaus)
gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_sj, vis_gaus)
pred_zcs = generate_zc_from_sj_gaus(spaghetti, pred_sj, pred_gaus)
wandb_logger = self.get_wandb_logger()
images = []
for i in range(vis_num_shapes):
gt_img = visutil.render_gaussians(vis_gaus[i], resolution=(256, 256))
try:
v, f = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
gt_mesh_img = visutil.render_mesh(v, f, resolution=(256, 256))
gt_img = imageutil.merge_images([gt_img, gt_mesh_img])
except:
pass
pred_img = visutil.render_gaussians(pred_gaus[i], resolution=(256, 256))
try:
v, f = get_mesh_from_spaghetti(spaghetti, mesher, pred_zcs[i], res=128)
pred_mesh_img = visutil.render_mesh(v, f, resolution=(256, 256))
pred_img = imageutil.merge_images([pred_img, pred_mesh_img])
except:
pass
img = imageutil.merge_images([gt_img, pred_img])
img = imageutil.draw_text(
img,
f"Left: GT | Right: Pred \n{vis_texts[i]}",
font_size=14,
max_seq_length=50,
)
images.append([img])
images = imageutil.merge_images(images)
wandb_logger.log_image("vis", [images])