import os import sys os.environ['CUDA_VISIBLE_DEVICES'] = '0' sys.path.append(os.getcwd()) from glob import glob import numpy as np import json import smplx as smpl from nets import * from repro_nets import * from trainer.options import parse_args from data_utils import torch_data from trainer.config import load_JsonConfig import torch import torch.nn as nn import torch.nn.functional as F from torch.utils import data def init_model(model_name, model_path, args, config): if model_name == 'freeMo': # generator = freeMo_Generator(args) # generator = freeMo_Generator(args) generator = freeMo_dev(args, config) # generator.load_state_dict(torch.load(model_path)['generator']) elif model_name == 'smplx_S2G': generator = smplx_S2G(args, config) elif model_name == 'StyleGestures': generator = StyleGesture_Generator( args, config ) elif model_name == 'Audio2Gestures': config.Train.using_mspec_stat = False generator = Audio2Gesture_Generator( args, config, torch.zeros([1, 1, 108]), torch.ones([1, 1, 108]) ) elif model_name == 'S2G': generator = S2G_Generator( args, config, ) elif model_name == 'Tmpt': generator = S2G_Generator( args, config, ) else: raise NotImplementedError model_ckpt = torch.load(model_path, map_location=torch.device('cpu')) if model_name == 'smplx_S2G': generator.generator.load_state_dict(model_ckpt['generator']['generator']) elif 'generator' in list(model_ckpt.keys()): generator.load_state_dict(model_ckpt['generator']) else: model_ckpt = {'generator': model_ckpt} generator.load_state_dict(model_ckpt) return generator def prevar_loader(data_root, speakers, args, config, model_path, device, generator): path = model_path.split('ckpt')[0] file = os.path.join(os.path.dirname(path), "pre_variable.npy") data_base = torch_data( data_root=data_root, speakers=speakers, split='pre', limbscaling=False, normalization=config.Data.pose.normalization, norm_method=config.Data.pose.norm_method, split_trans_zero=False, num_pre_frames=config.Data.pose.pre_pose_length, num_generate_length=config.Data.pose.generate_length, num_frames=15, aud_feat_win_size=config.Data.aud.aud_feat_win_size, aud_feat_dim=config.Data.aud.aud_feat_dim, feat_method=config.Data.aud.feat_method, smplx=True, audio_sr=22000, convert_to_6d=config.Data.pose.convert_to_6d, expression=config.Data.pose.expression ) data_base.get_dataset() pre_set = data_base.all_dataset pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True) total_pose = [] with torch.no_grad(): for bat in pre_loader: pose = bat['poses'].to(device).to(torch.float32) expression = bat['expression'].to(device).to(torch.float32) pose = pose.permute(0, 2, 1) pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0) expression = expression.permute(0, 2, 1) expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0) pose = torch.cat([pose, expression], dim=-1) pose = pose.reshape(pose.shape[0], -1, 1) pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu() total_pose.append(np.asarray(pose_code)) total_pose = np.concatenate(total_pose, axis=0) mean = np.mean(total_pose, axis=0) std = np.std(total_pose, axis=0) prevar = (mean, std) np.save(file, prevar, allow_pickle=True) return mean, std def main(): parser = parse_args() args = parser.parse_args() device = torch.device(args.gpu) torch.cuda.set_device(device) config = load_JsonConfig(args.config_file) print('init model...') generator = init_model(config.Model.model_name, args.model_path, args, config) print('init pre-pose vectors...') mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator) main()