File size: 3,722 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
from ..builder import ARCHITECTURES, build_loss, build_submodule
from .base_architecture import BaseArchitecture
@ARCHITECTURES.register_module()
class PoseVAE(BaseArchitecture):
def __init__(self,
encoder=None,
decoder=None,
loss_recon=None,
kl_div_loss_weight=None,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg, **kwargs)
self.encoder = build_submodule(encoder)
self.decoder = build_submodule(decoder)
self.loss_recon = build_loss(loss_recon)
self.kl_div_loss_weight = kl_div_loss_weight
def reparameterize(self, mu, logvar):
std = torch.exp(logvar / 2)
eps = std.data.new(std.size()).normal_()
latent_code = eps.mul(std).add_(mu)
return latent_code
def encode(self, pose):
mu, logvar = self.encoder(pose)
return mu
def forward(self, **kwargs):
motion = kwargs['motion'].float()
B, T = motion.shape[:2]
pose = motion.reshape(B * T, -1)
pose = pose[:, :-4]
mu, logvar = self.encoder(pose)
z = self.reparameterize(mu, logvar)
pred = self.decoder(z)
loss = dict()
recon_loss = self.loss_recon(pred, pose, reduction_override='none')
loss['recon_loss'] = recon_loss
if self.kl_div_loss_weight is not None:
loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight)
return loss
@ARCHITECTURES.register_module()
class MotionVAE(BaseArchitecture):
def __init__(self,
encoder=None,
decoder=None,
loss_recon=None,
kl_div_loss_weight=None,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg, **kwargs)
self.encoder = build_submodule(encoder)
self.decoder = build_submodule(decoder)
self.loss_recon = build_loss(loss_recon)
self.kl_div_loss_weight = kl_div_loss_weight
def sample(self, std=1, latent_code=None):
if latent_code is not None:
z = latent_code
else:
z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std
output = self.decoder(z)
if self.use_normalization:
output = output * self.motion_std
output = output + self.motion_mean
return output
def reparameterize(self, mu, logvar):
std = torch.exp(logvar / 2)
eps = std.data.new(std.size()).normal_()
latent_code = eps.mul(std).add_(mu)
return latent_code
def encode(self, motion, motion_mask):
mu, logvar = self.encoder(motion, motion_mask)
return self.reparameterize(mu, logvar)
def decode(self, z, motion_mask):
return self.decoder(z, motion_mask)
def forward(self, **kwargs):
motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask']
B, T = motion.shape[:2]
mu, logvar = self.encoder(motion, motion_mask)
z = self.reparameterize(mu, logvar)
pred = self.decoder(z, motion_mask)
loss = dict()
recon_loss = self.loss_recon(pred, motion, reduction_override='none')
recon_loss = recon_loss.mean(dim=-1) * motion_mask
recon_loss = recon_loss.sum() / motion_mask.sum()
loss['recon_loss'] = recon_loss
if self.kl_div_loss_weight is not None:
loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight)
return loss
|