Spaces:
Sleeping
Sleeping
File size: 7,324 Bytes
801501a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import numpy as np
import torch
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
from salad.model_components.lstm import LSTM
from salad.models.language_phase1 import LangPhase1Model
from salad.utils import imageutil, nputil, visutil
from salad.utils.spaghetti_util import (generate_zc_from_sj_gaus,
get_mesh_from_spaghetti, load_mesher,
load_spaghetti)
from salad.utils.train_util import get_dropout_mask
class LangPhase2Model(LangPhase1Model):
def __init__(self, network, variance_schedule, **kwargs):
super().__init__(network, variance_schedule, **kwargs)
def random_mask_gaus_text(self, gaus, text):
if self.hparams.get("classifier_free_guidance"):
text = list(text)
B = gaus.shape[0]
random_dp_mask = get_dropout_mask(
B, self.hparams.conditioning_dropout_prob, self.device
)
gaus = gaus * random_dp_mask.unsqueeze(1).unsqueeze(2)
for i in range(B):
if random_dp_mask[i] == 0:
text[i] = ""
return gaus, text
def forward(self, x, gaus, text):
"""
Input:
x: [B,G,512]
gaus: [B,G,16]
text: list of [B]
"""
B, G = x.shape[:2]
gaus, text = self.random_mask_gaus_text(gaus, text)
lang_emb = self.text_to_embedding(text)
cond = self.cond_from_gaus_lang_f(gaus, lang_emb)
return self.get_loss(x, cond)
def step(self, batch, stage):
x, gaus, text = batch
loss = self(x, gaus, text)
self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
return loss
def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None):
B, G, D = x0.shape
if not noisy_in:
if t is None:
t = self.var_sched.uniform_sample_t(B)
x_noisy, beta, e_rand = self.add_noise(x0, t)
else:
x_noisy = x0
beta = beta_in
e_rand = e_rand_in
e_theta = self.net(x_noisy, beta, cond)
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
return loss
def cond_from_gaus_lang_f(self, gaus, lang_f):
gaus = nputil.np2th(gaus).to(self.device)
G = gaus.shape[1]
lang_f = nputil.np2th(lang_f).to(self.device)
assert gaus.ndim == 3
if lang_f.ndim == 2:
lang_f = lang_f.unsqueeze(1)
lang_f = lang_f.expand(-1, G, -1)
return torch.cat([gaus, lang_f], -1)
def generate_null_cond(self, B, G):
text = ["" for _ in range(B)]
lang_emb = self.text_to_embedding(text)
gaus = torch.zeros(B, G, 16, dtype=torch.float, device=self.device)
return self.cond_from_gaus_lang_f(gaus, lang_emb)
@torch.no_grad()
def sample(
self,
num_samples_or_cond,
return_traj=False,
return_cond=False,
classifier_free_guidance=False,
free_guidance_weight=0.7,
):
if isinstance(num_samples_or_cond, int):
batch_size = num_samples_or_cond
ds = self._build_dataset("val")
batch_gaus = []
batch_text = []
for i in range(batch_size):
_, gaus, text = ds[i]
batch_gaus.append(gaus)
batch_text.append(text)
batch_gaus = torch.stack(batch_gaus, 0)
lang_emb = self.text_to_embedding(batch_text)
cond = self.cond_from_gaus_lang_f(batch_gaus, lang_emb).to(self.device)
elif isinstance(num_samples_or_cond, np.ndarray) or isinstance(
num_samples_or_cond, torch.Tensor
):
cond = nputil.np2th(num_samples_or_cond).to(self.device)
batch_size = len(cond)
G = cond.shape[1]
if classifier_free_guidance:
null_cond = self.generate_null_cond(batch_size, G)
x_T = torch.randn([batch_size, 16, 512]).to(self.device)
traj = {self.var_sched.num_steps: x_T}
for t in range(self.var_sched.num_steps, 0, -1):
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
alpha = self.var_sched.alphas[t]
alpha_bar = self.var_sched.alpha_bars[t]
sigma = self.var_sched.get_sigmas(t, flexibility=0)
c0 = 1.0 / torch.sqrt(alpha)
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
x_t = traj[t]
beta = self.var_sched.betas[[t] * batch_size]
e_theta = self.net(x_t, beta=beta, context=cond)
if classifier_free_guidance:
null_e_theta = self.net(x_t, beta=beta, context=null_cond)
w = free_guidance_weight
e_theta = (1 + w) * e_theta - w * null_e_theta
x_next = c0 * (x_t - c1 * e_theta) + sigma * z
traj[t - 1] = x_next.detach()
traj[t] = traj[t].cpu()
if not return_traj:
del traj[t]
if return_traj:
if return_cond:
return traj, cond
return traj
else:
if return_cond:
return traj[0], cond
return traj[0]
def validation(self):
vis_num_shapes = 4
vis_gt_sj = []
vis_gaus = []
vis_texts = []
ds = self._build_dataset("val")
vis_indices = [18453, 13036, 13204, 48244]
for i in vis_indices:
sj, gaus, text = ds[i]
vis_gt_sj.append(sj)
vis_gaus.append(gaus)
vis_texts.append(text)
vis_gt_sj = torch.stack(vis_gt_sj, 0)
vis_gaus = torch.stack(vis_gaus, 0).to(self.device)
vis_lang_f = self.text_to_embedding(vis_texts)
vis_cond = self.cond_from_gaus_lang_f(vis_gaus, vis_lang_f)
pred_sj = self.sample(vis_cond)
if not hasattr(self, "spaghetti"):
self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
spaghetti = self.spaghetti
if not hasattr(self, "mesher"):
self.mesher = load_mesher(self.device)
mesher = self.mesher
gt_zcs = generate_zc_from_sj_gaus(spaghetti, vis_gt_sj, vis_gaus)
pred_zcs = generate_zc_from_sj_gaus(spaghetti, pred_sj, vis_gaus)
wandb_logger = self.get_wandb_logger()
for i in range(vis_num_shapes):
gaus_img = visutil.render_gaussians(vis_gaus[i], resolution=(256, 256))
vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
gt_mesh_img = visutil.render_mesh(vert, face, resolution=(256, 256))
img = [gaus_img, gt_mesh_img]
try:
vert, face = get_mesh_from_spaghetti(spaghetti, mesher, pred_zcs[i])
pred_mesh_img = visutil.render_mesh(vert, face, resolution=(256, 256))
img.append(pred_mesh_img)
except Exception as e:
print(e)
img = imageutil.merge_images(img)
img = imageutil.draw_text(
img, vis_texts[i], font_size=14, max_seq_length=50
)
wandb_logger.log_image("vis", [img])
|