Spaces:
Sleeping
Sleeping
File size: 7,726 Bytes
d4b77ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import math
import parse
import logging
from utils import util
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from data import create_dataset, create_dataloader
from models.utils.loss import *
import yaml
from models.utils.edgeLoss import EdgeLoss
from abc import abstractmethod, ABCMeta
class Trainer(metaclass=ABCMeta):
def __init__(self, opt, rank):
self.opt = opt
self.rank = rank
# make directory and set logger
if rank <= 0:
self.mkdir()
self.logger, self.tb_logger = self.setLogger()
self.setSeed()
self.dataInfo, self.valInfo, self.trainSet, self.trainSize, self.totalIterations, self.totalEpochs, self.trainLoader, self.trainSampler = self.prepareDataset()
self.model, self.optimizer, self.scheduler = self.init_model()
self.model = self.model.to(self.opt['device'])
if opt['path'].get('opt_state', None):
self.startEpoch, self.currentStep = self.resume_training()
else:
self.startEpoch, self.currentStep = 0, 0
if opt['distributed']:
self.model = DDP(
self.model,
device_ids=[self.opt['local_rank']],
output_device=self.opt['local_rank'],
# find_unused_parameters=True
)
if self.rank <= 0:
self.logger.info('Start training from epoch: {}, iter: {}'.format(
self.startEpoch, self.currentStep))
self.maskedLoss = nn.L1Loss()
self.validLoss = nn.L1Loss()
self.edgeLoss = EdgeLoss(self.opt['device'])
self.countDown = 0
# metrics recorder
self.total_loss = 0
self.total_psnr = 0
self.total_ssim = 0
self.total_l1 = 0
self.total_l2 = 0
def get_lr(self):
lr = []
for param_group in self.optimizer.param_groups:
lr += [param_group['lr']]
return lr
def adjust_learning_rate(self, optimizer, target_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = target_lr
def mkdir(self):
new_name = util.mkdir_and_rename(self.opt['path']['OUTPUT_ROOT'])
if new_name:
self.opt['path']['TRAINING_STATE'] = os.path.join(new_name, 'training_state')
self.opt['path']['LOG'] = os.path.join(new_name, 'log')
self.opt['path']['VAL_IMAGES'] = os.path.join(new_name, 'val_images')
if not os.path.exists(self.opt['path']['TRAINING_STATE']):
os.makedirs(self.opt['path']['TRAINING_STATE'])
if not os.path.exists(self.opt['path']['LOG']):
os.makedirs(self.opt['path']['LOG'])
if not os.path.exists(self.opt['path']['VAL_IMAGES']):
os.makedirs(self.opt['path']['VAL_IMAGES'])
# save config file for output
with open(os.path.join(self.opt['path']['LOG'], 'config.yaml'), 'w') as f:
yaml.dump(self.opt, f)
def setLogger(self):
util.setup_logger('base', self.opt['path']['LOG'], 'train_' + self.opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(parse.toString(self.opt))
logger.info('OUTPUT DIR IS: {}'.format(self.opt['path']['OUTPUT_ROOT']))
if self.opt['use_tb_logger']:
version = float(torch.__version__[0:3])
if version >= 1.1:
from torch.utils.tensorboard import SummaryWriter
else:
logger.info('You are using PyTorch {}, Tensorboard will use [tensorboardX)'.format(version))
from tensorboardX import SummaryWriter
tb_logger = SummaryWriter(os.path.join(self.opt['path']['OUTPUT_ROOT'], 'log'))
else:
tb_logger = None
return logger, tb_logger
def setSeed(self):
seed = self.opt['train']['manual_seed']
if self.rank <= 0:
self.logger.info('Random seed: {}'.format(seed))
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
if seed == 0:
torch.backends.cudnn.deterministic = True
def prepareDataset(self):
dataInfo = self.opt['datasets']['dataInfo']
valInfo = self.opt['datasets']['valInfo']
valInfo['sigma'] = dataInfo['edge']['sigma']
valInfo['low_threshold'] = dataInfo['edge']['low_threshold']
valInfo['high_threshold'] = dataInfo['edge']['high_threshold']
valInfo['norm'] = self.opt['norm']
if self.rank <= 0:
self.logger.debug('Val info is: {}'.format(valInfo))
train_set, train_size, total_iterations, total_epochs = 0, 0, 0, 0
train_loader, train_sampler = None, None
for phase, dataset in self.opt['datasets'].items():
dataset['norm'] = self.opt['norm']
dataset['dataMode'] = self.opt['dataMode']
dataset['edge_loss'] = self.opt['edge_loss']
dataset['ternary'] = self.opt['ternary']
dataset['num_flows'] = self.opt['num_flows']
dataset['sample'] = self.opt['sample']
dataset['use_edges'] = self.opt['use_edges']
dataset['flow_interval'] = self.opt['flow_interval']
if phase.lower() == 'train':
train_set = create_dataset(dataset, dataInfo, phase, self.opt['datasetName_train'])
train_size = math.ceil(
len(train_set) / (dataset['batch_size'] * self.opt['world_size'])) # 计算一个epoch有多少个iterations
total_iterations = self.opt['train']['MAX_ITERS']
total_epochs = int(math.ceil(total_iterations / train_size))
if self.opt['distributed']:
train_sampler = DistributedSampler(
train_set,
num_replicas=self.opt['world_size'],
rank=self.opt['global_rank'])
else:
train_sampler = None
train_loader = create_dataloader(phase, train_set, dataset, self.opt, train_sampler)
if self.rank <= 0:
self.logger.info('Number of training batches: {}, iters: {}'.format(len(train_set),
total_iterations))
self.logger.info('Total epoch needed: {} for iters {}'.format(total_epochs, total_iterations))
assert train_set != 0 and train_size != 0, "Train size cannot be zero"
assert train_loader is not None, "Cannot find train set, val set can be None"
return dataInfo, valInfo, train_set, train_size, total_iterations, total_epochs, train_loader, train_sampler
@abstractmethod
def init_model(self):
pass
@abstractmethod
def resume_training(self):
pass
def train(self):
for epoch in range(self.startEpoch, self.totalEpochs + 1):
if self.opt['distributed']:
self.trainSampler.set_epoch(epoch)
self._trainEpoch(epoch)
if self.currentStep > self.totalIterations:
break
if self.opt['use_valid'] and (epoch + 1) % self.opt['train']['val_freq'] == 0:
self._validate(epoch)
self.scheduler.step(epoch)
@abstractmethod
def _trainEpoch(self, epoch):
pass
@abstractmethod
def _printLog(self, logs, epoch, loss):
pass
@abstractmethod
def save_checkpoint(self, epoch, metric, number):
pass
@abstractmethod
def _validate(self, epoch):
pass
|