Spaces:
Configuration error
Configuration error
import os | |
import numpy as np | |
import time | |
import sys | |
import argparse | |
import errno | |
from collections import OrderedDict | |
import tensorboardX | |
from tqdm import tqdm | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.optim.lr_scheduler import StepLR | |
from torch.utils.data import DataLoader | |
from lib.utils.tools import * | |
from lib.utils.learning import * | |
from lib.model.loss import * | |
from lib.data.dataset_action import NTURGBD | |
from lib.model.model_action import ActionNet | |
random.seed(0) | |
np.random.seed(0) | |
torch.manual_seed(0) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") | |
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') | |
parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') | |
parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') | |
parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') | |
parser.add_argument('-freq', '--print_freq', default=100) | |
parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') | |
opts = parser.parse_args() | |
return opts | |
def validate(test_loader, model, criterion): | |
model.eval() | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
with torch.no_grad(): | |
end = time.time() | |
for idx, (batch_input, batch_gt) in tqdm(enumerate(test_loader)): | |
batch_size = len(batch_input) | |
if torch.cuda.is_available(): | |
batch_gt = batch_gt.cuda() | |
batch_input = batch_input.cuda() | |
output = model(batch_input) # (N, num_classes) | |
loss = criterion(output, batch_gt) | |
# update metric | |
losses.update(loss.item(), batch_size) | |
acc1, acc5 = accuracy(output, batch_gt, topk=(1, 5)) | |
top1.update(acc1[0], batch_size) | |
top5.update(acc5[0], batch_size) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if (idx+1) % opts.print_freq == 0: | |
print('Test: [{0}/{1}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( | |
idx, len(test_loader), batch_time=batch_time, | |
loss=losses, top1=top1, top5=top5)) | |
return losses.avg, top1.avg, top5.avg | |
def train_with_config(args, opts): | |
print(args) | |
try: | |
os.makedirs(opts.checkpoint) | |
except OSError as e: | |
if e.errno != errno.EEXIST: | |
raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) | |
train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) | |
model_backbone = load_backbone(args) | |
if args.finetune: | |
if opts.resume or opts.evaluate: | |
pass | |
else: | |
chk_filename = os.path.join(opts.pretrained, opts.selection) | |
print('Loading backbone', chk_filename) | |
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)['model_pos'] | |
model_backbone = load_pretrained_weights(model_backbone, checkpoint) | |
if args.partial_train: | |
model_backbone = partial_train_layers(model_backbone, args.partial_train) | |
model = ActionNet(backbone=model_backbone, dim_rep=args.dim_rep, num_classes=args.action_classes, dropout_ratio=args.dropout_ratio, version=args.model_version, hidden_dim=args.hidden_dim, num_joints=args.num_joints) | |
criterion = torch.nn.CrossEntropyLoss() | |
if torch.cuda.is_available(): | |
model = nn.DataParallel(model) | |
model = model.cuda() | |
criterion = criterion.cuda() | |
best_acc = 0 | |
model_params = 0 | |
for parameter in model.parameters(): | |
model_params = model_params + parameter.numel() | |
print('INFO: Trainable parameter count:', model_params) | |
print('Loading dataset...') | |
trainloader_params = { | |
'batch_size': args.batch_size, | |
'shuffle': True, | |
'num_workers': 8, | |
'pin_memory': True, | |
'prefetch_factor': 4, | |
'persistent_workers': True | |
} | |
testloader_params = { | |
'batch_size': args.batch_size, | |
'shuffle': False, | |
'num_workers': 8, | |
'pin_memory': True, | |
'prefetch_factor': 4, | |
'persistent_workers': True | |
} | |
data_path = 'data/action/%s.pkl' % args.dataset | |
ntu60_xsub_train = NTURGBD(data_path=data_path, data_split=args.data_split+'_train', n_frames=args.clip_len, random_move=args.random_move, scale_range=args.scale_range_train) | |
ntu60_xsub_val = NTURGBD(data_path=data_path, data_split=args.data_split+'_val', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test) | |
train_loader = DataLoader(ntu60_xsub_train, **trainloader_params) | |
test_loader = DataLoader(ntu60_xsub_val, **testloader_params) | |
chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") | |
if os.path.exists(chk_filename): | |
opts.resume = chk_filename | |
if opts.resume or opts.evaluate: | |
chk_filename = opts.evaluate if opts.evaluate else opts.resume | |
print('Loading checkpoint', chk_filename) | |
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) | |
model.load_state_dict(checkpoint['model'], strict=True) | |
if not opts.evaluate: | |
optimizer = optim.AdamW( | |
[ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, | |
{"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, | |
], lr=args.lr_backbone, | |
weight_decay=args.weight_decay | |
) | |
scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) | |
st = 0 | |
print('INFO: Training on {} batches'.format(len(train_loader))) | |
if opts.resume: | |
st = checkpoint['epoch'] | |
if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
else: | |
print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') | |
lr = checkpoint['lr'] | |
if 'best_acc' in checkpoint and checkpoint['best_acc'] is not None: | |
best_acc = checkpoint['best_acc'] | |
# Training | |
for epoch in range(st, args.epochs): | |
print('Training epoch %d.' % epoch) | |
losses_train = AverageMeter() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
batch_time = AverageMeter() | |
data_time = AverageMeter() | |
model.train() | |
end = time.time() | |
iters = len(train_loader) | |
for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): # (N, 2, T, 17, 3) | |
data_time.update(time.time() - end) | |
batch_size = len(batch_input) | |
if torch.cuda.is_available(): | |
batch_gt = batch_gt.cuda() | |
batch_input = batch_input.cuda() | |
output = model(batch_input) # (N, num_classes) | |
optimizer.zero_grad() | |
loss_train = criterion(output, batch_gt) | |
losses_train.update(loss_train.item(), batch_size) | |
acc1, acc5 = accuracy(output, batch_gt, topk=(1, 5)) | |
top1.update(acc1[0], batch_size) | |
top5.update(acc5[0], batch_size) | |
loss_train.backward() | |
optimizer.step() | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if (idx + 1) % opts.print_freq == 0: | |
print('Train: [{0}][{1}/{2}]\t' | |
'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' | |
'loss {loss.val:.3f} ({loss.avg:.3f})\t' | |
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( | |
epoch, idx + 1, len(train_loader), batch_time=batch_time, | |
data_time=data_time, loss=losses_train, top1=top1)) | |
sys.stdout.flush() | |
test_loss, test_top1, test_top5 = validate(test_loader, model, criterion) | |
train_writer.add_scalar('train_loss', losses_train.avg, epoch + 1) | |
train_writer.add_scalar('train_top1', top1.avg, epoch + 1) | |
train_writer.add_scalar('train_top5', top5.avg, epoch + 1) | |
train_writer.add_scalar('test_loss', test_loss, epoch + 1) | |
train_writer.add_scalar('test_top1', test_top1, epoch + 1) | |
train_writer.add_scalar('test_top5', test_top5, epoch + 1) | |
scheduler.step() | |
# Save latest checkpoint. | |
chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') | |
print('Saving checkpoint to', chk_path) | |
torch.save({ | |
'epoch': epoch+1, | |
'lr': scheduler.get_last_lr(), | |
'optimizer': optimizer.state_dict(), | |
'model': model.state_dict(), | |
'best_acc' : best_acc | |
}, chk_path) | |
# Save best checkpoint. | |
best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) | |
if test_top1 > best_acc: | |
best_acc = test_top1 | |
print("save best checkpoint") | |
torch.save({ | |
'epoch': epoch+1, | |
'lr': scheduler.get_last_lr(), | |
'optimizer': optimizer.state_dict(), | |
'model': model.state_dict(), | |
'best_acc' : best_acc | |
}, best_chk_path) | |
if opts.evaluate: | |
test_loss, test_top1, test_top5 = validate(test_loader, model, criterion) | |
print('Loss {loss:.4f} \t' | |
'Acc@1 {top1:.3f} \t' | |
'Acc@5 {top5:.3f} \t'.format(loss=test_loss, top1=test_top1, top5=test_top5)) | |
if __name__ == "__main__": | |
opts = parse_args() | |
args = get_config(opts.config) | |
train_with_config(args, opts) |