|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
import os |
|
import os.path as osp |
|
from datetime import datetime as dt |
|
from pytorch_lightning.plugins import DDPPlugin |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
from human_body_prior.body_model.body_model import BodyModel |
|
from human_body_prior.data.dataloader import VPoserDS |
|
from human_body_prior.data.prepare_data import dataset_exists |
|
from human_body_prior.data.prepare_data import prepare_vposer_datasets |
|
from human_body_prior.models.vposer_model import VPoser |
|
from human_body_prior.tools.angle_continuous_repres import geodesic_loss_R |
|
from human_body_prior.tools.configurations import load_config, dump_config |
|
from human_body_prior.tools.omni_tools import copy2cpu as c2c |
|
from human_body_prior.tools.omni_tools import get_support_data_dir |
|
from human_body_prior.tools.omni_tools import log2file |
|
from human_body_prior.tools.omni_tools import make_deterministic |
|
from human_body_prior.tools.omni_tools import makepath |
|
from human_body_prior.tools.rotation_tools import aa2matrot |
|
from human_body_prior.visualizations.training_visualization import vposer_trainer_renderer |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
|
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
|
from pytorch_lightning.core import LightningModule |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from pytorch_lightning.utilities import rank_zero_only |
|
from torch import optim as optim_module |
|
from torch.optim import lr_scheduler as lr_sched_module |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
class VPoserTrainer(LightningModule): |
|
""" |
|
|
|
It includes all data loading and train / val logic., and it is used for both training and testing models. |
|
""" |
|
|
|
def __init__(self, _config): |
|
super(VPoserTrainer, self).__init__() |
|
|
|
_support_data_dir = get_support_data_dir() |
|
|
|
vp_ps = load_config(**_config) |
|
|
|
make_deterministic(vp_ps.general.rnd_seed) |
|
|
|
self.expr_id = vp_ps.general.expr_id |
|
self.dataset_id = vp_ps.general.dataset_id |
|
|
|
self.work_dir = vp_ps.logging.work_dir = makepath(vp_ps.general.work_basedir, self.expr_id) |
|
self.dataset_dir = vp_ps.logging.dataset_dir = osp.join(vp_ps.general.dataset_basedir, vp_ps.general.dataset_id) |
|
|
|
self._log_prefix = '[{}]'.format(self.expr_id) |
|
self.text_logger = log2file(prefix=self._log_prefix) |
|
|
|
self.seq_len = vp_ps.data_parms.num_timeseq_frames |
|
|
|
self.vp_model = VPoser(vp_ps) |
|
|
|
with torch.no_grad(): |
|
|
|
self.bm_train = BodyModel(vp_ps.body_model.bm_fname) |
|
|
|
if vp_ps.logging.render_during_training: |
|
self.renderer = vposer_trainer_renderer(self.bm_train, vp_ps.logging.num_bodies_to_display) |
|
else: |
|
self.renderer = None |
|
|
|
self.example_input_array = {'pose_body':torch.ones(vp_ps.train_parms.batch_size, 63),} |
|
self.vp_ps = vp_ps |
|
|
|
def forward(self, pose_body): |
|
|
|
return self.vp_model(pose_body) |
|
|
|
def _get_data(self, split_name): |
|
|
|
assert split_name in ('train', 'vald', 'test') |
|
|
|
split_name = split_name.replace('vald', 'vald') |
|
|
|
assert dataset_exists(self.dataset_dir), FileNotFoundError('Dataset does not exist dataset_dir = {}'.format(self.dataset_dir)) |
|
dataset = VPoserDS(osp.join(self.dataset_dir, split_name), data_fields = ['pose_body']) |
|
|
|
assert len(dataset) != 0, ValueError('Dataset has nothing in it!') |
|
|
|
return DataLoader(dataset, |
|
batch_size=self.vp_ps.train_parms.batch_size, |
|
shuffle=True if split_name == 'train' else False, |
|
num_workers=self.vp_ps.data_parms.num_workers, |
|
pin_memory=True) |
|
|
|
@rank_zero_only |
|
def on_train_start(self): |
|
if self.global_rank != 0: return |
|
self.train_starttime = dt.now().replace(microsecond=0) |
|
|
|
|
|
git_repo_dir = os.path.abspath(__file__).split('/') |
|
git_repo_dir = '/'.join(git_repo_dir[:git_repo_dir.index('human_body_prior') + 1]) |
|
starttime = dt.strftime(self.train_starttime, '%Y_%m_%d_%H_%M_%S') |
|
archive_path = makepath(self.work_dir, 'code', 'vposer_{}.tar.gz'.format(starttime), isfile=True) |
|
cmd = 'cd %s && git ls-files -z | xargs -0 tar -czf %s' % (git_repo_dir, archive_path) |
|
os.system(cmd) |
|
|
|
self.text_logger('Created a git archive backup at {}'.format(archive_path)) |
|
dump_config(self.vp_ps, osp.join(self.work_dir, '{}.yaml'.format(self.expr_id))) |
|
|
|
def train_dataloader(self): |
|
return self._get_data('train') |
|
|
|
def val_dataloader(self): |
|
return self._get_data('vald') |
|
|
|
def configure_optimizers(self): |
|
params_count = lambda params: sum(p.numel() for p in params if p.requires_grad) |
|
|
|
gen_params = [a[1] for a in self.vp_model.named_parameters() if a[1].requires_grad] |
|
gen_optimizer_class = getattr(optim_module, self.vp_ps.train_parms.gen_optimizer.type) |
|
gen_optimizer = gen_optimizer_class(gen_params, **self.vp_ps.train_parms.gen_optimizer.args) |
|
|
|
self.text_logger('Total Trainable Parameters Count in vp_model is %2.2f M.' % (params_count(gen_params) * 1e-6)) |
|
|
|
lr_sched_class = getattr(lr_sched_module, self.vp_ps.train_parms.lr_scheduler.type) |
|
|
|
gen_lr_scheduler = lr_sched_class(gen_optimizer, **self.vp_ps.train_parms.lr_scheduler.args) |
|
|
|
schedulers = [ |
|
{ |
|
'scheduler': gen_lr_scheduler, |
|
'monitor': 'val_loss', |
|
'interval': 'epoch', |
|
'frequency': 1 |
|
}, |
|
] |
|
return [gen_optimizer], schedulers |
|
|
|
def _compute_loss(self, dorig, drec): |
|
l1_loss = torch.nn.L1Loss(reduction='mean') |
|
geodesic_loss = geodesic_loss_R(reduction='mean') |
|
|
|
bs, latentD = drec['poZ_body_mean'].shape |
|
device = drec['poZ_body_mean'].device |
|
|
|
loss_kl_wt = self.vp_ps.train_parms.loss_weights.loss_kl_wt |
|
loss_rec_wt = self.vp_ps.train_parms.loss_weights.loss_rec_wt |
|
loss_matrot_wt = self.vp_ps.train_parms.loss_weights.loss_matrot_wt |
|
loss_jtr_wt = self.vp_ps.train_parms.loss_weights.loss_jtr_wt |
|
|
|
|
|
q_z = drec['q_z'] |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
bm_orig = self.bm_train(pose_body=dorig['pose_body']) |
|
|
|
bm_rec = self.bm_train(pose_body=drec['pose_body'].contiguous().view(bs, -1)) |
|
|
|
v2v = l1_loss(bm_rec.v, bm_orig.v) |
|
|
|
|
|
p_z = torch.distributions.normal.Normal( |
|
loc=torch.zeros((bs, latentD), device=device, requires_grad=False), |
|
scale=torch.ones((bs, latentD), device=device, requires_grad=False)) |
|
weighted_loss_dict = { |
|
'loss_kl':loss_kl_wt * torch.mean(torch.sum(torch.distributions.kl.kl_divergence(q_z, p_z), dim=[1])), |
|
'loss_mesh_rec': loss_rec_wt * v2v |
|
} |
|
|
|
if (self.current_epoch < self.vp_ps.train_parms.keep_extra_loss_terms_until_epoch): |
|
|
|
weighted_loss_dict['matrot'] = loss_matrot_wt * geodesic_loss(drec['pose_body_matrot'].view(-1,3,3), aa2matrot(dorig['pose_body'].view(-1, 3))) |
|
weighted_loss_dict['jtr'] = loss_jtr_wt * l1_loss(bm_rec.Jtr, bm_orig.Jtr) |
|
|
|
weighted_loss_dict['loss_total'] = torch.stack(list(weighted_loss_dict.values())).sum() |
|
|
|
with torch.no_grad(): |
|
unweighted_loss_dict = {'v2v': torch.sqrt(torch.pow(bm_rec.v-bm_orig.v, 2).sum(-1)).mean()} |
|
unweighted_loss_dict['loss_total'] = torch.cat( |
|
list({k: v.view(-1) for k, v in unweighted_loss_dict.items()}.values()), dim=-1).sum().view(1) |
|
|
|
return {'weighted_loss': weighted_loss_dict, 'unweighted_loss': unweighted_loss_dict} |
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx=None): |
|
|
|
drec = self(batch['pose_body'].view(-1, 63)) |
|
|
|
loss = self._compute_loss(batch, drec) |
|
|
|
train_loss = loss['weighted_loss']['loss_total'] |
|
|
|
tensorboard_logs = {'train_loss': train_loss} |
|
progress_bar = {k: c2c(v) for k, v in loss['weighted_loss'].items()} |
|
return {'loss': train_loss, 'progress_bar':progress_bar, 'log': tensorboard_logs} |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
drec = self(batch['pose_body'].view(-1, 63)) |
|
|
|
loss = self._compute_loss(batch, drec) |
|
val_loss = loss['unweighted_loss']['loss_total'] |
|
|
|
if self.renderer is not None and self.global_rank == 0 and batch_idx % 500==0 and np.random.rand()>0.5: |
|
out_fname = makepath(self.work_dir, 'renders/vald_rec_E{:03d}_It{:04d}_val_loss_{:.2f}.png'.format(self.current_epoch, batch_idx, val_loss.item()), isfile=True) |
|
self.renderer([batch, drec], out_fname = out_fname) |
|
dgen = self.vp_model.sample_poses(self.vp_ps.logging.num_bodies_to_display) |
|
out_fname = makepath(self.work_dir, 'renders/vald_gen_E{:03d}_I{:04d}.png'.format(self.current_epoch, batch_idx), isfile=True) |
|
self.renderer([dgen], out_fname = out_fname) |
|
|
|
|
|
progress_bar = {'v2v': val_loss} |
|
return {'val_loss': c2c(val_loss), 'progress_bar': progress_bar, 'log': progress_bar} |
|
|
|
def validation_epoch_end(self, outputs): |
|
metrics = {'val_loss': np.nanmean(np.concatenate([v['val_loss'] for v in outputs])) } |
|
|
|
if self.global_rank == 0: |
|
|
|
self.text_logger('Epoch {}: {}'.format(self.current_epoch, ', '.join('{}:{:.2f}'.format(k, v) for k, v in metrics.items()))) |
|
self.text_logger('lr is {}'.format([pg['lr'] for opt in self.trainer.optimizers for pg in opt.param_groups])) |
|
|
|
metrics = {k: torch.as_tensor(v) for k, v in metrics.items()} |
|
|
|
return {'val_loss': metrics['val_loss'], 'log': metrics} |
|
|
|
|
|
@rank_zero_only |
|
def on_train_end(self): |
|
|
|
self.train_endtime = dt.now().replace(microsecond=0) |
|
endtime = dt.strftime(self.train_endtime, '%Y_%m_%d_%H_%M_%S') |
|
elapsedtime = self.train_endtime - self.train_starttime |
|
self.vp_ps.logging.best_model_fname = self.trainer.checkpoint_callback.best_model_path |
|
|
|
self.text_logger('Epoch {} - Finished training at {} after {}'.format(self.current_epoch, endtime, elapsedtime)) |
|
self.text_logger('best_model_fname: {}'.format(self.vp_ps.logging.best_model_fname)) |
|
|
|
dump_config(self.vp_ps, osp.join(self.work_dir, '{}_{}.yaml'.format(self.expr_id, self.dataset_id))) |
|
self.hparams = self.vp_ps.toDict() |
|
|
|
@rank_zero_only |
|
def prepare_data(self): |
|
'''' Similar to standard AMASS dataset preparation pipeline: |
|
Donwload npz file, corresponding to body data from https://amass.is.tue.mpg.de/ and place them under amass_dir |
|
''' |
|
self.text_logger = log2file(makepath(self.work_dir, '{}.log'.format(self.expr_id), isfile=True), prefix=self._log_prefix) |
|
|
|
prepare_vposer_datasets(self.dataset_dir, self.vp_ps.data_parms.amass_splits, self.vp_ps.data_parms.amass_dir, logger=self.text_logger) |
|
|
|
|
|
def create_expr_message(ps): |
|
expr_msg = '[{}] batch_size = {}.'.format(ps.general.expr_id, ps.train_parms.batch_size) |
|
|
|
return expr_msg |
|
|
|
|
|
def train_vposer_once(_config): |
|
|
|
resume_training_if_possible = True |
|
|
|
model = VPoserTrainer(_config) |
|
model.vp_ps.logging.expr_msg = create_expr_message(model.vp_ps) |
|
|
|
dump_config(model.vp_ps, osp.join(model.work_dir, '{}.yaml'.format(model.expr_id))) |
|
|
|
logger = TensorBoardLogger(model.work_dir, name='tensorboard') |
|
lr_monitor = LearningRateMonitor() |
|
|
|
snapshots_dir = osp.join(model.work_dir, 'snapshots') |
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=makepath(snapshots_dir, isfile=True), |
|
filename="%s_{epoch:02d}_{val_loss:.2f}" % model.expr_id, |
|
save_top_k=1, |
|
verbose=True, |
|
monitor='val_loss', |
|
mode='min', |
|
) |
|
early_stop_callback = EarlyStopping(**model.vp_ps.train_parms.early_stopping) |
|
|
|
resume_from_checkpoint = None |
|
if resume_training_if_possible: |
|
available_ckpts = sorted(glob.glob(osp.join(snapshots_dir, '*.ckpt')), key=os.path.getmtime) |
|
if len(available_ckpts)>0: |
|
resume_from_checkpoint = available_ckpts[-1] |
|
model.text_logger('Resuming the training from {}'.format(resume_from_checkpoint)) |
|
|
|
trainer = pl.Trainer(gpus=1, |
|
weights_summary='top', |
|
distributed_backend = 'ddp', |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plugins=[DDPPlugin(find_unused_parameters=False)], |
|
|
|
callbacks=[lr_monitor, early_stop_callback, checkpoint_callback], |
|
|
|
max_epochs=model.vp_ps.train_parms.num_epochs, |
|
logger=logger, |
|
resume_from_checkpoint=resume_from_checkpoint |
|
) |
|
|
|
trainer.fit(model) |
|
|