Spaces:
Running
Running
import glob | |
import os | |
import sys | |
import time | |
import math | |
from datetime import datetime | |
import random | |
import logging | |
from collections import OrderedDict | |
import natsort | |
import numpy as np | |
import cv2 | |
import torch | |
from torchvision.utils import make_grid | |
from shutil import get_terminal_size | |
import yaml | |
try: | |
from yaml import CLoader as Loader, CDumper as Dumper | |
except ImportError: | |
from yaml import Loader, Dumper | |
def OrderedYaml(): | |
'''yaml orderedDict support''' | |
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG | |
def dict_representer(dumper, data): | |
return dumper.represent_dict(data.items()) | |
def dict_constructor(loader, node): | |
return OrderedDict(loader.construct_pairs(node)) | |
Dumper.add_representer(OrderedDict, dict_representer) | |
Loader.add_constructor(_mapping_tag, dict_constructor) | |
return Loader, Dumper | |
#################### | |
# miscellaneous | |
#################### | |
def get_timestamp(): | |
return datetime.now().strftime('%y%m%d-%H%M%S') | |
def mkdir(path): | |
if not os.path.exists(path): | |
os.makedirs(path) | |
def mkdirs(paths): | |
if isinstance(paths, str): | |
mkdir(paths) | |
else: | |
for path in paths: | |
mkdir(path) | |
def mkdir_and_rename(path): | |
if os.path.exists(path): | |
new_name = path + '_archived_' + get_timestamp() | |
print('Path already exists. Rename it to [{:s}]'.format(new_name)) | |
logger = logging.getLogger('base') | |
logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) | |
os.rename(path, new_name) | |
os.makedirs(path) | |
def set_random_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): | |
'''set up logger''' | |
lg = logging.getLogger(logger_name) | |
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', | |
datefmt='%y-%m-%d %H:%M:%S') | |
lg.setLevel(level) | |
if tofile: | |
log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) | |
fh = logging.FileHandler(log_file, mode='w') | |
fh.setFormatter(formatter) | |
lg.addHandler(fh) | |
if screen: | |
sh = logging.StreamHandler() | |
sh.setFormatter(formatter) | |
lg.addHandler(sh) | |
#################### | |
# image convert | |
#################### | |
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): | |
''' | |
Converts a torch Tensor into an image Numpy array | |
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order | |
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) | |
''' | |
if hasattr(tensor, 'detach'): | |
tensor = tensor.detach() | |
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp | |
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] | |
n_dim = tensor.dim() | |
if n_dim == 4: | |
n_img = len(tensor) | |
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() | |
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR | |
elif n_dim == 3: | |
img_np = tensor.numpy() | |
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR | |
elif n_dim == 2: | |
img_np = tensor.numpy() | |
else: | |
raise TypeError( | |
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) | |
if out_type == np.uint8: | |
img_np = (img_np * 255.0).round() | |
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default. | |
return img_np.astype(out_type) | |
def save_img(img, img_path, mode='RGB'): | |
cv2.imwrite(img_path, img) | |
#################### | |
# metric | |
#################### | |
def calculate_psnr(img1, img2): | |
# img1 and img2 have range [0, 255] | |
img1 = img1.astype(np.float64) | |
img2 = img2.astype(np.float64) | |
mse = np.mean((img1 - img2) ** 2) | |
if mse == 0: | |
return float('inf') | |
return 20 * math.log10(255.0 / math.sqrt(mse)) | |
def get_resume_paths(opt): | |
resume_state_path = None | |
resume_model_path = None | |
ts = opt_get(opt, ['path', 'training_state']) | |
if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None: | |
wildcard = os.path.join(ts, "*") | |
paths = natsort.natsorted(glob.glob(wildcard)) | |
if len(paths) > 0: | |
resume_state_path = paths[-1] | |
resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth') | |
else: | |
resume_state_path = opt.get('path', {}).get('resume_state') | |
return resume_state_path, resume_model_path | |
def opt_get(opt, keys, default=None): | |
if opt is None: | |
return default | |
ret = opt | |
for k in keys: | |
ret = ret.get(k, None) | |
if ret is None: | |
return default | |
return ret | |