Spaces:
Sleeping
Sleeping
import os | |
import sys | |
sys.path.append(os.getcwd()) | |
from data_utils import torch_data | |
from trainer.options import parse_args | |
from trainer.config import load_JsonConfig | |
from nets.init_model import init_model | |
import torch | |
import torch.utils.data as data | |
import torch.optim as optim | |
import numpy as np | |
import random | |
import logging | |
import time | |
import shutil | |
def prn_obj(obj): | |
print('\n'.join(['%s:%s' % item for item in obj.__dict__.items()])) | |
class Trainer(): | |
def __init__(self) -> None: | |
parser = parse_args() | |
self.args = parser.parse_args() | |
self.config = load_JsonConfig(self.args.config_file) | |
os.environ['smplx_npz_path']=self.config.smplx_npz_path | |
os.environ['extra_joint_path']=self.config.extra_joint_path | |
os.environ['j14_regressor_path']=self.config.j14_regressor_path | |
# torch.set_default_dtype(torch.float64) | |
# wandb_run = wandb.init(project=f's2g_sweep') | |
# if self.args.use_wandb: | |
# print('starting wandb sweep agent...') | |
# wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d' | |
# os.environ['WANDB_API_KEY'] = wandb_key | |
# | |
# default_config=dict(w_b=1,w_h=10) | |
# wandb.init(config=default_config) | |
# self.config.param.w_b=wandb.config.w_b | |
# self.config.param.w_h=wandb.config.w_h | |
# self.config.Train.epochs=30 | |
# if self.args.use_wandb: | |
# print('starting wandb sweep agent...') | |
# wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d' | |
# os.environ['WANDB_API_KEY'] = wandb_key | |
# | |
# wandb.init(config=self.args, project="s2g_sweep") | |
# # wandb.config.update(self.args) | |
# | |
# self.config.param.w_b=self.args.w_b | |
# self.config.param.w_h=self.args.w_h | |
# self.config.Train.epochs=30 | |
self.device = torch.device(self.args.gpu) | |
torch.cuda.set_device(self.device) | |
self.setup_seed(self.args.seed) | |
self.set_train_dir() | |
shutil.copy(self.args.config_file, self.train_dir) | |
self.generator = init_model(self.config.Model.model_name, self.args, self.config) | |
self.init_dataloader() | |
self.start_epoch = 0 | |
self.global_steps = 0 | |
if self.args.resume: | |
self.resume() | |
# self.init_optimizer() | |
def setup_seed(self, seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def set_train_dir(self): | |
time_stamp = time.strftime('%Y-%m-%d',time.localtime(time.time())) | |
train_dir = os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath( | |
time_stamp + '-' + self.args.exp_name + '-' + self.config.Log.name)) | |
# train_dir= os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath(time_stamp+'-'+self.args.exp_name+'-'+time.strftime("%H:%M:%S"))) | |
os.makedirs(train_dir, exist_ok=True) | |
log_file=os.path.join(train_dir, 'train.log') | |
fmt="%(asctime)s-%(lineno)d-%(message)s" | |
logging.basicConfig( | |
stream=sys.stdout, level=logging.INFO,format=fmt, datefmt='%m/%d %I:%M:%S %p' | |
) | |
fh=logging.FileHandler(log_file) | |
fh.setFormatter(logging.Formatter(fmt)) | |
logging.getLogger().addHandler(fh) | |
self.train_dir = train_dir | |
def resume(self): | |
print('resume from a previous ckpt') | |
ckpt = torch.load(self.args.pretrained_pth) | |
self.generator.load_state_dict(ckpt['generator']) | |
self.start_epoch = ckpt['epoch'] | |
self.global_steps = ckpt['global_steps'] | |
self.generator.global_step = self.global_steps | |
def init_dataloader(self): | |
if 'freeMo' in self.config.Model.model_name: | |
if self.config.Data.data_root.endswith('.csv'): | |
raise NotImplementedError | |
else: | |
data_class = torch_data | |
self.train_set = data_class( | |
data_root=self.config.Data.data_root, | |
speakers=self.args.speakers, | |
split='train', | |
limbscaling=self.config.Data.pose.augmentation, | |
normalization=self.config.Data.pose.normalization, | |
norm_method=self.config.Data.pose.norm_method, | |
split_trans_zero=True, | |
num_pre_frames=self.config.Data.pose.pre_pose_length, | |
num_frames=self.config.Data.pose.generate_length, | |
aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, | |
aud_feat_dim=self.config.Data.aud.aud_feat_dim, | |
feat_method=self.config.Data.aud.feat_method, | |
context_info=self.config.Data.aud.context_info | |
) | |
if self.config.Data.pose.normalization: | |
self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) | |
save_file = os.path.join(self.train_dir, 'norm_stats.npy') | |
np.save(save_file, self.norm_stats, allow_pickle=True) | |
self.train_set.get_dataset() | |
self.trans_set = self.train_set.trans_dataset | |
self.zero_set = self.train_set.zero_dataset | |
self.trans_loader = data.DataLoader(self.trans_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) | |
self.zero_loader = data.DataLoader(self.zero_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) | |
elif 'smplx' in self.config.Model.model_name or 's2g' in self.config.Model.model_name: | |
data_class = torch_data | |
self.train_set = data_class( | |
data_root=self.config.Data.data_root, | |
speakers=self.args.speakers, | |
split='train', | |
limbscaling=self.config.Data.pose.augmentation, | |
normalization=self.config.Data.pose.normalization, | |
norm_method=self.config.Data.pose.norm_method, | |
split_trans_zero=False, | |
num_pre_frames=self.config.Data.pose.pre_pose_length, | |
num_frames=self.config.Data.pose.generate_length, | |
num_generate_length=self.config.Data.pose.generate_length, | |
aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, | |
aud_feat_dim=self.config.Data.aud.aud_feat_dim, | |
feat_method=self.config.Data.aud.feat_method, | |
context_info=self.config.Data.aud.context_info, | |
smplx=True, | |
audio_sr=22000, | |
convert_to_6d=self.config.Data.pose.convert_to_6d, | |
expression=self.config.Data.pose.expression, | |
config=self.config | |
) | |
if self.config.Data.pose.normalization: | |
self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) | |
save_file = os.path.join(self.train_dir, 'norm_stats.npy') | |
np.save(save_file, self.norm_stats, allow_pickle=True) | |
self.train_set.get_dataset() | |
self.train_loader = data.DataLoader(self.train_set.all_dataset, | |
batch_size=self.config.DataLoader.batch_size, shuffle=True, | |
num_workers=self.config.DataLoader.num_workers, drop_last=True) | |
else: | |
data_class = torch_data | |
self.train_set = data_class( | |
data_root=self.config.Data.data_root, | |
speakers=self.args.speakers, | |
split='train', | |
limbscaling=self.config.Data.pose.augmentation, | |
normalization=self.config.Data.pose.normalization, | |
norm_method=self.config.Data.pose.norm_method, | |
split_trans_zero=False, | |
num_pre_frames=self.config.Data.pose.pre_pose_length, | |
num_frames=self.config.Data.pose.generate_length, | |
aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, | |
aud_feat_dim=self.config.Data.aud.aud_feat_dim, | |
feat_method=self.config.Data.aud.feat_method, | |
context_info=self.config.Data.aud.context_info | |
) | |
if self.config.Data.pose.normalization: | |
self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) | |
save_file = os.path.join(self.train_dir, 'norm_stats.npy') | |
np.save(save_file, self.norm_stats, allow_pickle=True) | |
self.train_set.get_dataset() | |
self.train_loader = data.DataLoader(self.train_set.all_dataset, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) | |
def init_optimizer(self): | |
pass | |
def print_func(self, loss_dict, steps): | |
info_str = ['global_steps:%d'%(self.global_steps)] | |
info_str += ['%s:%.4f'%(key, loss_dict[key]/steps) for key in list(loss_dict.keys())] | |
logging.info(','.join(info_str)) | |
def save_model(self, epoch): | |
# if 'vq' in self.config.Model.model_name: | |
# state_dict = { | |
# 'g_body': self.g_body.state_dict(), | |
# 'g_hand': self.g_hand.state_dict(), | |
# 'epoch': epoch, | |
# 'global_steps': self.global_steps | |
# } | |
# else: | |
state_dict = { | |
'generator': self.generator.state_dict(), | |
'epoch': epoch, | |
'global_steps': self.global_steps | |
} | |
save_name = os.path.join(self.train_dir, 'ckpt-%d.pth'%(epoch)) | |
torch.save(state_dict, save_name) | |
def train_epoch(self, epoch): | |
epoch_loss_dict = {} #最好是追踪每个epoch的loss变换 | |
epoch_steps = 0 | |
if 'freeMo' in self.config.Model.model_name: | |
for bat in zip(self.trans_loader, self.zero_loader): | |
self.global_steps += 1 | |
epoch_steps += 1 | |
_, loss_dict = self.generator(bat) | |
if epoch_loss_dict:#非空 | |
for key in list(loss_dict.keys()): | |
epoch_loss_dict[key] += loss_dict[key] | |
else: | |
for key in list(loss_dict.keys()): | |
epoch_loss_dict[key] = loss_dict[key] | |
if self.global_steps % self.config.Log.print_every == 0: | |
self.print_func(epoch_loss_dict, epoch_steps) | |
else: | |
# self.config.Model.model_name==smplx_S2G | |
for bat in self.train_loader: | |
# if epoch_steps == 1000: | |
# break | |
self.global_steps += 1 | |
epoch_steps += 1 | |
bat['epoch'] = epoch | |
_, loss_dict = self.generator(bat) | |
if epoch_loss_dict:#非空 | |
for key in list(loss_dict.keys()): | |
epoch_loss_dict[key] += loss_dict[key] | |
else: | |
for key in list(loss_dict.keys()): | |
epoch_loss_dict[key] = loss_dict[key] | |
if self.global_steps % self.config.Log.print_every == 0: | |
self.print_func(epoch_loss_dict, epoch_steps) | |
def train(self): | |
logging.info('start_training') | |
self.total_loss_dict = {} | |
for epoch in range(self.start_epoch, self.config.Train.epochs): | |
logging.info('epoch:%d'%(epoch)) | |
self.train_epoch(epoch) | |
# self.generator.scheduler.step() | |
# logging.info('learning rate:%d' % (self.generator.scheduler.get_lr()[0])) | |
if (epoch+1)%self.config.Log.save_every == 0 or (epoch+1) == 30: | |
self.save_model(epoch) | |