import torch def to3d(poses, config): if config.Data.pose.convert_to_6d: if config.Data.pose.expression: poses_exp = poses[:, -100:] poses = poses[:, :-100] poses = poses.reshape(poses.shape[0], -1, 5) sin, cos = poses[:, :, 3], poses[:, :, 4] pose_angle = torch.atan2(sin, cos) poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1) if config.Data.pose.expression: poses = torch.cat([poses, poses_exp], dim=-1) return poses def get_joint(smplx_model, betas, pred): joint = smplx_model(betas=betas.repeat(pred.shape[0], 1), expression=pred[:, 165:265], jaw_pose=pred[:, 0:3], leye_pose=pred[:, 3:6], reye_pose=pred[:, 6:9], global_orient=pred[:, 9:12], body_pose=pred[:, 12:75], left_hand_pose=pred[:, 75:120], right_hand_pose=pred[:, 120:165], return_verts=True)['joints'] return joint def get_joints(smplx_model, betas, pred): if len(pred.shape) == 3: B = pred.shape[0] x = 4 if B>= 4 else B T = pred.shape[1] pred = pred.reshape(-1, 265) smplx_model.batch_size = L = T * x times = pred.shape[0] // smplx_model.batch_size joints = [] for i in range(times): joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L])) joints = torch.cat(joints, dim=0) joints = joints.reshape(B, T, -1, 3) else: smplx_model.batch_size = pred.shape[0] joints = get_joint(smplx_model, betas, pred) return joints