Spaces:
Runtime error
Runtime error
import argparse | |
import numpy as np | |
import os | |
import sys | |
import time | |
from tqdm import tqdm | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
from config import init_args | |
import data | |
import models | |
from models import * | |
from utils import utils, torch_utils | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def validation(args, net, criterion, data_loader, device='cuda'): | |
# import pdb; pdb.set_trace() | |
net.eval() | |
pred_all = torch.tensor([]).to(device) | |
target_all = torch.tensor([]).to(device) | |
with torch.no_grad(): | |
for step, batch in tqdm(enumerate(data_loader), total=len(data_loader), desc="Validation"): | |
pred, target = predict(args, net, batch, device) | |
pred_all = torch.cat([pred_all, pred], dim=0) | |
target_all = torch.cat([target_all, target], dim=0) | |
res = criterion.evaluate(pred_all, target_all) | |
torch.cuda.empty_cache() | |
net.train() | |
return res | |
def predict(args, net, batch, device): | |
inputs = { | |
'frames': batch['frames'].to(device) | |
} | |
pred = net(inputs) | |
target = batch['label'].to(device) | |
return pred, target | |
def train(args, device): | |
# save dir | |
gpus = torch.cuda.device_count() | |
gpu_ids = list(range(gpus)) | |
# ----- make dirs for checkpoints ----- # | |
sys.stdout = utils.LoggerOutput(os.path.join('checkpoints', args.exp, 'log.txt')) | |
os.makedirs('./checkpoints/' + args.exp, exist_ok=True) | |
writer = SummaryWriter(os.path.join('./checkpoints', args.exp, 'visualization')) | |
# ------------------------------------- # | |
tqdm.write('{}'.format(args)) | |
# ------------------------------------ # | |
# ----- Dataset and Dataloader ----- # | |
train_dataset = data.GreatestHitDataset(args, split='train') | |
# train_dataset.getitem_test(1) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False) | |
val_dataset = data.GreatestHitDataset(args, split='val') | |
val_loader = DataLoader( | |
val_dataset, | |
batch_size=args.batch_size, | |
shuffle=False, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False) | |
# --------------------------------- # | |
# ----- Network ----- # | |
net = models.VideoOnsetNet(pretrained=False).to(device) | |
criterion = models.BCLoss(args) | |
optimizer = torch_utils.make_optimizer(net, args) | |
# --------------------- # | |
# -------- Loading checkpoints weights ------------- # | |
if args.resume: | |
resume = './checkpoints/' + args.resume | |
net, args.start_epoch = torch_utils.load_model(resume, net, device=device, strict=True) | |
if args.resume_optim: | |
tqdm.write('loading optimizer...') | |
optim_state = torch.load(resume)['optimizer'] | |
optimizer.load_state_dict(optim_state) | |
tqdm.write('loaded optimizer!') | |
else: | |
args.start_epoch = 0 | |
# ------------------- | |
net = nn.DataParallel(net, device_ids=gpu_ids) | |
# --------- Random or resume validation ------------ # | |
res = validation(args, net, criterion, val_loader, device) | |
writer.add_scalars('VideoOnset' + '/validation', res, args.start_epoch) | |
tqdm.write("Beginning, Validation results: {}".format(res)) | |
tqdm.write('\n') | |
# ----------------- Training ---------------- # | |
# import pdb; pdb.set_trace() | |
VALID_STEP = args.valid_step | |
for epoch in range(args.start_epoch, args.epochs): | |
running_loss = 0.0 | |
torch_utils.adjust_learning_rate(optimizer, epoch, args) | |
for step, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training"): | |
pred, target = predict(args, net, batch, device) | |
loss = criterion(pred, target) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if step % 1 == 0: | |
tqdm.write("Epoch: {}/{}, step: {}/{}, loss: {}".format(epoch+1, args.epochs, step+1, len(train_loader), loss)) | |
running_loss += loss.item() | |
current_step = epoch * len(train_loader) + step + 1 | |
BOARD_STEP = 3 | |
if (step+1) % BOARD_STEP == 0: | |
writer.add_scalar('VideoOnset' + '/training loss', running_loss / BOARD_STEP, current_step) | |
running_loss = 0.0 | |
# ----------- Validtion -------------- # | |
if (epoch + 1) % VALID_STEP == 0: | |
res = validation(args, net, criterion, val_loader, device) | |
writer.add_scalars('VideoOnset' + '/validation', res, epoch + 1) | |
tqdm.write("Epoch: {}/{}, Validation results: {}".format(epoch + 1, args.epochs, res)) | |
# ---------- Save model ----------- # | |
SAVE_STEP = args.save_step | |
if (epoch + 1) % SAVE_STEP == 0: | |
path = os.path.join('./checkpoints', args.exp, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar') | |
torch.save({'epoch': epoch + 1, | |
'step': current_step, | |
'state_dict': net.state_dict(), | |
'optimizer': optimizer.state_dict(), | |
}, | |
path) | |
# --------------------------------- # | |
torch.cuda.empty_cache() | |
tqdm.write('Training Complete!') | |
writer.close() | |
def test(args, device): | |
# save dir | |
gpus = torch.cuda.device_count() | |
gpu_ids = list(range(gpus)) | |
# ----- make dirs for results ----- # | |
sys.stdout = utils.LoggerOutput(os.path.join('results', args.exp, 'log.txt')) | |
os.makedirs('./results/' + args.exp, exist_ok=True) | |
# ------------------------------------- # | |
tqdm.write('{}'.format(args)) | |
# ------------------------------------ # | |
# ----- Dataset and Dataloader ----- # | |
test_dataset = data.GreatestHitDataset(args, split='test') | |
test_loader = DataLoader( | |
test_dataset, | |
batch_size=args.batch_size, | |
shuffle=False, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False) | |
# --------------------------------- # | |
# ----- Network ----- # | |
net = models.VideoOnsetNet(pretrained=False).to(device) | |
criterion = models.BCLoss(args) | |
# -------- Loading checkpoints weights ------------- # | |
if args.resume: | |
resume = './checkpoints/' + args.resume | |
net, _ = torch_utils.load_model(resume, net, device=device, strict=True) | |
# ------------------- # | |
net = nn.DataParallel(net, device_ids=gpu_ids) | |
# --------- Testing ------------ # | |
res = validation(args, net, criterion, test_loader, device) | |
tqdm.write("Testing results: {}".format(res)) | |
# CUDA_VISIBLE_DEVICES=1 python main.py --exp='EXP1' --epochs=100 --batch_size=12 --num_workers=8 --save_step=10 --valid_step=1 --lr=0.0001 --optim='Adam' --repeat=1 --schedule='cos' | |
if __name__ == '__main__': | |
args = init_args() | |
if args.test_mode: | |
test(args, DEVICE) | |
else: | |
train(args, DEVICE) |