# Copyright (c) 2020 Huawei Technologies Co., Ltd. # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode # # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE import os from os.path import basename import math import argparse import random import logging import cv2 import torch import torch.distributed as dist import torch.multiprocessing as mp import options.options as option from utils import util from data import create_dataloader, create_dataset from models import create_model from utils.timer import Timer, TickTock from utils.util import get_resume_paths import wandb def getEnv(name): import os; return True if name in os.environ.keys() else False def init_dist(backend='nccl', **kwargs): ''' initialization for distributed training''' # if mp.get_start_method(allow_none=True) is None: if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn') rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() torch.cuda.set_deviceDistIterSampler(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs) def main(): wandb.init(project='srflow') #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YMAL file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### distributed training settings opt['dist'] = False rank = -1 print('Disabled distributed training.') #### loading resume state if exists if opt['path'].get('resume_state', None): resume_state_path, _ = get_resume_paths(opt) # distributed resuming: all load into default GPU if resume_state_path is None: resume_state = None else: device_id = torch.cuda.current_device() resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt.get('use_tb_logger', False) and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter conf_name = basename(args.opt).replace(".yml", "") exp_dir = opt['path']['experiments_root'] log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train') log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid') tb_logger_train = SummaryWriter(log_dir=log_dir_train) tb_logger_valid = SummaryWriter(log_dir=log_dir_valid) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': full_dataset = create_dataset(dataset_opt) print('Dataset created') train_len = int(len(full_dataset) * 0.95) val_len = len(full_dataset) - train_len train_set, val_set = torch.utils.data.random_split(full_dataset, [train_len, val_len]) train_size = int(math.ceil(train_len / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) elif phase == 'val': continue else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model current_step = 0 if resume_state is None else resume_state['iter'] model = create_model(opt, current_step) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training timer = Timer() logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) timerData = TickTock() for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) timerData.tick() for _, train_data in enumerate(train_loader): timerData.tock() current_step += 1 if current_step > total_iters: break #### training model.feed_data(train_data) #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) try: nll = model.optimize_parameters(current_step) except RuntimeError as e: print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ") print(e) if nll is None: nll = 0 wandb.log({"loss": nll}) #### log def eta(t_iter): return (t_iter * (opt['train']['niter'] - current_step)) / 3600 if current_step % opt['logger']['print_freq'] == 0 \ or current_step - (resume_state['iter'] if resume_state else 0) < 25: avg_time = timer.get_average_and_reset() avg_data_time = timerData.get_average_and_reset() message = ' '.format( epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time, eta(avg_time), nll) print(message) timer.tick() # Reduce number of logs if current_step % 5 == 0: tb_logger_train.add_scalar('loss/nll', nll, current_step) tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step) tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step) tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step) tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step) for k, v in model.get_current_log().items(): tb_logger_train.add_scalar(k, v, current_step) # validation if current_step % opt['train']['val_freq'] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 nlls = [] for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) nll = model.test() if nll is None: nll = 0 nlls.append(nll) visuals = model.get_current_visuals() sr_img = None # Save SR images for reference if hasattr(model, 'heats'): for heat in model.heats: for i in range(model.n_sample): sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8 save_img_path = os.path.join(img_dir, '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name, current_step, int(heat * 100), i)) util.save_img(sr_img, save_img_path) else: sr_img = util.tensor2img(visuals['SR']) # uint8 save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) assert sr_img is not None # Save LQ images for reference save_img_path_lq = os.path.join(img_dir, '{:s}_LQ.png'.format(img_name)) if not os.path.isfile(save_img_path_lq): lq_img = util.tensor2img(visuals['LQ']) # uint8 util.save_img( cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'], interpolation=cv2.INTER_NEAREST), save_img_path_lq) # Save GT images for reference gt_img = util.tensor2img(visuals['GT']) # uint8 save_img_path_gt = os.path.join(img_dir, '{:s}_GT.png'.format(img_name)) if not os.path.isfile(save_img_path_gt): util.save_img(gt_img, save_img_path_gt) # calculate PSNR crop_size = opt['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx avg_nll = sum(nlls) / len(nlls) # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info(' psnr: {:.4e}'.format( epoch, current_step, avg_psnr)) # tensorboard logger tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step) tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step) tb_logger_train.flush() tb_logger_valid.flush() #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) timerData.tick() with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f: f.write("TRAIN_DONE") if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') if __name__ == '__main__': main()