Spaces:
Running
Running
# Copyright (c) 2020 Huawei Technologies Co., Ltd. | |
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode | |
# | |
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE | |
import os | |
from os.path import basename | |
import math | |
import argparse | |
import random | |
import logging | |
import cv2 | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import options.options as option | |
from utils import util | |
from data import create_dataloader, create_dataset | |
from models import create_model | |
from utils.timer import Timer, TickTock | |
from utils.util import get_resume_paths | |
import wandb | |
def getEnv(name): import os; return True if name in os.environ.keys() else False | |
def init_dist(backend='nccl', **kwargs): | |
''' initialization for distributed training''' | |
# if mp.get_start_method(allow_none=True) is None: | |
if mp.get_start_method(allow_none=True) != 'spawn': | |
mp.set_start_method('spawn') | |
rank = int(os.environ['RANK']) | |
num_gpus = torch.cuda.device_count() | |
torch.cuda.set_deviceDistIterSampler(rank % num_gpus) | |
dist.init_process_group(backend=backend, **kwargs) | |
def main(): | |
wandb.init(project='srflow') | |
#### options | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-opt', type=str, help='Path to option YMAL file.') | |
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', | |
help='job launcher') | |
parser.add_argument('--local_rank', type=int, default=0) | |
args = parser.parse_args() | |
opt = option.parse(args.opt, is_train=True) | |
#### distributed training settings | |
opt['dist'] = False | |
rank = -1 | |
print('Disabled distributed training.') | |
#### loading resume state if exists | |
if opt['path'].get('resume_state', None): | |
resume_state_path, _ = get_resume_paths(opt) | |
# distributed resuming: all load into default GPU | |
if resume_state_path is None: | |
resume_state = None | |
else: | |
device_id = torch.cuda.current_device() | |
resume_state = torch.load(resume_state_path, | |
map_location=lambda storage, loc: storage.cuda(device_id)) | |
option.check_resume(opt, resume_state['iter']) # check resume options | |
else: | |
resume_state = None | |
#### mkdir and loggers | |
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) | |
if resume_state is None: | |
util.mkdir_and_rename( | |
opt['path']['experiments_root']) # rename experiment folder if exists | |
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' | |
and 'pretrain_model' not in key and 'resume' not in key)) | |
# config loggers. Before it, the log will not work | |
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, | |
screen=True, tofile=True) | |
util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, | |
screen=True, tofile=True) | |
logger = logging.getLogger('base') | |
logger.info(option.dict2str(opt)) | |
# tensorboard logger | |
if opt.get('use_tb_logger', False) and 'debug' not in opt['name']: | |
version = float(torch.__version__[0:3]) | |
if version >= 1.1: # PyTorch 1.1 | |
from torch.utils.tensorboard import SummaryWriter | |
else: | |
logger.info( | |
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) | |
from tensorboardX import SummaryWriter | |
conf_name = basename(args.opt).replace(".yml", "") | |
exp_dir = opt['path']['experiments_root'] | |
log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train') | |
log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid') | |
tb_logger_train = SummaryWriter(log_dir=log_dir_train) | |
tb_logger_valid = SummaryWriter(log_dir=log_dir_valid) | |
else: | |
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) | |
logger = logging.getLogger('base') | |
# convert to NoneDict, which returns None for missing keys | |
opt = option.dict_to_nonedict(opt) | |
#### random seed | |
seed = opt['train']['manual_seed'] | |
if seed is None: | |
seed = random.randint(1, 10000) | |
if rank <= 0: | |
logger.info('Random seed: {}'.format(seed)) | |
util.set_random_seed(seed) | |
torch.backends.cudnn.benchmark = True | |
# torch.backends.cudnn.deterministic = True | |
#### create train and val dataloader | |
dataset_ratio = 200 # enlarge the size of each epoch | |
for phase, dataset_opt in opt['datasets'].items(): | |
if phase == 'train': | |
full_dataset = create_dataset(dataset_opt) | |
print('Dataset created') | |
train_len = int(len(full_dataset) * 0.95) | |
val_len = len(full_dataset) - train_len | |
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_len, val_len]) | |
train_size = int(math.ceil(train_len / dataset_opt['batch_size'])) | |
total_iters = int(opt['train']['niter']) | |
total_epochs = int(math.ceil(total_iters / train_size)) | |
train_sampler = None | |
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) | |
if rank <= 0: | |
logger.info('Number of train images: {:,d}, iters: {:,d}'.format( | |
len(train_set), train_size)) | |
logger.info('Total epochs needed: {:d} for iters {:,d}'.format( | |
total_epochs, total_iters)) | |
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1, | |
pin_memory=True) | |
elif phase == 'val': | |
continue | |
else: | |
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) | |
assert train_loader is not None | |
#### create model | |
current_step = 0 if resume_state is None else resume_state['iter'] | |
model = create_model(opt, current_step) | |
#### resume training | |
if resume_state: | |
logger.info('Resuming training from epoch: {}, iter: {}.'.format( | |
resume_state['epoch'], resume_state['iter'])) | |
start_epoch = resume_state['epoch'] | |
current_step = resume_state['iter'] | |
model.resume_training(resume_state) # handle optimizers and schedulers | |
else: | |
current_step = 0 | |
start_epoch = 0 | |
#### training | |
timer = Timer() | |
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) | |
timerData = TickTock() | |
for epoch in range(start_epoch, total_epochs + 1): | |
if opt['dist']: | |
train_sampler.set_epoch(epoch) | |
timerData.tick() | |
for _, train_data in enumerate(train_loader): | |
timerData.tock() | |
current_step += 1 | |
if current_step > total_iters: | |
break | |
#### training | |
model.feed_data(train_data) | |
#### update learning rate | |
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) | |
try: | |
nll = model.optimize_parameters(current_step) | |
except RuntimeError as e: | |
print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ") | |
print(e) | |
if nll is None: | |
nll = 0 | |
wandb.log({"loss": nll}) | |
#### log | |
def eta(t_iter): | |
return (t_iter * (opt['train']['niter'] - current_step)) / 3600 | |
if current_step % opt['logger']['print_freq'] == 0 \ | |
or current_step - (resume_state['iter'] if resume_state else 0) < 25: | |
avg_time = timer.get_average_and_reset() | |
avg_data_time = timerData.get_average_and_reset() | |
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.2e}, td:{:.2e}, eta:{:.2e}, nll:{:.3e}> '.format( | |
epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time, | |
eta(avg_time), nll) | |
print(message) | |
timer.tick() | |
# Reduce number of logs | |
if current_step % 5 == 0: | |
tb_logger_train.add_scalar('loss/nll', nll, current_step) | |
tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step) | |
tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step) | |
tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step) | |
tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step) | |
for k, v in model.get_current_log().items(): | |
tb_logger_train.add_scalar(k, v, current_step) | |
# validation | |
if current_step % opt['train']['val_freq'] == 0 and rank <= 0: | |
avg_psnr = 0.0 | |
idx = 0 | |
nlls = [] | |
for val_data in val_loader: | |
idx += 1 | |
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] | |
img_dir = os.path.join(opt['path']['val_images'], img_name) | |
util.mkdir(img_dir) | |
model.feed_data(val_data) | |
nll = model.test() | |
if nll is None: | |
nll = 0 | |
nlls.append(nll) | |
visuals = model.get_current_visuals() | |
sr_img = None | |
# Save SR images for reference | |
if hasattr(model, 'heats'): | |
for heat in model.heats: | |
for i in range(model.n_sample): | |
sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8 | |
save_img_path = os.path.join(img_dir, | |
'{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name, | |
current_step, | |
int(heat * 100), i)) | |
util.save_img(sr_img, save_img_path) | |
else: | |
sr_img = util.tensor2img(visuals['SR']) # uint8 | |
save_img_path = os.path.join(img_dir, | |
'{:s}_{:d}.png'.format(img_name, current_step)) | |
util.save_img(sr_img, save_img_path) | |
assert sr_img is not None | |
# Save LQ images for reference | |
save_img_path_lq = os.path.join(img_dir, | |
'{:s}_LQ.png'.format(img_name)) | |
if not os.path.isfile(save_img_path_lq): | |
lq_img = util.tensor2img(visuals['LQ']) # uint8 | |
util.save_img( | |
cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'], | |
interpolation=cv2.INTER_NEAREST), | |
save_img_path_lq) | |
# Save GT images for reference | |
gt_img = util.tensor2img(visuals['GT']) # uint8 | |
save_img_path_gt = os.path.join(img_dir, | |
'{:s}_GT.png'.format(img_name)) | |
if not os.path.isfile(save_img_path_gt): | |
util.save_img(gt_img, save_img_path_gt) | |
# calculate PSNR | |
crop_size = opt['scale'] | |
gt_img = gt_img / 255. | |
sr_img = sr_img / 255. | |
cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] | |
cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] | |
avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) | |
avg_psnr = avg_psnr / idx | |
avg_nll = sum(nlls) / len(nlls) | |
# log | |
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) | |
logger_val = logging.getLogger('val') # validation logger | |
logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format( | |
epoch, current_step, avg_psnr)) | |
# tensorboard logger | |
tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step) | |
tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step) | |
tb_logger_train.flush() | |
tb_logger_valid.flush() | |
#### save models and training states | |
if current_step % opt['logger']['save_checkpoint_freq'] == 0: | |
if rank <= 0: | |
logger.info('Saving models and training states.') | |
model.save(current_step) | |
model.save_training_state(epoch, current_step) | |
timerData.tick() | |
with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f: | |
f.write("TRAIN_DONE") | |
if rank <= 0: | |
logger.info('Saving the final model.') | |
model.save('latest') | |
logger.info('End of training.') | |
if __name__ == '__main__': | |
main() | |