Spaces:
Running
Running
# import imageio | |
import logging | |
import os | |
from collections import defaultdict | |
from timeit import default_timer | |
import torch | |
from torch.nn import functional as F | |
from tqdm import trange | |
from disvae.utils.modelIO import save_model | |
TRAIN_LOSSES_LOGFILE = "train_losses.log" | |
class Trainer: | |
""" | |
Class to handle training of model. | |
Parameters | |
---------- | |
model: disvae.vae.VAE | |
optimizer: torch.optim.Optimizer | |
loss_f: disvae.models.BaseLoss | |
Loss function. | |
device: torch.device, optional | |
Device on which to run the code. | |
logger: logging.Logger, optional | |
Logger. | |
save_dir : str, optional | |
Directory for saving logs. | |
gif_visualizer : viz.Visualizer, optional | |
Gif Visualizer that should return samples at every epochs. | |
is_progress_bar: bool, optional | |
Whether to use a progress bar for training. | |
""" | |
def __init__( | |
self, | |
model, | |
optimizer, | |
loss_f, | |
device=torch.device("cpu"), | |
logger=logging.getLogger(__name__), | |
save_dir="results", | |
gif_visualizer=None, | |
is_progress_bar=True, | |
): | |
self.device = device | |
self.model = model.to(self.device) | |
self.loss_f = loss_f | |
self.optimizer = optimizer | |
self.save_dir = save_dir | |
self.is_progress_bar = is_progress_bar | |
self.logger = logger | |
self.losses_logger = LossesLogger( | |
os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE) | |
) | |
self.gif_visualizer = gif_visualizer | |
self.logger.info("Training Device: {}".format(self.device)) | |
def __call__(self, data_loader, epochs=10, checkpoint_every=10): | |
""" | |
Trains the model. | |
Parameters | |
---------- | |
data_loader: torch.utils.data.DataLoader | |
epochs: int, optional | |
Number of epochs to train the model for. | |
checkpoint_every: int, optional | |
Save a checkpoint of the trained model every n epoch. | |
""" | |
start = default_timer() | |
self.model.train() | |
for epoch in range(epochs): | |
storer = defaultdict(list) | |
mean_epoch_loss = self._train_epoch(data_loader, storer, epoch) | |
self.logger.info( | |
"Epoch: {} Average loss per image: {:.2f}".format( | |
epoch + 1, mean_epoch_loss | |
) | |
) | |
self.losses_logger.log(epoch, storer) | |
if self.gif_visualizer is not None: | |
self.gif_visualizer() | |
if epoch % checkpoint_every == 0: | |
save_model( | |
self.model, self.save_dir, filename="model-{}.pt".format(epoch) | |
) | |
if self.gif_visualizer is not None: | |
self.gif_visualizer.save_reset() | |
self.model.eval() | |
delta_time = (default_timer() - start) / 60 | |
self.logger.info("Finished training after {:.1f} min.".format(delta_time)) | |
def _train_epoch(self, data_loader, storer, epoch): | |
""" | |
Trains the model for one epoch. | |
Parameters | |
---------- | |
data_loader: torch.utils.data.DataLoader | |
storer: dict | |
Dictionary in which to store important variables for vizualisation. | |
epoch: int | |
Epoch number | |
Return | |
------ | |
mean_epoch_loss: float | |
Mean loss per image | |
""" | |
epoch_loss = 0.0 | |
kwargs = dict( | |
desc="Epoch {}".format(epoch + 1), | |
leave=False, | |
disable=not self.is_progress_bar, | |
) | |
with trange(len(data_loader), **kwargs) as t: | |
for _, (data, _) in enumerate(data_loader): | |
iter_loss = self._train_iteration(data, storer) | |
epoch_loss += iter_loss | |
t.set_postfix(loss=iter_loss) | |
t.update() | |
mean_epoch_loss = epoch_loss / len(data_loader) | |
return mean_epoch_loss | |
def _train_iteration(self, data, storer): | |
""" | |
Trains the model for one iteration on a batch of data. | |
Parameters | |
---------- | |
data: torch.Tensor | |
A batch of data. Shape : (batch_size, channel, height, width). | |
storer: dict | |
Dictionary in which to store important variables for vizualisation. | |
""" | |
batch_size, channel, height, width = data.size() | |
data = data.to(self.device) | |
try: | |
recon_batch, latent_dist, latent_sample = self.model(data) | |
loss = self.loss_f( | |
data, | |
recon_batch, | |
latent_dist, | |
self.model.training, | |
storer, | |
latent_sample=latent_sample, | |
) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
except ValueError: | |
# for losses that use multiple optimizers (e.g. Factor) | |
loss = self.loss_f.call_optimize(data, self.model, self.optimizer, storer) | |
return loss.item() | |
class LossesLogger(object): | |
"""Class definition for objects to write data to log files in a | |
form which is then easy to be plotted. | |
""" | |
def __init__(self, file_path_name): | |
"""Create a logger to store information for plotting.""" | |
if os.path.isfile(file_path_name): | |
os.remove(file_path_name) | |
self.logger = logging.getLogger("losses_logger") | |
self.logger.setLevel(1) # always store | |
file_handler = logging.FileHandler(file_path_name) | |
file_handler.setLevel(1) | |
self.logger.addHandler(file_handler) | |
header = ",".join(["Epoch", "Loss", "Value"]) | |
self.logger.debug(header) | |
def log(self, epoch, losses_storer): | |
"""Write to the log file""" | |
for k, v in losses_storer.items(): | |
log_string = ",".join(str(item) for item in [epoch, k, mean(v)]) | |
self.logger.debug(log_string) | |
# HELPERS | |
def mean(l): | |
"""Compute the mean of a list""" | |
return sum(l) / len(l) | |