TalkSHOWLIVE / trainer /Trainer.py
vscode69's picture
second half
99afdfe
raw
history blame
12 kB
import os
import sys
sys.path.append(os.getcwd())
from data_utils import torch_data
from trainer.options import parse_args
from trainer.config import load_JsonConfig
from nets.init_model import init_model
import torch
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import random
import logging
import time
import shutil
def prn_obj(obj):
print('\n'.join(['%s:%s' % item for item in obj.__dict__.items()]))
class Trainer():
def __init__(self) -> None:
parser = parse_args()
self.args = parser.parse_args()
self.config = load_JsonConfig(self.args.config_file)
os.environ['smplx_npz_path']=self.config.smplx_npz_path
os.environ['extra_joint_path']=self.config.extra_joint_path
os.environ['j14_regressor_path']=self.config.j14_regressor_path
# torch.set_default_dtype(torch.float64)
# wandb_run = wandb.init(project=f's2g_sweep')
# if self.args.use_wandb:
# print('starting wandb sweep agent...')
# wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d'
# os.environ['WANDB_API_KEY'] = wandb_key
#
# default_config=dict(w_b=1,w_h=10)
# wandb.init(config=default_config)
# self.config.param.w_b=wandb.config.w_b
# self.config.param.w_h=wandb.config.w_h
# self.config.Train.epochs=30
# if self.args.use_wandb:
# print('starting wandb sweep agent...')
# wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d'
# os.environ['WANDB_API_KEY'] = wandb_key
#
# wandb.init(config=self.args, project="s2g_sweep")
# # wandb.config.update(self.args)
#
# self.config.param.w_b=self.args.w_b
# self.config.param.w_h=self.args.w_h
# self.config.Train.epochs=30
self.device = torch.device(self.args.gpu)
torch.cuda.set_device(self.device)
self.setup_seed(self.args.seed)
self.set_train_dir()
shutil.copy(self.args.config_file, self.train_dir)
self.generator = init_model(self.config.Model.model_name, self.args, self.config)
self.init_dataloader()
self.start_epoch = 0
self.global_steps = 0
if self.args.resume:
self.resume()
# self.init_optimizer()
def setup_seed(self, seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def set_train_dir(self):
time_stamp = time.strftime('%Y-%m-%d',time.localtime(time.time()))
train_dir = os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath(
time_stamp + '-' + self.args.exp_name + '-' + self.config.Log.name))
# train_dir= os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath(time_stamp+'-'+self.args.exp_name+'-'+time.strftime("%H:%M:%S")))
os.makedirs(train_dir, exist_ok=True)
log_file=os.path.join(train_dir, 'train.log')
fmt="%(asctime)s-%(lineno)d-%(message)s"
logging.basicConfig(
stream=sys.stdout, level=logging.INFO,format=fmt, datefmt='%m/%d %I:%M:%S %p'
)
fh=logging.FileHandler(log_file)
fh.setFormatter(logging.Formatter(fmt))
logging.getLogger().addHandler(fh)
self.train_dir = train_dir
def resume(self):
print('resume from a previous ckpt')
ckpt = torch.load(self.args.pretrained_pth)
self.generator.load_state_dict(ckpt['generator'])
self.start_epoch = ckpt['epoch']
self.global_steps = ckpt['global_steps']
self.generator.global_step = self.global_steps
def init_dataloader(self):
if 'freeMo' in self.config.Model.model_name:
if self.config.Data.data_root.endswith('.csv'):
raise NotImplementedError
else:
data_class = torch_data
self.train_set = data_class(
data_root=self.config.Data.data_root,
speakers=self.args.speakers,
split='train',
limbscaling=self.config.Data.pose.augmentation,
normalization=self.config.Data.pose.normalization,
norm_method=self.config.Data.pose.norm_method,
split_trans_zero=True,
num_pre_frames=self.config.Data.pose.pre_pose_length,
num_frames=self.config.Data.pose.generate_length,
aud_feat_win_size=self.config.Data.aud.aud_feat_win_size,
aud_feat_dim=self.config.Data.aud.aud_feat_dim,
feat_method=self.config.Data.aud.feat_method,
context_info=self.config.Data.aud.context_info
)
if self.config.Data.pose.normalization:
self.norm_stats = (self.train_set.data_mean, self.train_set.data_std)
save_file = os.path.join(self.train_dir, 'norm_stats.npy')
np.save(save_file, self.norm_stats, allow_pickle=True)
self.train_set.get_dataset()
self.trans_set = self.train_set.trans_dataset
self.zero_set = self.train_set.zero_dataset
self.trans_loader = data.DataLoader(self.trans_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True)
self.zero_loader = data.DataLoader(self.zero_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True)
elif 'smplx' in self.config.Model.model_name or 's2g' in self.config.Model.model_name:
data_class = torch_data
self.train_set = data_class(
data_root=self.config.Data.data_root,
speakers=self.args.speakers,
split='train',
limbscaling=self.config.Data.pose.augmentation,
normalization=self.config.Data.pose.normalization,
norm_method=self.config.Data.pose.norm_method,
split_trans_zero=False,
num_pre_frames=self.config.Data.pose.pre_pose_length,
num_frames=self.config.Data.pose.generate_length,
num_generate_length=self.config.Data.pose.generate_length,
aud_feat_win_size=self.config.Data.aud.aud_feat_win_size,
aud_feat_dim=self.config.Data.aud.aud_feat_dim,
feat_method=self.config.Data.aud.feat_method,
context_info=self.config.Data.aud.context_info,
smplx=True,
audio_sr=22000,
convert_to_6d=self.config.Data.pose.convert_to_6d,
expression=self.config.Data.pose.expression,
config=self.config
)
if self.config.Data.pose.normalization:
self.norm_stats = (self.train_set.data_mean, self.train_set.data_std)
save_file = os.path.join(self.train_dir, 'norm_stats.npy')
np.save(save_file, self.norm_stats, allow_pickle=True)
self.train_set.get_dataset()
self.train_loader = data.DataLoader(self.train_set.all_dataset,
batch_size=self.config.DataLoader.batch_size, shuffle=True,
num_workers=self.config.DataLoader.num_workers, drop_last=True)
else:
data_class = torch_data
self.train_set = data_class(
data_root=self.config.Data.data_root,
speakers=self.args.speakers,
split='train',
limbscaling=self.config.Data.pose.augmentation,
normalization=self.config.Data.pose.normalization,
norm_method=self.config.Data.pose.norm_method,
split_trans_zero=False,
num_pre_frames=self.config.Data.pose.pre_pose_length,
num_frames=self.config.Data.pose.generate_length,
aud_feat_win_size=self.config.Data.aud.aud_feat_win_size,
aud_feat_dim=self.config.Data.aud.aud_feat_dim,
feat_method=self.config.Data.aud.feat_method,
context_info=self.config.Data.aud.context_info
)
if self.config.Data.pose.normalization:
self.norm_stats = (self.train_set.data_mean, self.train_set.data_std)
save_file = os.path.join(self.train_dir, 'norm_stats.npy')
np.save(save_file, self.norm_stats, allow_pickle=True)
self.train_set.get_dataset()
self.train_loader = data.DataLoader(self.train_set.all_dataset, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True)
def init_optimizer(self):
pass
def print_func(self, loss_dict, steps):
info_str = ['global_steps:%d'%(self.global_steps)]
info_str += ['%s:%.4f'%(key, loss_dict[key]/steps) for key in list(loss_dict.keys())]
logging.info(','.join(info_str))
def save_model(self, epoch):
# if 'vq' in self.config.Model.model_name:
# state_dict = {
# 'g_body': self.g_body.state_dict(),
# 'g_hand': self.g_hand.state_dict(),
# 'epoch': epoch,
# 'global_steps': self.global_steps
# }
# else:
state_dict = {
'generator': self.generator.state_dict(),
'epoch': epoch,
'global_steps': self.global_steps
}
save_name = os.path.join(self.train_dir, 'ckpt-%d.pth'%(epoch))
torch.save(state_dict, save_name)
def train_epoch(self, epoch):
epoch_loss_dict = {} #最好是追踪每个epoch的loss变换
epoch_steps = 0
if 'freeMo' in self.config.Model.model_name:
for bat in zip(self.trans_loader, self.zero_loader):
self.global_steps += 1
epoch_steps += 1
_, loss_dict = self.generator(bat)
if epoch_loss_dict:#非空
for key in list(loss_dict.keys()):
epoch_loss_dict[key] += loss_dict[key]
else:
for key in list(loss_dict.keys()):
epoch_loss_dict[key] = loss_dict[key]
if self.global_steps % self.config.Log.print_every == 0:
self.print_func(epoch_loss_dict, epoch_steps)
else:
# self.config.Model.model_name==smplx_S2G
for bat in self.train_loader:
# if epoch_steps == 1000:
# break
self.global_steps += 1
epoch_steps += 1
bat['epoch'] = epoch
_, loss_dict = self.generator(bat)
if epoch_loss_dict:#非空
for key in list(loss_dict.keys()):
epoch_loss_dict[key] += loss_dict[key]
else:
for key in list(loss_dict.keys()):
epoch_loss_dict[key] = loss_dict[key]
if self.global_steps % self.config.Log.print_every == 0:
self.print_func(epoch_loss_dict, epoch_steps)
def train(self):
logging.info('start_training')
self.total_loss_dict = {}
for epoch in range(self.start_epoch, self.config.Train.epochs):
logging.info('epoch:%d'%(epoch))
self.train_epoch(epoch)
# self.generator.scheduler.step()
# logging.info('learning rate:%d' % (self.generator.scheduler.get_lr()[0]))
if (epoch+1)%self.config.Log.save_every == 0 or (epoch+1) == 30:
self.save_model(epoch)