# ------------------------------------------------------------------------------------------ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import functools import os, shutil import numpy as np import torch def logging(s, log_path, print_=True, log_=True): if print_: print(s) if log_: with open(log_path, 'a+') as f_log: f_log.write(s + '\n') def get_logger(log_path, **kwargs): return functools.partial(logging, log_path=log_path, **kwargs) def create_exp_dir(dir_path, scripts_to_save=None, debug=False): if debug: print('Debug Mode : no experiment dir created') return functools.partial(logging, log_path=None, log_=False) if not os.path.exists(dir_path): os.makedirs(dir_path) print('Experiment dir : {}'.format(dir_path)) if scripts_to_save is not None: script_path = os.path.join(dir_path, 'scripts') if not os.path.exists(script_path): os.makedirs(script_path) for script in scripts_to_save: dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) shutil.copyfile(script, dst_file) return get_logger(log_path=os.path.join(dir_path, 'log.txt')) def save_checkpoint(model, optimizer, path, epoch): torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))