Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from transformers import BertModel, BertTokenizer | |
from salad.model_components.lstm import LSTM | |
from salad.models.phase1 import Phase1Model | |
from salad.utils import imageutil, nputil, visutil | |
from salad.utils.spaghetti_util import (clip_eigenvalues, | |
generate_zc_from_sj_gaus, | |
get_mesh_from_spaghetti, load_mesher, | |
load_spaghetti, project_eigenvectors) | |
from salad.utils.train_util import get_dropout_mask | |
from salad.data.dataset import LangSALADDataset | |
class LangPhase1Model(Phase1Model): | |
def __init__(self, network, variance_schedule, **kwargs): | |
super().__init__(network, variance_schedule, **kwargs) | |
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
if self.hparams.get("use_lstm"): | |
self.bertmodel = LSTM( | |
text_dim=768, embedding_dim=768, vocab_size=30522, padding_idx=0 | |
) | |
else: | |
self.bertmodel = BertModel.from_pretrained("bert-base-uncased") | |
if self.hparams.get("text_encoder_freeze"): | |
for p in self.bertmodel.parameters(): | |
p.requires_grad_(False) | |
def forward(self, x, text): | |
""" | |
Input: | |
x: [B,G,16] | |
text: list of length [B] | |
""" | |
B, G = x.shape[:2] | |
text = self.random_mask_text(text) | |
lang_emb = self.text_to_embedding(text) | |
return self.get_loss(x, lang_emb) | |
def tokenizing(self, text): | |
tokenized = self.tokenizer( | |
text, return_tensors="pt", padding=True, truncation=True | |
).to(self.device) | |
return tokenized | |
def text_to_embedding(self, text): | |
""" | |
text: list of length [B] | |
return [B,768] | |
""" | |
tokenized = self.tokenizing(text) | |
if self.hparams.get("use_lstm"): | |
lang_emb, _ = self.bertmodel(tokenized.input_ids) | |
else: | |
if self.hparams.get("text_encoder_return_seq"): | |
lang_emb = self.bertmodel(**tokenized).last_hidden_state | |
else: | |
lang_emb = self.bertmodel(**tokenized).pooler_output | |
if lang_emb.ndim == 2: | |
lang_emb = lang_emb.unsqueeze(1) | |
return lang_emb | |
def random_mask_text(self, text): | |
text = list(text) | |
B = len(text) | |
if self.hparams.get("classifier_free_guidance"): | |
random_dp_mask = get_dropout_mask( | |
B, self.hparams.conditioning_dropout_prob, self.device | |
) | |
for i in range(B): | |
if random_dp_mask[i] == 0: | |
text[i] = "" | |
return text | |
def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None): | |
B, G, D = x0.shape | |
if not noisy_in: | |
if t is None: | |
t = self.var_sched.uniform_sample_t(B) | |
x_noisy, beta, e_rand = self.add_noise(x0, t) | |
else: | |
x_noisy = x0 | |
beta = beta_in | |
e_rand = e_rand_in | |
e_theta = self.net(x_noisy, beta, cond) | |
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean") | |
return loss | |
def step(self, batch, stage: str): | |
x, text = batch | |
loss = self(x, text) | |
self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True) | |
return loss | |
def sample( | |
self, | |
num_samples_or_text, | |
return_traj=False, | |
return_cond=False, | |
classifier_free_guidance=True, | |
free_guidance_weight=2.0, | |
): | |
if isinstance(num_samples_or_text, str): | |
num_samples_or_text = [num_samples_or_text] | |
if isinstance(num_samples_or_text, int): | |
batch_size = num_samples_or_text | |
ds = self._build_dataset("val") | |
texts = [ds[i][1] for i in range(batch_size)] | |
elif isinstance(num_samples_or_text, list): | |
texts = num_samples_or_text | |
batch_size = len(num_samples_or_text) | |
if self.hparams.get("use_zc"): | |
x_T = torch.randn([batch_size, 16, 512]).to(self.device) | |
else: | |
x_T = torch.randn([batch_size, 16, 16]).to(self.device) | |
G = x_T.shape[1] | |
lang_emb = self.text_to_embedding(texts) | |
if classifier_free_guidance: | |
null_texts = ["" for _ in range(batch_size)] | |
null_lang_emb = self.text_to_embedding(null_texts) | |
traj = {self.var_sched.num_steps: x_T} | |
for t in range(self.var_sched.num_steps, 0, -1): | |
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) | |
alpha = self.var_sched.alphas[t] | |
alpha_bar = self.var_sched.alpha_bars[t] | |
sigma = self.var_sched.get_sigmas(t, flexibility=0) | |
c0 = 1.0 / torch.sqrt(alpha) | |
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) | |
x_t = traj[t] | |
beta = self.var_sched.betas[[t] * batch_size] | |
e_theta = self.net(x_t, beta=beta, context=lang_emb) | |
if classifier_free_guidance: | |
null_e_theta = self.net(x_t, beta=beta, context=null_lang_emb) | |
w = free_guidance_weight | |
e_theta = (1 + w) * e_theta - w * null_e_theta | |
x_next = c0 * (x_t - c1 * e_theta) + sigma * z | |
traj[t - 1] = x_next.detach() | |
traj[t] = traj[t].cpu() | |
if not return_traj: | |
del traj[t] | |
if return_traj: | |
if return_cond: | |
return traj, lang_emb | |
return traj | |
else: | |
if return_cond: | |
return traj[0], lang_emb | |
return traj[0] | |
def sampling_gaussians( | |
self, | |
num_samples_or_text, | |
classifier_free_guidance=True, | |
free_guidance_weight=2.0, | |
return_cond=False, | |
): | |
gaus = self.sample( | |
num_samples_or_text, | |
classifier_free_guidance=classifier_free_guidance, | |
free_guidance_weight=free_guidance_weight, | |
return_cond=return_cond, | |
) | |
if isinstance(gaus, tuple): | |
text = gaus[1] | |
gaus = gaus[0] | |
# gaus = reflect_and_concat_gmms(raw_gaus) | |
if self.hparams.get("global_normalization"): | |
if not hasattr(self, "data_val"): | |
self._build_dataset("val") | |
if self.hparams.get("global_normalization") == "partial": | |
gaus = self.data_val.unnormalize_global_static(gaus, slice(12, None)) | |
elif self.hparams.get("global_normalization") == "all": | |
gaus = self.data_val.unnormalize_global_static(gaus, slice(None)) | |
gaus = project_eigenvectors(clip_eigenvalues(gaus)) | |
if return_cond: | |
return gaus, text | |
return gaus | |
def _build_dataset(self, stage): | |
if hasattr(self, f"data_{stage}"): | |
return getattr(self, f"data_{stage}") | |
ds_class = ( | |
LangSALADDataset | |
) | |
if stage == "train": | |
ds = ds_class(**self.hparams.dataset_kwargs) | |
else: | |
dataset_kwargs = self.hparams.dataset_kwargs.copy() | |
dataset_kwargs["repeat"] = 1 | |
ds = ds_class(**dataset_kwargs) | |
setattr(self, f"data_{stage}", ds) | |
return ds | |
def validation_zc(self): | |
vis_num_shapes = 4 | |
vis_zcs = [] | |
vis_texts = [] | |
ds = self._build_dataset("val") | |
for i in [0, 1, 2, 3]: | |
zcs, text = ds[i] | |
vis_zcs.append(zcs) | |
vis_texts.append(text) | |
vis_zcs = torch.stack(vis_zcs, 0) | |
ldm_zcs = self.sample(vis_texts) | |
if not hasattr(self, "spaghetti"): | |
self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag) | |
spaghetti = self.spaghetti | |
if not hasattr(self, "mesher"): | |
self.mesher = load_mesher(self.device) | |
mesher = self.mesher | |
wandb_logger = self.get_wandb_logger() | |
images = [] | |
for i in range(vis_num_shapes): | |
try: | |
v, f = get_mesh_from_spaghetti(spaghetti, mesher, vis_zcs[i], res=128) | |
gt_img = visutil.render_mesh(v, f, resolution=(256, 256)) | |
except: | |
pass | |
try: | |
v, f = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i], res=128) | |
pred_img = visutil.render_mesh(v, f, resolution=(256, 256)) | |
except: | |
pass | |
img = imageutil.merge_images([gt_img, pred_img]) | |
img = imageutil.draw_text( | |
img, | |
f"Left: GT | Right: Pred \n{vis_texts[i]}", | |
font_size=14, | |
max_seq_length=50, | |
) | |
images.append([img]) | |
images = imageutil.merge_images(images) | |
wandb_logger.log_image("vis", [images]) | |
def validation(self): | |
if self.hparams.get("use_zc"): | |
self.validation_zc() | |
return | |
vis_num_shapes = 4 | |
vis_gaus = [] | |
vis_texts = [] | |
ds = self._build_dataset("val") | |
vis_indices = [18453, 13036, 13204, 48244] | |
for i in vis_indices: | |
gaus, text = ds[i] | |
vis_gaus.append(gaus) | |
vis_texts.append(text) | |
vis_gaus = torch.stack(vis_gaus, 0) | |
if self.hparams.get("global_normalization"): | |
if self.hparams.get("global_normalization") == "partial": | |
vis_gaus = self.data_val.unnormalize_global_static( | |
vis_gaus, slice(12, None) | |
) | |
elif self.hparams.get("global_normalization") == "all": | |
vis_gaus = self.dataval.unnormalize_global_static(vis_gaus, slice(None)) | |
# vis_gaus = reflect_and_concat_gmms(vis_gaus) | |
pred_gaus = self.sampling_gaussians(vis_texts) | |
if not hasattr(self, "spaghetti"): | |
self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag) | |
spaghetti = self.spaghetti | |
if not hasattr(self, "mesher"): | |
self.mesher = load_mesher(self.device) | |
mesher = self.mesher | |
""" get intrinsics """ | |
# TODO change the ckpt path. | |
if not hasattr(self, "phase2_model"): | |
phase2_ckpt = "/home/juil/pvddir/results/phase2/augment_final_0214/0214_202607/checkpoints/epoch=4999-val_loss=0.0000.ckpt" | |
self.phase2_model = SpaghettiConditionSALDM.load_from_checkpoint( | |
phase2_ckpt, strict=False | |
).to(self.device) | |
self.phase2_model.eval() | |
for p in self.phase2_model.parameters(): | |
p.requires_grad_(False) | |
phase2_model = self.phase2_model | |
gt_sj = phase2_model.sample(vis_gaus) | |
pred_sj = phase2_model.sample(pred_gaus) | |
gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_sj, vis_gaus) | |
pred_zcs = generate_zc_from_sj_gaus(spaghetti, pred_sj, pred_gaus) | |
wandb_logger = self.get_wandb_logger() | |
images = [] | |
for i in range(vis_num_shapes): | |
gt_img = visutil.render_gaussians(vis_gaus[i], resolution=(256, 256)) | |
try: | |
v, f = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128) | |
gt_mesh_img = visutil.render_mesh(v, f, resolution=(256, 256)) | |
gt_img = imageutil.merge_images([gt_img, gt_mesh_img]) | |
except: | |
pass | |
pred_img = visutil.render_gaussians(pred_gaus[i], resolution=(256, 256)) | |
try: | |
v, f = get_mesh_from_spaghetti(spaghetti, mesher, pred_zcs[i], res=128) | |
pred_mesh_img = visutil.render_mesh(v, f, resolution=(256, 256)) | |
pred_img = imageutil.merge_images([pred_img, pred_mesh_img]) | |
except: | |
pass | |
img = imageutil.merge_images([gt_img, pred_img]) | |
img = imageutil.draw_text( | |
img, | |
f"Left: GT | Right: Pred \n{vis_texts[i]}", | |
font_size=14, | |
max_seq_length=50, | |
) | |
images.append([img]) | |
images = imageutil.merge_images(images) | |
wandb_logger.log_image("vis", [images]) | |