Spaces:
Sleeping
Sleeping
from typing import Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from salad.models.base_model import BaseModel | |
from salad.utils import imageutil, nputil, sysutil, thutil, visutil | |
from salad.utils.spaghetti_util import (clip_eigenvalues, | |
generate_zc_from_sj_gaus, | |
get_mesh_from_spaghetti, load_mesher, | |
load_spaghetti, project_eigenvectors) | |
class Phase2Model(BaseModel): | |
def __init__(self, network, variance_schedule, **kwargs): | |
super().__init__(network, variance_schedule, **kwargs) | |
def forward(self, x, cond): | |
return self.get_loss(x, cond) | |
def step(self, batch, stage: str): | |
x, cond = batch | |
loss = self(x, cond) | |
self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True) | |
return loss | |
def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None): | |
B, G, D = x0.shape | |
if not noisy_in: | |
if t is None: | |
t = self.var_sched.uniform_sample_t(B) | |
x_noisy, beta, e_rand = self.add_noise(x0, t) | |
else: | |
x_noisy = x0 | |
beta = beta_in | |
e_rand = e_rand_in | |
e_theta = self.net(x_noisy, beta, cond) | |
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean") | |
return loss | |
def sample( | |
self, | |
num_samples_or_gaus: Union[torch.Tensor, np.ndarray, int], | |
return_traj=False, | |
classifier_free_guidance=None, | |
free_guidance_weight=-0.7, | |
augment_condition_in_test=False, | |
return_cond=False, | |
): | |
if isinstance(num_samples_or_gaus, int): | |
batch_size = num_samples_or_gaus | |
ds = self._build_dataset("val") | |
cond = torch.stack([ds[i][1] for i in range(batch_size)], 0) | |
elif isinstance(num_samples_or_gaus, np.ndarray) or isinstance( | |
num_samples_or_gaus, torch.Tensor | |
): | |
cond = nputil.np2th(num_samples_or_gaus) | |
if cond.dim() == 2: | |
cond = cond[None] | |
batch_size = len(cond) | |
else: | |
raise ValueError( | |
"'num_samples_or_gaus' should be int, torch.Tensor or np.ndarray." | |
) | |
x_T = torch.randn([batch_size, 16, 512]).to(self.device) | |
cond = cond.to(self.device) | |
traj = {self.var_sched.num_steps: x_T} | |
for t in range(self.var_sched.num_steps, 0, -1): | |
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) | |
alpha = self.var_sched.alphas[t] | |
alpha_bar = self.var_sched.alpha_bars[t] | |
sigma = self.var_sched.get_sigmas(t, flexibility=0) | |
c0 = 1.0 / torch.sqrt(alpha) | |
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) | |
x_t = traj[t] | |
beta = self.var_sched.betas[[t] * batch_size] | |
e_theta = self.net(x_t, beta=beta, context=cond) | |
x_next = c0 * (x_t - c1 * e_theta) + sigma * z | |
traj[t - 1] = x_next.detach() | |
traj[t] = traj[t].cpu() | |
if not return_traj: | |
del traj[t] | |
if return_traj: | |
if return_cond: | |
return traj, cond | |
return traj | |
else: | |
if return_cond: | |
return traj[0], cond | |
return traj[0] | |
def validation(self): | |
latent_ds = self._build_dataset("val") | |
vis_num_shapes = 3 | |
num_variations = 3 | |
sysutil.clean_gpu() | |
if not hasattr(self, "spaghetti"): | |
spaghetti = load_spaghetti( | |
self.device, | |
self.hparams.spaghetti_tag | |
if self.hparams.get("spaghetti_tag") | |
else "chairs_large", | |
) | |
self.spaghetti = spaghetti | |
else: | |
spaghetti = self.spaghetti | |
if not hasattr(self, "mesher"): | |
mesher = load_mesher(self.device) | |
self.mesher = mesher | |
else: | |
mesher = self.mesher | |
"""======== Sampling ========""" | |
gt_zs = [] | |
gt_gaus = [] | |
gt_zs, gt_gaus = zip(*[latent_ds[i + 3] for i in range(vis_num_shapes)]) | |
gt_zs, gt_gaus = list(map(lambda x: torch.stack(x), [gt_zs, gt_gaus])) | |
if self.hparams.get("sj_global_normalization"): | |
gt_zs = thutil.th2np(gt_zs) | |
gt_zs = latent_ds.unnormalize_sj_global_static(gt_zs) | |
gt_zs = nputil.np2th(gt_zs).to(self.device) | |
gt_gaus_repeated = gt_gaus.repeat_interleave(num_variations, 0) | |
clean_ldm_zs, clean_gaus = self.sample(gt_gaus_repeated, return_cond=True) | |
clean_gaus = project_eigenvectors(clip_eigenvalues(clean_gaus)) | |
clean_zcs = generate_zc_from_sj_gaus(spaghetti, clean_ldm_zs, clean_gaus) | |
gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_zs, gt_gaus) | |
sysutil.clean_gpu() | |
"""==========================""" | |
""" Spaghetti Decoding """ | |
wandb_logger = self.get_wandb_logger() | |
resolution = (256, 256) | |
for i in range(vis_num_shapes): | |
img_per_shape = [] | |
gaus_img = visutil.render_gaussians(gt_gaus[i], resolution=resolution) | |
vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128) | |
gt_mesh_img = visutil.render_mesh(vert, face, resolution=resolution) | |
gt_img = imageutil.merge_images([gaus_img, gt_mesh_img]) | |
gt_img = imageutil.draw_text(gt_img, "GT", font_size=24) | |
img_per_shape.append(gt_img) | |
for j in range(num_variations): | |
try: | |
gaus_img = visutil.render_gaussians( | |
clean_gaus[i * num_variations + j], resolution=resolution | |
) | |
vert, face = get_mesh_from_spaghetti( | |
spaghetti, mesher, clean_zcs[i * num_variations + j], res=128 | |
) | |
mesh_img = visutil.render_mesh(vert, face, resolution=resolution) | |
pred_img = imageutil.merge_images([gaus_img, mesh_img]) | |
pred_img = imageutil.draw_text( | |
pred_img, f"{j}-th clean gaus", font_size=24 | |
) | |
img_per_shape.append(pred_img) | |
except Exception as e: | |
print(e) | |
try: | |
image = imageutil.merge_images(img_per_shape) | |
wandb_logger.log_image("visualization", [image]) | |
except Exception as e: | |
print(e) | |
""" ================== """ | |