File size: 1,772 Bytes
99afdfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch


def to3d(poses, config):
    if config.Data.pose.convert_to_6d:
        if config.Data.pose.expression:
            poses_exp = poses[:, -100:]
            poses = poses[:, :-100]

        poses = poses.reshape(poses.shape[0], -1, 5)
        sin, cos = poses[:, :, 3], poses[:, :, 4]
        pose_angle = torch.atan2(sin, cos)
        poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1)

        if config.Data.pose.expression:
            poses = torch.cat([poses, poses_exp], dim=-1)
    return poses


def get_joint(smplx_model, betas, pred):
    joint = smplx_model(betas=betas.repeat(pred.shape[0], 1),
                        expression=pred[:, 165:265],
                        jaw_pose=pred[:, 0:3],
                        leye_pose=pred[:, 3:6],
                        reye_pose=pred[:, 6:9],
                        global_orient=pred[:, 9:12],
                        body_pose=pred[:, 12:75],
                        left_hand_pose=pred[:, 75:120],
                        right_hand_pose=pred[:, 120:165],
                        return_verts=True)['joints']
    return joint


def get_joints(smplx_model, betas, pred):
    if len(pred.shape) == 3:
        B = pred.shape[0]
        x = 4 if B>= 4 else B
        T = pred.shape[1]
        pred = pred.reshape(-1, 265)
        smplx_model.batch_size = L = T * x

        times = pred.shape[0] // smplx_model.batch_size
        joints = []
        for i in range(times):
            joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L]))
        joints = torch.cat(joints, dim=0)
        joints = joints.reshape(B, T, -1, 3)
    else:
        smplx_model.batch_size = pred.shape[0]
        joints = get_joint(smplx_model, betas, pred)
    return joints