|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from human_body_prior.models.model_components import BatchFlatten |
|
from human_body_prior.tools.rotation_tools import matrot2aa |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class ContinousRotReprDecoder(nn.Module): |
|
def __init__(self): |
|
super(ContinousRotReprDecoder, self).__init__() |
|
|
|
def forward(self, module_input): |
|
reshaped_input = module_input.view(-1, 3, 2) |
|
|
|
b1 = F.normalize(reshaped_input[:, :, 0], dim=1) |
|
|
|
dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True) |
|
b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1) |
|
b3 = torch.cross(b1, b2, dim=1) |
|
|
|
return torch.stack([b1, b2, b3], dim=-1) |
|
|
|
|
|
class NormalDistDecoder(nn.Module): |
|
def __init__(self, num_feat_in, latentD): |
|
super(NormalDistDecoder, self).__init__() |
|
|
|
self.mu = nn.Linear(num_feat_in, latentD) |
|
self.logvar = nn.Linear(num_feat_in, latentD) |
|
|
|
def forward(self, Xout): |
|
return torch.distributions.normal.Normal(self.mu(Xout), F.softplus(self.logvar(Xout))) |
|
|
|
|
|
class VPoser(nn.Module): |
|
def __init__(self, model_ps): |
|
super(VPoser, self).__init__() |
|
|
|
num_neurons, self.latentD = model_ps.model_params.num_neurons, model_ps.model_params.latentD |
|
|
|
self.num_joints = 21 |
|
n_features = self.num_joints * 3 |
|
|
|
self.encoder_net = nn.Sequential( |
|
BatchFlatten(), |
|
nn.BatchNorm1d(n_features), |
|
nn.Linear(n_features, num_neurons), |
|
nn.LeakyReLU(), |
|
nn.BatchNorm1d(num_neurons), |
|
nn.Dropout(0.1), |
|
nn.Linear(num_neurons, num_neurons), |
|
nn.Linear(num_neurons, num_neurons), |
|
NormalDistDecoder(num_neurons, self.latentD) |
|
) |
|
|
|
self.decoder_net = nn.Sequential( |
|
nn.Linear(self.latentD, num_neurons), |
|
nn.LeakyReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(num_neurons, num_neurons), |
|
nn.LeakyReLU(), |
|
nn.Linear(num_neurons, self.num_joints * 6), |
|
ContinousRotReprDecoder(), |
|
) |
|
|
|
def encode(self, pose_body): |
|
''' |
|
:param Pin: Nx(numjoints*3) |
|
:param rep_type: 'matrot'/'aa' for matrix rotations or axis-angle |
|
:return: |
|
''' |
|
return self.encoder_net(pose_body) |
|
|
|
def decode(self, Zin): |
|
bs = Zin.shape[0] |
|
|
|
prec = self.decoder_net(Zin) |
|
|
|
return { |
|
'pose_body': matrot2aa(prec.view(-1, 3, 3)).view(bs, -1, 3), |
|
'pose_body_matrot': prec.view(bs, -1, 9) |
|
} |
|
|
|
|
|
def forward(self, pose_body): |
|
''' |
|
:param Pin: aa: Nx1xnum_jointsx3 / matrot: Nx1xnum_jointsx9 |
|
:param input_type: matrot / aa for matrix rotations or axis angles |
|
:param output_type: matrot / aa |
|
:return: |
|
''' |
|
|
|
q_z = self.encode(pose_body) |
|
q_z_sample = q_z.rsample() |
|
decode_results = self.decode(q_z_sample) |
|
decode_results.update({'poZ_body_mean': q_z.mean, 'poZ_body_std': q_z.scale, 'q_z': q_z}) |
|
return decode_results |
|
|
|
def sample_poses(self, num_poses, seed=None): |
|
np.random.seed(seed) |
|
|
|
some_weight = [a for a in self.parameters()][0] |
|
dtype = some_weight.dtype |
|
device = some_weight.device |
|
self.eval() |
|
with torch.no_grad(): |
|
Zgen = torch.tensor(np.random.normal(0., 1., size=(num_poses, self.latentD)), dtype=dtype, device=device) |
|
|
|
return self.decode(Zgen) |
|
|