Spaces:
Sleeping
Sleeping
File size: 8,504 Bytes
99afdfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import os
import sys
# os.environ["PYOPENGL_PLATFORM"] = "egl"
from transformers import Wav2Vec2Processor
from visualise.rendering import RenderTool
sys.path.append(os.getcwd())
from glob import glob
import numpy as np
import json
import smplx as smpl
from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from scripts.diversity import init_model, init_dataloader, get_vertices
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
import time
global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
smplx_model, rendertool, args=None, config=None, var=None):
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
am_sr = 16000
num_sample = 1
face = False
if face:
body_static = torch.zeros([1, 162], device='cuda')
body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
stand = False
j = 0
gt_0 = None
for bat in infer_loader:
poses_ = bat['poses'].to(torch.float32).to(device)
if poses_.shape[-1] == 300:
j = j + 1
if j > 1000:
continue
id = bat['speaker'].to('cuda') - 20
if config.Data.pose.expression:
expression = bat['expression'].to(device).to(torch.float32)
poses = torch.cat([poses_, expression], dim=1)
else:
poses = poses_
cur_wav_file = bat['aud_file'][0]
betas = bat['betas'][0].to(torch.float64).to('cuda')
# betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
gt = poses.to('cuda').squeeze().transpose(1, 0)
if config.Data.pose.normalization:
gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
if config.Data.pose.convert_to_6d:
if config.Data.pose.expression:
gt_exp = gt[:, -100:]
gt = gt[:, :-100]
gt = gt.reshape(gt.shape[0], -1, 6)
gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
gt = torch.cat([gt, gt_exp], -1)
if face:
gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)
result_list = [gt]
# cur_wav_file = '.\\training_data\\french-V4.wav'
# pred_face = g_face.infer_on_audio(cur_wav_file,
# initial_pose=poses_,
# norm_stats=None,
# w_pre=False,
# # id=id,
# frame=None,
# am=am,
# am_sr=am_sr
# )
#
# pred_face = torch.tensor(pred_face).squeeze().to('cuda')
pred_face = torch.zeros([gt.shape[0], 103], device='cuda')
pred_jaw = pred_face[:, :3]
pred_face = pred_face[:, 3:]
# id = torch.tensor([0], device='cuda')
for i in range(num_sample):
pred_res = g_body.infer_on_audio(cur_wav_file,
initial_pose=poses_,
norm_stats=norm_stats,
txgfile=None,
id=id,
var=var,
fps=30,
continuity=True,
smooth=False
)
pred = torch.tensor(pred_res).squeeze().to('cuda')
if pred.shape[0] < pred_face.shape[0]:
repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
pred = torch.cat([pred, repeat_frame], dim=0)
else:
pred = pred[:pred_face.shape[0], :]
if config.Data.pose.convert_to_6d:
pred = pred.reshape(pred.shape[0], -1, 6)
pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
pred = pred.reshape(pred.shape[0], -1)
pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
# pred[:, 9:12] = global_orient
pred = part2full(pred, stand)
if face:
pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
# result_list[0] = poses2pred(result_list[0], stand)
# if gt_0 is None:
# gt_0 = gt
# pred = pred2poses(pred, gt_0)
# result_list[0] = poses2poses(result_list[0], gt_0)
result_list.append(pred)
vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
result_list = [res.to('cpu') for res in result_list]
dict = np.concatenate(result_list[1:], axis=0)
file_name = 'visualise/video/' + config.Log.name + '/' + \
cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
np.save(file_name, dict)
rendertool._render_continuity(cur_wav_file, vertices_list[1], frame=60)
def main():
parser = parse_args()
args = parser.parse_args()
device = torch.device(args.gpu)
torch.cuda.set_device(device)
config = load_JsonConfig(args.config_file)
smplx = True
os.environ['smplx_npz_path'] = config.smplx_npz_path
os.environ['extra_joint_path'] = config.extra_joint_path
os.environ['j14_regressor_path'] = config.j14_regressor_path
print('init model...')
body_model_name = 's2g_body_pixel'
body_model_path = './experiments/2022-12-31-smplx_S2G-body-pixel-conti-wide/ckpt-99.pth' # './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
generator = init_model(body_model_name, body_model_path, args, config)
# face_model_name = 's2g_face'
# face_model_path = './experiments/2022-10-15-smplx_S2G-face-sgd-3p-wv2/ckpt-99.pth' # './experiments/2022-09-28-smplx_S2G-face-faceformer-3d/ckpt-99.pth'
# generator_face = init_model(face_model_name, face_model_path, args, config)
generator_face = None
print('init dataloader...')
infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
print('init smlpx model...')
dtype = torch.float64
model_params = dict(model_path='E:/PycharmProjects/Motion-Projects/models',
model_type='smplx',
create_global_orient=True,
create_body_pose=True,
create_betas=True,
num_betas=300,
create_left_hand_pose=True,
create_right_hand_pose=True,
use_pca=False,
flat_hand_mean=False,
create_expression=True,
num_expression_coeffs=100,
num_pca_comps=12,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=False,
# gender='ne',
dtype=dtype, )
smplx_model = smpl.create(**model_params).to('cuda')
print('init rendertool...')
rendertool = RenderTool('visualise/video/' + config.Log.name)
infer(config.Data.data_root, generator, generator_face, None, args.exp_name, infer_loader, infer_set, device,
norm_stats, smplx, smplx_model, rendertool, args, config, (None, None))
if __name__ == '__main__':
main()
|