|
import torch |
|
|
|
from ..builder import ARCHITECTURES, build_loss, build_submodule |
|
from .base_architecture import BaseArchitecture |
|
|
|
|
|
@ARCHITECTURES.register_module() |
|
class PoseVAE(BaseArchitecture): |
|
|
|
def __init__(self, |
|
encoder=None, |
|
decoder=None, |
|
loss_recon=None, |
|
kl_div_loss_weight=None, |
|
init_cfg=None, |
|
**kwargs): |
|
super().__init__(init_cfg=init_cfg, **kwargs) |
|
self.encoder = build_submodule(encoder) |
|
self.decoder = build_submodule(decoder) |
|
self.loss_recon = build_loss(loss_recon) |
|
self.kl_div_loss_weight = kl_div_loss_weight |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(logvar / 2) |
|
|
|
eps = std.data.new(std.size()).normal_() |
|
latent_code = eps.mul(std).add_(mu) |
|
return latent_code |
|
|
|
def encode(self, pose): |
|
mu, logvar = self.encoder(pose) |
|
return mu |
|
|
|
def forward(self, **kwargs): |
|
motion = kwargs['motion'].float() |
|
B, T = motion.shape[:2] |
|
pose = motion.reshape(B * T, -1) |
|
pose = pose[:, :-4] |
|
|
|
mu, logvar = self.encoder(pose) |
|
z = self.reparameterize(mu, logvar) |
|
pred = self.decoder(z) |
|
|
|
loss = dict() |
|
recon_loss = self.loss_recon(pred, pose, reduction_override='none') |
|
loss['recon_loss'] = recon_loss |
|
if self.kl_div_loss_weight is not None: |
|
loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
|
loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) |
|
|
|
return loss |
|
|
|
|
|
@ARCHITECTURES.register_module() |
|
class MotionVAE(BaseArchitecture): |
|
|
|
def __init__(self, |
|
encoder=None, |
|
decoder=None, |
|
loss_recon=None, |
|
kl_div_loss_weight=None, |
|
init_cfg=None, |
|
**kwargs): |
|
super().__init__(init_cfg=init_cfg, **kwargs) |
|
self.encoder = build_submodule(encoder) |
|
self.decoder = build_submodule(decoder) |
|
self.loss_recon = build_loss(loss_recon) |
|
self.kl_div_loss_weight = kl_div_loss_weight |
|
|
|
def sample(self, std=1, latent_code=None): |
|
if latent_code is not None: |
|
z = latent_code |
|
else: |
|
z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std |
|
output = self.decoder(z) |
|
if self.use_normalization: |
|
output = output * self.motion_std |
|
output = output + self.motion_mean |
|
return output |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(logvar / 2) |
|
|
|
eps = std.data.new(std.size()).normal_() |
|
latent_code = eps.mul(std).add_(mu) |
|
return latent_code |
|
|
|
def encode(self, motion, motion_mask): |
|
mu, logvar = self.encoder(motion, motion_mask) |
|
return self.reparameterize(mu, logvar) |
|
|
|
def decode(self, z, motion_mask): |
|
return self.decoder(z, motion_mask) |
|
|
|
def forward(self, **kwargs): |
|
motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'] |
|
B, T = motion.shape[:2] |
|
|
|
mu, logvar = self.encoder(motion, motion_mask) |
|
z = self.reparameterize(mu, logvar) |
|
pred = self.decoder(z, motion_mask) |
|
|
|
loss = dict() |
|
recon_loss = self.loss_recon(pred, motion, reduction_override='none') |
|
recon_loss = recon_loss.mean(dim=-1) * motion_mask |
|
recon_loss = recon_loss.sum() / motion_mask.sum() |
|
loss['recon_loss'] = recon_loss |
|
if self.kl_div_loss_weight is not None: |
|
loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
|
loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) |
|
|
|
return loss |
|
|