oguzakif's picture
init repo
d4b77ac
raw
history blame
8.35 kB
import math
import parse
import logging
from utils import util
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from data import create_dataset, create_dataloader
from models.utils.loss import *
import yaml
from abc import abstractmethod, ABCMeta
from models.utils.flow_losses import AdversarialLoss
class Trainer(metaclass=ABCMeta):
def __init__(self, opt, rank):
self.opt = opt
self.rank = rank
# make directory and set logger
if rank <= 0:
self.mkdir()
self.logger, self.tb_logger = self.setLogger()
self.setSeed()
self.dataInfo, self.valInfo, self.trainSet, self.trainSize, self.totalIterations, self.totalEpochs, self.trainLoader, self.trainSampler = self.prepareDataset()
self.model, self.dist, self.optimizer, self.dist_optim, self.scheduler, self.dist_scheduler = self.init_model()
self.flow_model = self.init_flow_model()
self.model = self.model.to(self.opt['device'])
self.dist = self.dist.to(self.opt['device'])
if opt['path'].get('gen_state', None):
self.startEpoch, self.currentStep = self.resume_training()
else:
self.startEpoch, self.currentStep = 0, 0
if opt['distributed']:
self.model = DDP(
self.model,
device_ids=[self.opt['local_rank']],
output_device=self.opt['local_rank'],
find_unused_parameters=True
)
self.dist = DDP(
self.dist,
device_ids=[self.opt['local_rank']],
output_device=self.opt['local_rank'],
find_unused_parameters=True
)
if self.rank <= 0:
self.logger.info('Start training from epoch: {}, iter: {}'.format(
self.startEpoch, self.currentStep))
self.best_psnr = 0
self.valid_best_psnr = 0
self.maskedLoss = nn.L1Loss()
self.validLoss = nn.L1Loss()
self.adversarial_loss = AdversarialLoss(type='hinge')
self.adversarial_loss = self.adversarial_loss.to(self.opt['device'])
self.countDown = 0
# metrics recorder
self.total_loss = 0
self.total_psnr = 0
self.total_ssim = 0
self.total_l1 = 0
self.total_l2 = 0
def get_lr(self):
lr = []
for param_group in self.optimizer.param_groups:
lr += [param_group['lr']]
for param_group in self.dist_optim.param_groups:
lr += [param_group['lr']]
return lr
def adjust_learning_rate(self, optimizer, target_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = target_lr
for param_group in self.dist_optim.param_groups:
param_group['lr'] = target_lr
def mkdir(self):
new_name = util.mkdir_and_rename(self.opt['path']['OUTPUT_ROOT'])
if new_name:
self.opt['path']['TRAINING_STATE'] = os.path.join(new_name, 'training_state')
self.opt['path']['LOG'] = os.path.join(new_name, 'log')
self.opt['path']['VAL_IMAGES'] = os.path.join(new_name, 'val_images')
if not os.path.exists(self.opt['path']['TRAINING_STATE']):
os.makedirs(self.opt['path']['TRAINING_STATE'])
if not os.path.exists(self.opt['path']['LOG']):
os.makedirs(self.opt['path']['LOG'])
if not os.path.exists(self.opt['path']['VAL_IMAGES']):
os.makedirs(self.opt['path']['VAL_IMAGES'])
# save config file for output
with open(os.path.join(self.opt['path']['LOG'], 'config.yaml'), 'w') as f:
yaml.dump(self.opt, f)
def setLogger(self):
util.setup_logger('base', self.opt['path']['LOG'], 'train_' + self.opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(parse.toString(self.opt))
logger.info('OUTPUT DIR IS: {}'.format(self.opt['path']['OUTPUT_ROOT']))
if self.opt['use_tb_logger']:
version = float(torch.__version__[0:3])
if version >= 1.1:
from torch.utils.tensorboard import SummaryWriter
else:
logger.info('You are using PyTorch {}, Tensorboard will use [tensorboardX)'.format(version))
from tensorboardX import SummaryWriter
tb_logger = SummaryWriter(os.path.join(self.opt['path']['OUTPUT_ROOT'], 'log'))
else:
tb_logger = None
return logger, tb_logger
def setSeed(self):
seed = self.opt['train']['manual_seed']
if self.rank <= 0:
self.logger.info('Random seed: {}'.format(seed))
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
if seed == 0:
torch.backends.cudnn.deterministic = True
def prepareDataset(self):
dataInfo = self.opt['datasets']['dataInfo']
valInfo = self.opt['datasets']['valInfo']
valInfo['norm'] = self.opt['norm']
if self.rank <= 0:
self.logger.debug('Val info is: {}'.format(valInfo))
train_set, train_size, total_iterations, total_epochs = 0, 0, 0, 0
train_loader, train_sampler = None, None
for phase, dataset in self.opt['datasets'].items():
dataset['norm'] = self.opt['norm']
dataset['dataMode'] = self.opt['dataMode']
dataset['num_frames'] = self.opt['num_frames']
dataset['sample'] = self.opt['sample']
dataset['flow2rgb'] = self.opt['flow2rgb']
dataset['flow_direction'] = self.opt['flow_direction']
dataset['max_val'] = self.opt['max_val']
dataset['input_resolution'] = self.opt['input_resolution']
if phase.lower() == 'train':
train_set = create_dataset(dataset, dataInfo, phase, self.opt['datasetName_train'])
train_size = math.ceil(
len(train_set) / (dataset['batch_size'] * self.opt['world_size']))
total_iterations = self.opt['train']['MAX_ITERS']
total_epochs = int(math.ceil(total_iterations / train_size))
if self.opt['distributed']:
train_sampler = DistributedSampler(
train_set,
num_replicas=self.opt['world_size'],
rank=self.opt['global_rank'])
else:
train_sampler = None
train_loader = create_dataloader(phase, train_set, dataset, self.opt, train_sampler)
if self.rank <= 0:
self.logger.info('Number of training batches: {}, iters: {}'.format(len(train_set),
total_iterations))
self.logger.info('Total epoch needed: {} for iters {}'.format(total_epochs, total_iterations))
assert train_set != 0 and train_size != 0, "Train size cannot be zero"
assert train_loader is not None, "Cannot find train set, val set can be None"
return dataInfo, valInfo, train_set, train_size, total_iterations, total_epochs, train_loader, train_sampler
@abstractmethod
def init_model(self):
pass
@abstractmethod
def init_flow_model(self):
pass
@abstractmethod
def resume_training(self):
pass
def train(self):
for epoch in range(self.startEpoch, self.totalEpochs + 1):
if self.opt['distributed']:
self.trainSampler.set_epoch(epoch)
self._trainEpoch(epoch)
if self.currentStep > self.totalIterations:
break
if self.opt['use_valid'] and (epoch + 1) % self.opt['train']['val_freq'] == 0:
self._validate(epoch)
self.scheduler.step(epoch)
self.dist_scheduler.step(epoch)
@abstractmethod
def _trainEpoch(self, epoch):
pass
@abstractmethod
def _printLog(self, logs, epoch, loss):
pass
@abstractmethod
def save_checkpoint(self, epoch, is_best, metric, number):
pass
@abstractmethod
def _validate(self, epoch):
pass