import numpy as np import torch import torch.nn.functional as F from transformers import BertModel, BertTokenizer from salad.model_components.lstm import LSTM from salad.models.language_phase1 import LangPhase1Model from salad.utils import imageutil, nputil, visutil from salad.utils.spaghetti_util import (generate_zc_from_sj_gaus, get_mesh_from_spaghetti, load_mesher, load_spaghetti) from salad.utils.train_util import get_dropout_mask class LangPhase2Model(LangPhase1Model): def __init__(self, network, variance_schedule, **kwargs): super().__init__(network, variance_schedule, **kwargs) def random_mask_gaus_text(self, gaus, text): if self.hparams.get("classifier_free_guidance"): text = list(text) B = gaus.shape[0] random_dp_mask = get_dropout_mask( B, self.hparams.conditioning_dropout_prob, self.device ) gaus = gaus * random_dp_mask.unsqueeze(1).unsqueeze(2) for i in range(B): if random_dp_mask[i] == 0: text[i] = "" return gaus, text def forward(self, x, gaus, text): """ Input: x: [B,G,512] gaus: [B,G,16] text: list of [B] """ B, G = x.shape[:2] gaus, text = self.random_mask_gaus_text(gaus, text) lang_emb = self.text_to_embedding(text) cond = self.cond_from_gaus_lang_f(gaus, lang_emb) return self.get_loss(x, cond) def step(self, batch, stage): x, gaus, text = batch loss = self(x, gaus, text) self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True) return loss def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None): B, G, D = x0.shape if not noisy_in: if t is None: t = self.var_sched.uniform_sample_t(B) x_noisy, beta, e_rand = self.add_noise(x0, t) else: x_noisy = x0 beta = beta_in e_rand = e_rand_in e_theta = self.net(x_noisy, beta, cond) loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean") return loss def cond_from_gaus_lang_f(self, gaus, lang_f): gaus = nputil.np2th(gaus).to(self.device) G = gaus.shape[1] lang_f = nputil.np2th(lang_f).to(self.device) assert gaus.ndim == 3 if lang_f.ndim == 2: lang_f = lang_f.unsqueeze(1) lang_f = lang_f.expand(-1, G, -1) return torch.cat([gaus, lang_f], -1) def generate_null_cond(self, B, G): text = ["" for _ in range(B)] lang_emb = self.text_to_embedding(text) gaus = torch.zeros(B, G, 16, dtype=torch.float, device=self.device) return self.cond_from_gaus_lang_f(gaus, lang_emb) @torch.no_grad() def sample( self, num_samples_or_cond, return_traj=False, return_cond=False, classifier_free_guidance=False, free_guidance_weight=0.7, ): if isinstance(num_samples_or_cond, int): batch_size = num_samples_or_cond ds = self._build_dataset("val") batch_gaus = [] batch_text = [] for i in range(batch_size): _, gaus, text = ds[i] batch_gaus.append(gaus) batch_text.append(text) batch_gaus = torch.stack(batch_gaus, 0) lang_emb = self.text_to_embedding(batch_text) cond = self.cond_from_gaus_lang_f(batch_gaus, lang_emb).to(self.device) elif isinstance(num_samples_or_cond, np.ndarray) or isinstance( num_samples_or_cond, torch.Tensor ): cond = nputil.np2th(num_samples_or_cond).to(self.device) batch_size = len(cond) G = cond.shape[1] if classifier_free_guidance: null_cond = self.generate_null_cond(batch_size, G) x_T = torch.randn([batch_size, 16, 512]).to(self.device) traj = {self.var_sched.num_steps: x_T} for t in range(self.var_sched.num_steps, 0, -1): z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) alpha = self.var_sched.alphas[t] alpha_bar = self.var_sched.alpha_bars[t] sigma = self.var_sched.get_sigmas(t, flexibility=0) c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) x_t = traj[t] beta = self.var_sched.betas[[t] * batch_size] e_theta = self.net(x_t, beta=beta, context=cond) if classifier_free_guidance: null_e_theta = self.net(x_t, beta=beta, context=null_cond) w = free_guidance_weight e_theta = (1 + w) * e_theta - w * null_e_theta x_next = c0 * (x_t - c1 * e_theta) + sigma * z traj[t - 1] = x_next.detach() traj[t] = traj[t].cpu() if not return_traj: del traj[t] if return_traj: if return_cond: return traj, cond return traj else: if return_cond: return traj[0], cond return traj[0] def validation(self): vis_num_shapes = 4 vis_gt_sj = [] vis_gaus = [] vis_texts = [] ds = self._build_dataset("val") vis_indices = [18453, 13036, 13204, 48244] for i in vis_indices: sj, gaus, text = ds[i] vis_gt_sj.append(sj) vis_gaus.append(gaus) vis_texts.append(text) vis_gt_sj = torch.stack(vis_gt_sj, 0) vis_gaus = torch.stack(vis_gaus, 0).to(self.device) vis_lang_f = self.text_to_embedding(vis_texts) vis_cond = self.cond_from_gaus_lang_f(vis_gaus, vis_lang_f) pred_sj = self.sample(vis_cond) if not hasattr(self, "spaghetti"): self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag) spaghetti = self.spaghetti if not hasattr(self, "mesher"): self.mesher = load_mesher(self.device) mesher = self.mesher gt_zcs = generate_zc_from_sj_gaus(spaghetti, vis_gt_sj, vis_gaus) pred_zcs = generate_zc_from_sj_gaus(spaghetti, pred_sj, vis_gaus) wandb_logger = self.get_wandb_logger() for i in range(vis_num_shapes): gaus_img = visutil.render_gaussians(vis_gaus[i], resolution=(256, 256)) vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128) gt_mesh_img = visutil.render_mesh(vert, face, resolution=(256, 256)) img = [gaus_img, gt_mesh_img] try: vert, face = get_mesh_from_spaghetti(spaghetti, mesher, pred_zcs[i]) pred_mesh_img = visutil.render_mesh(vert, face, resolution=(256, 256)) img.append(pred_mesh_img) except Exception as e: print(e) img = imageutil.merge_images(img) img = imageutil.draw_text( img, vis_texts[i], font_size=14, max_seq_length=50 ) wandb_logger.log_image("vis", [img])