File size: 3,142 Bytes
99afdfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import sys

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append(os.getcwd())

from tqdm import tqdm
from transformers import Wav2Vec2Processor

from evaluation.metrics import LVD

import numpy as np
import smplx as smpl

from data_utils.lower_body import part2full, poses2pred, c_index_3d
from nets import *
from nets.utils import get_path, get_dpath
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig

import torch
from torch.utils import data
from data_utils.get_j import to3d, get_joints
from scripts.test_body import init_model, init_dataloader


def test(test_loader, generator, config):
    print('start testing')

    loss_dict = {}
    B = 1
    with torch.no_grad():
        count = 0
        for bat in tqdm(test_loader, desc="Testing......"):
            count = count + 1
            aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
                              bat['expression'].to('cuda').to(torch.float32)
            id = bat['speaker'].to('cuda') - 20
            betas = bat['betas'][0].to('cuda').to(torch.float64)
            poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze()
            poses = to3d(poses, config).unsqueeze(dim=0).transpose(1, 2)
            # poses = poses[:, c_index_3d, :]

            cur_wav_file = bat['aud_file'][0]

            pred = generator.infer_on_audio(cur_wav_file,
                                            initial_pose=poses,
                                            id=id,
                                            fps=30,
                                            B=B
                                            )
            pred = torch.tensor(pred, device='cuda')
            bat_loss_dict = {'capacity': (poses[:, c_index_3d, :pred.shape[0]].transpose(1,2) - pred).abs().sum(-1).mean()}

            if loss_dict:  # 非空
                for key in list(bat_loss_dict.keys()):
                    loss_dict[key] += bat_loss_dict[key]
            else:
                for key in list(bat_loss_dict.keys()):
                    loss_dict[key] = bat_loss_dict[key]
        for key in loss_dict.keys():
            loss_dict[key] = loss_dict[key] / count
            print(key + '=' + str(loss_dict[key].item()))


def main():
    parser = parse_args()
    args = parser.parse_args()
    device = torch.device(args.gpu)
    torch.cuda.set_device(device)

    config = load_JsonConfig(args.config_file)

    os.environ['smplx_npz_path'] = config.smplx_npz_path
    os.environ['extra_joint_path'] = config.extra_joint_path
    os.environ['j14_regressor_path'] = config.j14_regressor_path

    print('init dataloader...')
    test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
    print('init model...')
    model_name = 's2g_body_vq'
    model_type = 'n_com_8192'
    model_path = get_path(model_name, model_type)
    generator = init_model(model_name, model_path, args, config)

    test(test_loader, generator, config)


if __name__ == '__main__':
    main()