vae_drilling / disvae /training.py
Jonas Becker
1st try
7f19394
# import imageio
import logging
import os
from collections import defaultdict
from timeit import default_timer
import torch
from torch.nn import functional as F
from tqdm import trange
from disvae.utils.modelIO import save_model
TRAIN_LOSSES_LOGFILE = "train_losses.log"
class Trainer:
"""
Class to handle training of model.
Parameters
----------
model: disvae.vae.VAE
optimizer: torch.optim.Optimizer
loss_f: disvae.models.BaseLoss
Loss function.
device: torch.device, optional
Device on which to run the code.
logger: logging.Logger, optional
Logger.
save_dir : str, optional
Directory for saving logs.
gif_visualizer : viz.Visualizer, optional
Gif Visualizer that should return samples at every epochs.
is_progress_bar: bool, optional
Whether to use a progress bar for training.
"""
def __init__(
self,
model,
optimizer,
loss_f,
device=torch.device("cpu"),
logger=logging.getLogger(__name__),
save_dir="results",
gif_visualizer=None,
is_progress_bar=True,
):
self.device = device
self.model = model.to(self.device)
self.loss_f = loss_f
self.optimizer = optimizer
self.save_dir = save_dir
self.is_progress_bar = is_progress_bar
self.logger = logger
self.losses_logger = LossesLogger(
os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE)
)
self.gif_visualizer = gif_visualizer
self.logger.info("Training Device: {}".format(self.device))
def __call__(self, data_loader, epochs=10, checkpoint_every=10):
"""
Trains the model.
Parameters
----------
data_loader: torch.utils.data.DataLoader
epochs: int, optional
Number of epochs to train the model for.
checkpoint_every: int, optional
Save a checkpoint of the trained model every n epoch.
"""
start = default_timer()
self.model.train()
for epoch in range(epochs):
storer = defaultdict(list)
mean_epoch_loss = self._train_epoch(data_loader, storer, epoch)
self.logger.info(
"Epoch: {} Average loss per image: {:.2f}".format(
epoch + 1, mean_epoch_loss
)
)
self.losses_logger.log(epoch, storer)
if self.gif_visualizer is not None:
self.gif_visualizer()
if epoch % checkpoint_every == 0:
save_model(
self.model, self.save_dir, filename="model-{}.pt".format(epoch)
)
if self.gif_visualizer is not None:
self.gif_visualizer.save_reset()
self.model.eval()
delta_time = (default_timer() - start) / 60
self.logger.info("Finished training after {:.1f} min.".format(delta_time))
def _train_epoch(self, data_loader, storer, epoch):
"""
Trains the model for one epoch.
Parameters
----------
data_loader: torch.utils.data.DataLoader
storer: dict
Dictionary in which to store important variables for vizualisation.
epoch: int
Epoch number
Return
------
mean_epoch_loss: float
Mean loss per image
"""
epoch_loss = 0.0
kwargs = dict(
desc="Epoch {}".format(epoch + 1),
leave=False,
disable=not self.is_progress_bar,
)
with trange(len(data_loader), **kwargs) as t:
for _, (data, _) in enumerate(data_loader):
iter_loss = self._train_iteration(data, storer)
epoch_loss += iter_loss
t.set_postfix(loss=iter_loss)
t.update()
mean_epoch_loss = epoch_loss / len(data_loader)
return mean_epoch_loss
def _train_iteration(self, data, storer):
"""
Trains the model for one iteration on a batch of data.
Parameters
----------
data: torch.Tensor
A batch of data. Shape : (batch_size, channel, height, width).
storer: dict
Dictionary in which to store important variables for vizualisation.
"""
batch_size, channel, height, width = data.size()
data = data.to(self.device)
try:
recon_batch, latent_dist, latent_sample = self.model(data)
loss = self.loss_f(
data,
recon_batch,
latent_dist,
self.model.training,
storer,
latent_sample=latent_sample,
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
except ValueError:
# for losses that use multiple optimizers (e.g. Factor)
loss = self.loss_f.call_optimize(data, self.model, self.optimizer, storer)
return loss.item()
class LossesLogger(object):
"""Class definition for objects to write data to log files in a
form which is then easy to be plotted.
"""
def __init__(self, file_path_name):
"""Create a logger to store information for plotting."""
if os.path.isfile(file_path_name):
os.remove(file_path_name)
self.logger = logging.getLogger("losses_logger")
self.logger.setLevel(1) # always store
file_handler = logging.FileHandler(file_path_name)
file_handler.setLevel(1)
self.logger.addHandler(file_handler)
header = ",".join(["Epoch", "Loss", "Value"])
self.logger.debug(header)
def log(self, epoch, losses_storer):
"""Write to the log file"""
for k, v in losses_storer.items():
log_string = ",".join(str(item) for item in [epoch, k, mean(v)])
self.logger.debug(log_string)
# HELPERS
def mean(l):
"""Compute the mean of a list"""
return sum(l) / len(l)