Spaces:
Runtime error
Runtime error
import contextlib | |
import numpy as np | |
import random | |
import shutil | |
import os | |
import torch | |
def set_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"): | |
filename = os.path.join(checkpoint_path, filename) | |
torch.save(state, filename) | |
if is_best: | |
shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt")) | |
def load_checkpoint(model, path): | |
best_checkpoint = torch.load(path) | |
model.load_state_dict(best_checkpoint["state_dict"]) | |
def log_metrics(set_name, metrics, logger): | |
logger.info( | |
"{}: Loss: {:.5f} | spec_acc: {:.5f}, rgb_acc: {:.5f}".format( | |
set_name, metrics["loss"], metrics["spec_acc"], metrics["rgb_acc"] | |
) | |
) | |
def numpy_seed(seed, *addl_seeds): | |
"""Context manager which seeds the NumPy PRNG with the specified seed and | |
restores the state afterward""" | |
if seed is None: | |
yield | |
return | |
if len(addl_seeds) > 0: | |
seed = int(hash((seed, *addl_seeds)) % 1e6) | |
state = np.random.get_state() | |
np.random.seed(seed) | |
try: | |
yield | |
finally: | |
np.random.set_state(state) | |