File size: 3,722 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch

from ..builder import ARCHITECTURES, build_loss, build_submodule
from .base_architecture import BaseArchitecture


@ARCHITECTURES.register_module()
class PoseVAE(BaseArchitecture):

    def __init__(self,
                 encoder=None,
                 decoder=None,
                 loss_recon=None,
                 kl_div_loss_weight=None,
                 init_cfg=None,
                 **kwargs):
        super().__init__(init_cfg=init_cfg, **kwargs)
        self.encoder = build_submodule(encoder)
        self.decoder = build_submodule(decoder)
        self.loss_recon = build_loss(loss_recon)
        self.kl_div_loss_weight = kl_div_loss_weight

    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar / 2)

        eps = std.data.new(std.size()).normal_()
        latent_code = eps.mul(std).add_(mu)
        return latent_code

    def encode(self, pose):
        mu, logvar = self.encoder(pose)
        return mu

    def forward(self, **kwargs):
        motion = kwargs['motion'].float()
        B, T = motion.shape[:2]
        pose = motion.reshape(B * T, -1)
        pose = pose[:, :-4]

        mu, logvar = self.encoder(pose)
        z = self.reparameterize(mu, logvar)
        pred = self.decoder(z)

        loss = dict()
        recon_loss = self.loss_recon(pred, pose, reduction_override='none')
        loss['recon_loss'] = recon_loss
        if self.kl_div_loss_weight is not None:
            loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight)

        return loss


@ARCHITECTURES.register_module()
class MotionVAE(BaseArchitecture):

    def __init__(self,
                 encoder=None,
                 decoder=None,
                 loss_recon=None,
                 kl_div_loss_weight=None,
                 init_cfg=None,
                 **kwargs):
        super().__init__(init_cfg=init_cfg, **kwargs)
        self.encoder = build_submodule(encoder)
        self.decoder = build_submodule(decoder)
        self.loss_recon = build_loss(loss_recon)
        self.kl_div_loss_weight = kl_div_loss_weight

    def sample(self, std=1, latent_code=None):
        if latent_code is not None:
            z = latent_code
        else:
            z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std
        output = self.decoder(z)
        if self.use_normalization:
            output = output * self.motion_std
            output = output + self.motion_mean
        return output

    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar / 2)

        eps = std.data.new(std.size()).normal_()
        latent_code = eps.mul(std).add_(mu)
        return latent_code

    def encode(self, motion, motion_mask):
        mu, logvar = self.encoder(motion, motion_mask)
        return self.reparameterize(mu, logvar)

    def decode(self, z, motion_mask):
        return self.decoder(z, motion_mask)

    def forward(self, **kwargs):
        motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask']
        B, T = motion.shape[:2]

        mu, logvar = self.encoder(motion, motion_mask)
        z = self.reparameterize(mu, logvar)
        pred = self.decoder(z, motion_mask)

        loss = dict()
        recon_loss = self.loss_recon(pred, motion, reduction_override='none')
        recon_loss = recon_loss.mean(dim=-1) * motion_mask
        recon_loss = recon_loss.sum() / motion_mask.sum()
        loss['recon_loss'] = recon_loss
        if self.kl_div_loss_weight is not None:
            loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight)

        return loss