|
|
|
|
|
''' |
|
@File : summary.py |
|
@Time : 2022/10/15 23:38:13 |
|
@Author : BQH |
|
@Version : 1.0 |
|
@Contact : [email protected] |
|
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA |
|
@Desc : 运行时日志文件 |
|
''' |
|
|
|
|
|
|
|
import os |
|
import sys |
|
import torch |
|
import logging |
|
from datetime import datetime |
|
|
|
|
|
|
|
try: |
|
from tensorboardX import SummaryWriter |
|
except ImportError: |
|
class SummaryWriter: |
|
def __init__(self, log_dir=None, comment='', **kwargs): |
|
print('\nunable to import tensorboardX, log will be recorded by pytorch!\n') |
|
self.log_dir = log_dir if log_dir is not None else './logs' |
|
os.makedirs('./logs', exist_ok=True) |
|
self.logs = {'comment': comment} |
|
return |
|
|
|
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): |
|
if tag in self.logs: |
|
self.logs[tag].append((scalar_value, global_step, walltime)) |
|
else: |
|
self.logs[tag] = [(scalar_value, global_step, walltime)] |
|
return |
|
|
|
def close(self): |
|
timestamp = str(datetime.now()).replace(' ', '_').replace(':', '_') |
|
torch.save(self.logs, os.path.join(self.log_dir, 'log_%s.pickle' % timestamp)) |
|
return |
|
|
|
|
|
class EmptySummaryWriter: |
|
def __init__(self, **kwargs): |
|
pass |
|
|
|
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): |
|
pass |
|
|
|
def close(self): |
|
pass |
|
|
|
|
|
def create_summary(distributed_rank=0, **kwargs): |
|
if distributed_rank > 0: |
|
return EmptySummaryWriter(**kwargs) |
|
else: |
|
return SummaryWriter(**kwargs) |
|
|
|
|
|
def create_logger(distributed_rank=0, save_dir=None): |
|
logger = logging.getLogger('logger') |
|
logger.setLevel(logging.DEBUG) |
|
|
|
filename = "log_%s.txt" % (datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) |
|
|
|
|
|
if distributed_rank > 0: |
|
return logger |
|
ch = logging.StreamHandler(stream=sys.stdout) |
|
ch.setLevel(logging.DEBUG) |
|
|
|
formatter = logging.Formatter("%(message)s [%(asctime)s]") |
|
ch.setFormatter(formatter) |
|
logger.addHandler(ch) |
|
|
|
if save_dir is not None: |
|
fh = logging.FileHandler(os.path.join(save_dir, filename)) |
|
fh.setLevel(logging.DEBUG) |
|
fh.setFormatter(formatter) |
|
logger.addHandler(fh) |
|
|
|
return logger |
|
|
|
|
|
class Saver: |
|
def __init__(self, distributed_rank, save_dir): |
|
self.distributed_rank = distributed_rank |
|
self.save_dir = save_dir |
|
os.makedirs(self.save_dir, exist_ok=True) |
|
return |
|
|
|
def save(self, obj, save_name): |
|
if self.distributed_rank == 0: |
|
torch.save(obj, os.path.join(self.save_dir, save_name + '.t7')) |
|
return 'checkpoint saved in %s !' % os.path.join(self.save_dir, save_name) |
|
else: |
|
return '' |
|
|
|
|
|
def create_saver(distributed_rank, save_dir): |
|
return Saver(distributed_rank, save_dir) |
|
|
|
|
|
class DisablePrint: |
|
def __init__(self, local_rank=0): |
|
self.local_rank = local_rank |
|
|
|
def __enter__(self): |
|
if self.local_rank != 0: |
|
self._original_stdout = sys.stdout |
|
sys.stdout = open(os.devnull, 'w') |
|
else: |
|
pass |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
if self.local_rank != 0: |
|
sys.stdout.close() |
|
sys.stdout = self._original_stdout |
|
else: |
|
pass |