|
import sys |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from lib.utils.utils_smpl import SMPL |
|
from lib.utils.utils_mesh import rotation_matrix_to_angle_axis, rot6d_to_rotmat |
|
|
|
class SMPLRegressor(nn.Module): |
|
def __init__(self, args, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.): |
|
super(SMPLRegressor, self).__init__() |
|
param_pose_dim = 24 * 6 |
|
self.dropout = nn.Dropout(p=dropout_ratio) |
|
self.fc1 = nn.Linear(num_joints*dim_rep, hidden_dim) |
|
self.pool2 = nn.AdaptiveAvgPool2d((None, 1)) |
|
self.fc2 = nn.Linear(num_joints*dim_rep, hidden_dim) |
|
self.bn1 = nn.BatchNorm1d(hidden_dim, momentum=0.1) |
|
self.bn2 = nn.BatchNorm1d(hidden_dim, momentum=0.1) |
|
self.relu1 = nn.ReLU(inplace=True) |
|
self.relu2 = nn.ReLU(inplace=True) |
|
self.head_pose = nn.Linear(hidden_dim, param_pose_dim) |
|
self.head_shape = nn.Linear(hidden_dim, 10) |
|
nn.init.xavier_uniform_(self.head_pose.weight, gain=0.01) |
|
nn.init.xavier_uniform_(self.head_shape.weight, gain=0.01) |
|
self.smpl = SMPL( |
|
args.data_root, |
|
batch_size=64, |
|
create_transl=False, |
|
) |
|
mean_params = np.load(self.smpl.smpl_mean_params) |
|
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) |
|
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) |
|
self.register_buffer('init_pose', init_pose) |
|
self.register_buffer('init_shape', init_shape) |
|
self.J_regressor = self.smpl.J_regressor_h36m |
|
|
|
def forward(self, feat, init_pose=None, init_shape=None): |
|
N, T, J, C = feat.shape |
|
NT = N * T |
|
feat = feat.reshape(N, T, -1) |
|
|
|
feat_pose = feat.reshape(NT, -1) |
|
|
|
feat_pose = self.dropout(feat_pose) |
|
feat_pose = self.fc1(feat_pose) |
|
feat_pose = self.bn1(feat_pose) |
|
feat_pose = self.relu1(feat_pose) |
|
|
|
feat_shape = feat.permute(0,2,1) |
|
feat_shape = self.pool2(feat_shape).reshape(N, -1) |
|
|
|
feat_shape = self.dropout(feat_shape) |
|
feat_shape = self.fc2(feat_shape) |
|
feat_shape = self.bn2(feat_shape) |
|
feat_shape = self.relu2(feat_shape) |
|
|
|
pred_pose = self.init_pose.expand(NT, -1) |
|
pred_shape = self.init_shape.expand(N, -1) |
|
|
|
pred_pose = self.head_pose(feat_pose) + pred_pose |
|
pred_shape = self.head_shape(feat_shape) + pred_shape |
|
pred_shape = pred_shape.expand(T, N, -1).permute(1, 0, 2).reshape(NT, -1) |
|
pred_rotmat = rot6d_to_rotmat(pred_pose).view(-1, 24, 3, 3) |
|
pred_output = self.smpl( |
|
betas=pred_shape, |
|
body_pose=pred_rotmat[:, 1:], |
|
global_orient=pred_rotmat[:, 0].unsqueeze(1), |
|
pose2rot=False |
|
) |
|
pred_vertices = pred_output.vertices*1000.0 |
|
assert self.J_regressor is not None |
|
J_regressor_batch = self.J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(pred_vertices.device) |
|
pred_joints = torch.matmul(J_regressor_batch, pred_vertices) |
|
pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72) |
|
output = [{ |
|
'theta' : torch.cat([pose, pred_shape], dim=1), |
|
'verts' : pred_vertices, |
|
'kp_3d' : pred_joints, |
|
}] |
|
return output |
|
|
|
class MeshRegressor(nn.Module): |
|
def __init__(self, args, backbone, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.5): |
|
super(MeshRegressor, self).__init__() |
|
self.backbone = backbone |
|
self.feat_J = num_joints |
|
self.head = SMPLRegressor(args, dim_rep, num_joints, hidden_dim, dropout_ratio) |
|
|
|
def forward(self, x, init_pose=None, init_shape=None, n_iter=3): |
|
''' |
|
Input: (N x T x 17 x 3) |
|
''' |
|
N, T, J, C = x.shape |
|
feat = self.backbone.get_representation(x) |
|
feat = feat.reshape([N, T, self.feat_J, -1]) |
|
smpl_output = self.head(feat) |
|
for s in smpl_output: |
|
s['theta'] = s['theta'].reshape(N, T, -1) |
|
s['verts'] = s['verts'].reshape(N, T, -1, 3) |
|
s['kp_3d'] = s['kp_3d'].reshape(N, T, -1, 3) |
|
return smpl_output |