Spaces:
Configuration error
Configuration error
import os | |
import random | |
import copy | |
import time | |
import sys | |
import shutil | |
import argparse | |
import errno | |
import math | |
import numpy as np | |
from collections import defaultdict, OrderedDict | |
import tensorboardX | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from torch.optim.lr_scheduler import StepLR | |
from lib.utils.tools import * | |
from lib.model.loss import * | |
from lib.model.loss_mesh import * | |
from lib.utils.utils_mesh import * | |
from lib.utils.utils_smpl import * | |
from lib.utils.utils_data import * | |
from lib.utils.learning import * | |
from lib.data.dataset_mesh import MotionSMPL | |
from lib.model.model_mesh import MeshRegressor | |
from torch.utils.data import DataLoader | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") | |
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') | |
parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') | |
parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') | |
parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') | |
parser.add_argument('-freq', '--print_freq', default=100) | |
parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') | |
parser.add_argument('-sd', '--seed', default=0, type=int, help='random seed') | |
opts = parser.parse_args() | |
return opts | |
def set_random_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
def validate(test_loader, model, criterion, dataset_name='h36m'): | |
model.eval() | |
print(f'===========> validating {dataset_name}') | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
losses_dict = {'loss_3d_pos': AverageMeter(), | |
'loss_3d_scale': AverageMeter(), | |
'loss_3d_velocity': AverageMeter(), | |
'loss_lv': AverageMeter(), | |
'loss_lg': AverageMeter(), | |
'loss_a': AverageMeter(), | |
'loss_av': AverageMeter(), | |
'loss_pose': AverageMeter(), | |
'loss_shape': AverageMeter(), | |
'loss_norm': AverageMeter(), | |
} | |
mpjpes = AverageMeter() | |
mpves = AverageMeter() | |
results = defaultdict(list) | |
smpl = SMPL(args.data_root, batch_size=1).cuda() | |
J_regressor = smpl.J_regressor_h36m | |
with torch.no_grad(): | |
end = time.time() | |
for idx, (batch_input, batch_gt) in tqdm(enumerate(test_loader)): | |
batch_size, clip_len = batch_input.shape[:2] | |
if torch.cuda.is_available(): | |
batch_gt['theta'] = batch_gt['theta'].cuda().float() | |
batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float() | |
batch_gt['verts'] = batch_gt['verts'].cuda().float() | |
batch_input = batch_input.cuda().float() | |
output = model(batch_input) | |
output_final = output | |
if args.flip: | |
batch_input_flip = flip_data(batch_input) | |
output_flip = model(batch_input_flip) | |
output_flip_pose = output_flip[0]['theta'][:, :, :72] | |
output_flip_shape = output_flip[0]['theta'][:, :, 72:] | |
output_flip_pose = flip_thetas_batch(output_flip_pose) | |
output_flip_pose = output_flip_pose.reshape(-1, 72) | |
output_flip_shape = output_flip_shape.reshape(-1, 10) | |
output_flip_smpl = smpl( | |
betas=output_flip_shape, | |
body_pose=output_flip_pose[:, 3:], | |
global_orient=output_flip_pose[:, :3], | |
pose2rot=True | |
) | |
output_flip_verts = output_flip_smpl.vertices.detach()*1000.0 | |
J_regressor_batch = J_regressor[None, :].expand(output_flip_verts.shape[0], -1, -1).to(output_flip_verts.device) | |
output_flip_kp3d = torch.matmul(J_regressor_batch, output_flip_verts) # (NT,17,3) | |
output_flip_back = [{ | |
'theta': torch.cat((output_flip_pose.reshape(batch_size, clip_len, -1), output_flip_shape.reshape(batch_size, clip_len, -1)), dim=-1), | |
'verts': output_flip_verts.reshape(batch_size, clip_len, -1, 3), | |
'kp_3d': output_flip_kp3d.reshape(batch_size, clip_len, -1, 3), | |
}] | |
output_final = [{}] | |
for k, v in output_flip[0].items(): | |
output_final[0][k] = (output[0][k] + output_flip_back[0][k])*0.5 | |
output = output_final | |
loss_dict = criterion(output, batch_gt) | |
loss = args.lambda_3d * loss_dict['loss_3d_pos'] + \ | |
args.lambda_scale * loss_dict['loss_3d_scale'] + \ | |
args.lambda_3dv * loss_dict['loss_3d_velocity'] + \ | |
args.lambda_lv * loss_dict['loss_lv'] + \ | |
args.lambda_lg * loss_dict['loss_lg'] + \ | |
args.lambda_a * loss_dict['loss_a'] + \ | |
args.lambda_av * loss_dict['loss_av'] + \ | |
args.lambda_shape * loss_dict['loss_shape'] + \ | |
args.lambda_pose * loss_dict['loss_pose'] + \ | |
args.lambda_norm * loss_dict['loss_norm'] | |
# update metric | |
losses.update(loss.item(), batch_size) | |
loss_str = '' | |
for k, v in loss_dict.items(): | |
losses_dict[k].update(v.item(), batch_size) | |
loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k]) | |
mpjpe, mpve = compute_error(output, batch_gt) | |
mpjpes.update(mpjpe, batch_size) | |
mpves.update(mpve, batch_size) | |
for keys in output[0].keys(): | |
output[0][keys] = output[0][keys].detach().cpu().numpy() | |
batch_gt[keys] = batch_gt[keys].detach().cpu().numpy() | |
results['kp_3d'].append(output[0]['kp_3d']) | |
results['verts'].append(output[0]['verts']) | |
results['kp_3d_gt'].append(batch_gt['kp_3d']) | |
results['verts_gt'].append(batch_gt['verts']) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if idx % int(opts.print_freq) == 0: | |
print('Test: [{0}/{1}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'{2}' | |
'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t' | |
'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format( | |
idx, len(test_loader), loss_str, batch_time=batch_time, | |
loss=losses, mpves=mpves, mpjpes=mpjpes)) | |
print(f'==> start concating results of {dataset_name}') | |
for term in results.keys(): | |
results[term] = np.concatenate(results[term]) | |
print(f'==> start evaluating {dataset_name}...') | |
error_dict = evaluate_mesh(results) | |
err_str = '' | |
for err_key, err_val in error_dict.items(): | |
err_str += '{}: {:.2f}mm \t'.format(err_key, err_val) | |
print(f'=======================> {dataset_name} validation done: ', loss_str) | |
print(f'=======================> {dataset_name} validation done: ', err_str) | |
return losses.avg, error_dict['mpjpe'], error_dict['pa_mpjpe'], error_dict['mpve'], losses_dict | |
def train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch): | |
model.train() | |
end = time.time() | |
for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): | |
data_time.update(time.time() - end) | |
batch_size = len(batch_input) | |
if torch.cuda.is_available(): | |
batch_gt['theta'] = batch_gt['theta'].cuda().float() | |
batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float() | |
batch_gt['verts'] = batch_gt['verts'].cuda().float() | |
batch_input = batch_input.cuda().float() | |
output = model(batch_input) | |
optimizer.zero_grad() | |
loss_dict = criterion(output, batch_gt) | |
loss_train = args.lambda_3d * loss_dict['loss_3d_pos'] + \ | |
args.lambda_scale * loss_dict['loss_3d_scale'] + \ | |
args.lambda_3dv * loss_dict['loss_3d_velocity'] + \ | |
args.lambda_lv * loss_dict['loss_lv'] + \ | |
args.lambda_lg * loss_dict['loss_lg'] + \ | |
args.lambda_a * loss_dict['loss_a'] + \ | |
args.lambda_av * loss_dict['loss_av'] + \ | |
args.lambda_shape * loss_dict['loss_shape'] + \ | |
args.lambda_pose * loss_dict['loss_pose'] + \ | |
args.lambda_norm * loss_dict['loss_norm'] | |
losses_train.update(loss_train.item(), batch_size) | |
loss_str = '' | |
for k, v in loss_dict.items(): | |
losses_dict[k].update(v.item(), batch_size) | |
loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k]) | |
mpjpe, mpve = compute_error(output, batch_gt) | |
mpjpes.update(mpjpe, batch_size) | |
mpves.update(mpve, batch_size) | |
loss_train.backward() | |
optimizer.step() | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if idx % int(opts.print_freq) == 0: | |
print('Train: [{0}][{1}/{2}]\t' | |
'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' | |
'loss {loss.val:.3f} ({loss.avg:.3f})\t' | |
'{3}' | |
'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t' | |
'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format( | |
epoch, idx + 1, len(train_loader), loss_str, batch_time=batch_time, | |
data_time=data_time, loss=losses_train, mpves=mpves, mpjpes=mpjpes)) | |
sys.stdout.flush() | |
def train_with_config(args, opts): | |
print(args) | |
try: | |
os.makedirs(opts.checkpoint) | |
shutil.copy(opts.config, opts.checkpoint) | |
except OSError as e: | |
if e.errno != errno.EEXIST: | |
raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) | |
train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) | |
model_backbone = load_backbone(args) | |
if args.finetune: | |
if opts.resume or opts.evaluate: | |
pass | |
else: | |
chk_filename = os.path.join(opts.pretrained, opts.selection) | |
print('Loading backbone', chk_filename) | |
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)['model_pos'] | |
model_backbone = load_pretrained_weights(model_backbone, checkpoint) | |
if args.partial_train: | |
model_backbone = partial_train_layers(model_backbone, args.partial_train) | |
model = MeshRegressor(args, backbone=model_backbone, dim_rep=args.dim_rep, hidden_dim=args.hidden_dim, dropout_ratio=args.dropout, num_joints=args.num_joints) | |
criterion = MeshLoss(loss_type = args.loss_type) | |
best_jpe = 9999.0 | |
model_params = 0 | |
for parameter in model.parameters(): | |
if parameter.requires_grad == True: | |
model_params = model_params + parameter.numel() | |
print('INFO: Trainable parameter count:', model_params) | |
print('Loading dataset...') | |
trainloader_params = { | |
'batch_size': args.batch_size, | |
'shuffle': True, | |
'num_workers': 8, | |
'pin_memory': True, | |
'prefetch_factor': 4, | |
'persistent_workers': True | |
} | |
testloader_params = { | |
'batch_size': args.batch_size, | |
'shuffle': False, | |
'num_workers': 8, | |
'pin_memory': True, | |
'prefetch_factor': 4, | |
'persistent_workers': True | |
} | |
if hasattr(args, "dt_file_h36m"): | |
mesh_train = MotionSMPL(args, data_split='train', dataset="h36m") | |
mesh_val = MotionSMPL(args, data_split='test', dataset="h36m") | |
train_loader = DataLoader(mesh_train, **trainloader_params) | |
test_loader = DataLoader(mesh_val, **testloader_params) | |
print('INFO: Training on {} batches (h36m)'.format(len(train_loader))) | |
if hasattr(args, "dt_file_pw3d"): | |
if args.train_pw3d: | |
mesh_train_pw3d = MotionSMPL(args, data_split='train', dataset="pw3d") | |
train_loader_pw3d = DataLoader(mesh_train_pw3d, **trainloader_params) | |
print('INFO: Training on {} batches (pw3d)'.format(len(train_loader_pw3d))) | |
mesh_val_pw3d = MotionSMPL(args, data_split='test', dataset="pw3d") | |
test_loader_pw3d = DataLoader(mesh_val_pw3d, **testloader_params) | |
trainloader_img_params = { | |
'batch_size': args.batch_size_img, | |
'shuffle': True, | |
'num_workers': 8, | |
'pin_memory': True, | |
'prefetch_factor': 4, | |
'persistent_workers': True | |
} | |
testloader_img_params = { | |
'batch_size': args.batch_size_img, | |
'shuffle': False, | |
'num_workers': 8, | |
'pin_memory': True, | |
'prefetch_factor': 4, | |
'persistent_workers': True | |
} | |
if hasattr(args, "dt_file_coco"): | |
mesh_train_coco = MotionSMPL(args, data_split='train', dataset="coco") | |
mesh_val_coco = MotionSMPL(args, data_split='test', dataset="coco") | |
train_loader_coco = DataLoader(mesh_train_coco, **trainloader_img_params) | |
test_loader_coco = DataLoader(mesh_val_coco, **testloader_img_params) | |
print('INFO: Training on {} batches (coco)'.format(len(train_loader_coco))) | |
if torch.cuda.is_available(): | |
model = nn.DataParallel(model) | |
model = model.cuda() | |
chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") | |
if os.path.exists(chk_filename): | |
opts.resume = chk_filename | |
if opts.resume or opts.evaluate: | |
chk_filename = opts.evaluate if opts.evaluate else opts.resume | |
print('Loading checkpoint', chk_filename) | |
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) | |
model.load_state_dict(checkpoint['model'], strict=True) | |
if not opts.evaluate: | |
optimizer = optim.AdamW( | |
[ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, | |
{"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, | |
], lr=args.lr_backbone, | |
weight_decay=args.weight_decay | |
) | |
scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) | |
st = 0 | |
if opts.resume: | |
st = checkpoint['epoch'] | |
if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
else: | |
print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') | |
lr = checkpoint['lr'] | |
if 'best_jpe' in checkpoint and checkpoint['best_jpe'] is not None: | |
best_jpe = checkpoint['best_jpe'] | |
# Training | |
for epoch in range(st, args.epochs): | |
print('Training epoch %d.' % epoch) | |
losses_train = AverageMeter() | |
losses_dict = { | |
'loss_3d_pos': AverageMeter(), | |
'loss_3d_scale': AverageMeter(), | |
'loss_3d_velocity': AverageMeter(), | |
'loss_lv': AverageMeter(), | |
'loss_lg': AverageMeter(), | |
'loss_a': AverageMeter(), | |
'loss_av': AverageMeter(), | |
'loss_pose': AverageMeter(), | |
'loss_shape': AverageMeter(), | |
'loss_norm': AverageMeter(), | |
} | |
mpjpes = AverageMeter() | |
mpves = AverageMeter() | |
batch_time = AverageMeter() | |
data_time = AverageMeter() | |
if hasattr(args, "dt_file_h36m") and epoch < args.warmup_h36m: | |
train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) | |
test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, test_losses_dict = validate(test_loader, model, criterion, 'h36m') | |
for k, v in test_losses_dict.items(): | |
train_writer.add_scalar('test_loss/'+k, v.avg, epoch + 1) | |
train_writer.add_scalar('test_loss', test_loss, epoch + 1) | |
train_writer.add_scalar('test_mpjpe', test_mpjpe, epoch + 1) | |
train_writer.add_scalar('test_pa_mpjpe', test_pa_mpjpe, epoch + 1) | |
train_writer.add_scalar('test_mpve', test_mpve, epoch + 1) | |
if hasattr(args, "dt_file_coco") and epoch < args.warmup_coco: | |
train_epoch(args, opts, model, train_loader_coco, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) | |
if hasattr(args, "dt_file_pw3d"): | |
if args.train_pw3d: | |
train_epoch(args, opts, model, train_loader_pw3d, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) | |
test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, test_losses_dict_pw3d = validate(test_loader_pw3d, model, criterion, 'pw3d') | |
for k, v in test_losses_dict_pw3d.items(): | |
train_writer.add_scalar('test_loss_pw3d/'+k, v.avg, epoch + 1) | |
train_writer.add_scalar('test_loss_pw3d', test_loss_pw3d, epoch + 1) | |
train_writer.add_scalar('test_mpjpe_pw3d', test_mpjpe_pw3d, epoch + 1) | |
train_writer.add_scalar('test_pa_mpjpe_pw3d', test_pa_mpjpe_pw3d, epoch + 1) | |
train_writer.add_scalar('test_mpve_pw3d', test_mpve_pw3d, epoch + 1) | |
for k, v in losses_dict.items(): | |
train_writer.add_scalar('train_loss/'+k, v.avg, epoch + 1) | |
train_writer.add_scalar('train_loss', losses_train.avg, epoch + 1) | |
train_writer.add_scalar('train_mpjpe', mpjpes.avg, epoch + 1) | |
train_writer.add_scalar('train_mpve', mpves.avg, epoch + 1) | |
# Decay learning rate exponentially | |
scheduler.step() | |
# Save latest checkpoint. | |
chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') | |
print('Saving checkpoint to', chk_path) | |
torch.save({ | |
'epoch': epoch+1, | |
'lr': scheduler.get_last_lr(), | |
'optimizer': optimizer.state_dict(), | |
'model': model.state_dict(), | |
'best_jpe' : best_jpe | |
}, chk_path) | |
# Save checkpoint if necessary. | |
if (epoch+1) % args.checkpoint_frequency == 0: | |
chk_path = os.path.join(opts.checkpoint, 'epoch_{}.bin'.format(epoch)) | |
print('Saving checkpoint to', chk_path) | |
torch.save({ | |
'epoch': epoch+1, | |
'lr': scheduler.get_last_lr(), | |
'optimizer': optimizer.state_dict(), | |
'model': model.state_dict(), | |
'best_jpe' : best_jpe | |
}, chk_path) | |
if hasattr(args, "dt_file_pw3d"): | |
best_jpe_cur = test_mpjpe_pw3d | |
else: | |
best_jpe_cur = test_mpjpe | |
# Save best checkpoint. | |
best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) | |
if best_jpe_cur < best_jpe: | |
best_jpe = best_jpe_cur | |
print("save best checkpoint") | |
torch.save({ | |
'epoch': epoch+1, | |
'lr': scheduler.get_last_lr(), | |
'optimizer': optimizer.state_dict(), | |
'model': model.state_dict(), | |
'best_jpe' : best_jpe | |
}, best_chk_path) | |
if opts.evaluate: | |
if hasattr(args, "dt_file_h36m"): | |
test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, _ = validate(test_loader, model, criterion, 'h36m') | |
if hasattr(args, "dt_file_pw3d"): | |
test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, _ = validate(test_loader_pw3d, model, criterion, 'pw3d') | |
if __name__ == "__main__": | |
opts = parse_args() | |
set_random_seed(opts.seed) | |
args = get_config(opts.config) | |
train_with_config(args, opts) |