Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from os.path import join | |
import os | |
class EarlyStopping: | |
"""Early stops the training if validation loss doesn't improve after a given patience.""" | |
def __init__(self, patience=7, verbose=False, delta=0, save_path="."): | |
""" | |
Args: | |
patience (int): How long to wait after last time validation loss improved. | |
Default: 7 | |
verbose (bool): If True, prints a message for each validation loss improvement. | |
Default: False | |
delta (float): Minimum change in the monitored quantity to qualify as an improvement. | |
Default: 0 | |
""" | |
self.patience = patience | |
self.verbose = verbose | |
self.counter = 0 | |
self.best_score = None | |
self.early_stop = False | |
self.val_loss_min = np.Inf | |
self.delta = delta | |
self.save_path = save_path | |
def __call__(self, val_loss, model): | |
score = -val_loss | |
if self.best_score is None: | |
self.best_score = score | |
self.save_checkpoint(val_loss, model) | |
elif score < self.best_score + self.delta: | |
self.counter += 1 | |
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |
if self.counter >= self.patience: | |
self.early_stop = True | |
else: | |
self.best_score = score | |
self.save_checkpoint(val_loss, model) | |
self.counter = 0 | |
def save_checkpoint(self, val_loss, model): | |
'''Saves model when validation loss decrease.''' | |
if self.verbose: | |
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') | |
# save_path = join(self.save_path, "best_model") | |
# if not os.path.exists(save_path): | |
# os.mkdir(save_path) | |
# model_to_save = model.module if hasattr(model, 'module') else model | |
# model_to_save.save_pretrained(save_path) | |
self.val_loss_min = val_loss | |