vae_drilling / disvae /evaluate.py
Jonas Becker
1st try
7f19394
import os
import logging
import math
from functools import reduce
from collections import defaultdict
import json
from timeit import default_timer
from tqdm import trange, tqdm
import numpy as np
import torch
from disvae.models.losses import get_loss_f
from disvae.utils.math import log_density_gaussian
from disvae.utils.modelIO import save_metadata
TEST_LOSSES_FILE = "test_losses.log"
METRICS_FILENAME = "metrics.log"
METRIC_HELPERS_FILE = "metric_helpers.pth"
class Evaluator:
"""
Class to handle training of model.
Parameters
----------
model: disvae.vae.VAE
loss_f: disvae.models.BaseLoss
Loss function.
device: torch.device, optional
Device on which to run the code.
logger: logging.Logger, optional
Logger.
save_dir : str, optional
Directory for saving logs.
is_progress_bar: bool, optional
Whether to use a progress bar for training.
"""
def __init__(self, model, loss_f,
device=torch.device("cpu"),
logger=logging.getLogger(__name__),
save_dir="results",
is_progress_bar=True):
self.device = device
self.loss_f = loss_f
self.model = model.to(self.device)
self.logger = logger
self.save_dir = save_dir
self.is_progress_bar = is_progress_bar
self.logger.info("Testing Device: {}".format(self.device))
def __call__(self, data_loader, is_metrics=False, is_losses=True):
"""Compute all test losses.
Parameters
----------
data_loader: torch.utils.data.DataLoader
is_metrics: bool, optional
Whether to compute and store the disentangling metrics.
is_losses: bool, optional
Whether to compute and store the test losses.
"""
start = default_timer()
is_still_training = self.model.training
self.model.eval()
metric, losses = None, None
if is_metrics:
self.logger.info('Computing metrics...')
metrics = self.compute_metrics(data_loader)
self.logger.info('Losses: {}'.format(metrics))
save_metadata(metrics, self.save_dir, filename=METRICS_FILENAME)
if is_losses:
self.logger.info('Computing losses...')
losses = self.compute_losses(data_loader)
self.logger.info('Losses: {}'.format(losses))
save_metadata(losses, self.save_dir, filename=TEST_LOSSES_FILE)
if is_still_training:
self.model.train()
self.logger.info('Finished evaluating after {:.1f} min.'.format((default_timer() - start) / 60))
return metric, losses
def compute_losses(self, dataloader):
"""Compute all test losses.
Parameters
----------
data_loader: torch.utils.data.DataLoader
"""
storer = defaultdict(list)
for data, _ in tqdm(dataloader, leave=False, disable=not self.is_progress_bar):
data = data.to(self.device)
try:
recon_batch, latent_dist, latent_sample = self.model(data)
_ = self.loss_f(data, recon_batch, latent_dist, self.model.training,
storer, latent_sample=latent_sample)
except ValueError:
# for losses that use multiple optimizers (e.g. Factor)
_ = self.loss_f.call_optimize(data, self.model, None, storer)
losses = {k: sum(v) / len(dataloader) for k, v in storer.items()}
return losses
def compute_metrics(self, dataloader):
"""Compute all the metrics.
Parameters
----------
data_loader: torch.utils.data.DataLoader
"""
try:
lat_sizes = dataloader.dataset.lat_sizes
lat_names = dataloader.dataset.lat_names
except AttributeError:
raise ValueError("Dataset needs to have known true factors of variations to compute the metric. This does not seem to be the case for {}".format(type(dataloader.__dict__["dataset"]).__name__))
self.logger.info("Computing the empirical distribution q(z|x).")
samples_zCx, params_zCx = self._compute_q_zCx(dataloader)
len_dataset, latent_dim = samples_zCx.shape
self.logger.info("Estimating the marginal entropy.")
# marginal entropy H(z_j)
H_z = self._estimate_latent_entropies(samples_zCx, params_zCx)
# conditional entropy H(z|v)
samples_zCx = samples_zCx.view(*lat_sizes, latent_dim)
params_zCx = tuple(p.view(*lat_sizes, latent_dim) for p in params_zCx)
H_zCv = self._estimate_H_zCv(samples_zCx, params_zCx, lat_sizes, lat_names)
H_z = H_z.cpu()
H_zCv = H_zCv.cpu()
# I[z_j;v_k] = E[log \sum_x q(z_j|x)p(x|v_k)] + H[z_j] = - H[z_j|v_k] + H[z_j]
mut_info = - H_zCv + H_z
sorted_mut_info = torch.sort(mut_info, dim=1, descending=True)[0].clamp(min=0)
metric_helpers = {'marginal_entropies': H_z, 'cond_entropies': H_zCv}
mig = self._mutual_information_gap(sorted_mut_info, lat_sizes, storer=metric_helpers)
aam = self._axis_aligned_metric(sorted_mut_info, storer=metric_helpers)
metrics = {'MIG': mig.item(), 'AAM': aam.item()}
torch.save(metric_helpers, os.path.join(self.save_dir, METRIC_HELPERS_FILE))
return metrics
def _mutual_information_gap(self, sorted_mut_info, lat_sizes, storer=None):
"""Compute the mutual information gap as in [1].
References
----------
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
autoencoders." Advances in Neural Information Processing Systems. 2018.
"""
# difference between the largest and second largest mutual info
delta_mut_info = sorted_mut_info[:, 0] - sorted_mut_info[:, 1]
# NOTE: currently only works if balanced dataset for every factor of variation
# then H(v_k) = - |V_k|/|V_k| log(1/|V_k|) = log(|V_k|)
H_v = torch.from_numpy(lat_sizes).float().log()
mig_k = delta_mut_info / H_v
mig = mig_k.mean() # mean over factor of variations
if storer is not None:
storer["mig_k"] = mig_k
storer["mig"] = mig
return mig
def _axis_aligned_metric(self, sorted_mut_info, storer=None):
"""Compute the proposed axis aligned metrics."""
numerator = (sorted_mut_info[:, 0] - sorted_mut_info[:, 1:].sum(dim=1)).clamp(min=0)
aam_k = numerator / sorted_mut_info[:, 0]
aam_k[torch.isnan(aam_k)] = 0
aam = aam_k.mean() # mean over factor of variations
if storer is not None:
storer["aam_k"] = aam_k
storer["aam"] = aam
return aam
def _compute_q_zCx(self, dataloader):
"""Compute the empiricall disitribution of q(z|x).
Parameter
---------
dataloader: torch.utils.data.DataLoader
Batch data iterator.
Return
------
samples_zCx: torch.tensor
Tensor of shape (len_dataset, latent_dim) containing a sample of
q(z|x) for every x in the dataset.
params_zCX: tuple of torch.Tensor
Sufficient statistics q(z|x) for each training example. E.g. for
gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
"""
len_dataset = len(dataloader.dataset)
latent_dim = self.model.latent_dim
n_suff_stat = 2
q_zCx = torch.zeros(len_dataset, latent_dim, n_suff_stat, device=self.device)
n = 0
with torch.no_grad():
for x, label in dataloader:
batch_size = x.size(0)
idcs = slice(n, n + batch_size)
q_zCx[idcs, :, 0], q_zCx[idcs, :, 1] = self.model.encoder(x.to(self.device))
n += batch_size
params_zCX = q_zCx.unbind(-1)
samples_zCx = self.model.reparameterize(*params_zCX)
return samples_zCx, params_zCX
def _estimate_latent_entropies(self, samples_zCx, params_zCX,
n_samples=10000):
r"""Estimate :math:`H(z_j) = E_{q(z_j)} [-log q(z_j)] = E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]`
using the emperical distribution of :math:`p(x)`.
Note
----
- the expectation over the emperical distributio is: :math:`q(z) = 1/N sum_{n=1}^N q(z|x_n)`.
- we assume that q(z|x) is factorial i.e. :math:`q(z|x) = \prod_j q(z_j|x)`.
- computes numerically stable NLL: :math:`- log q(z) = log N - logsumexp_n=1^N log q(z|x_n)`.
Parameters
----------
samples_zCx: torch.tensor
Tensor of shape (len_dataset, latent_dim) containing a sample of
q(z|x) for every x in the dataset.
params_zCX: tuple of torch.Tensor
Sufficient statistics q(z|x) for each training example. E.g. for
gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
n_samples: int, optional
Number of samples to use to estimate the entropies.
Return
------
H_z: torch.Tensor
Tensor of shape (latent_dim) containing the marginal entropies H(z_j)
"""
len_dataset, latent_dim = samples_zCx.shape
device = samples_zCx.device
H_z = torch.zeros(latent_dim, device=device)
# sample from p(x)
samples_x = torch.randperm(len_dataset, device=device)[:n_samples]
# sample from p(z|x)
samples_zCx = samples_zCx.index_select(0, samples_x).view(latent_dim, n_samples)
mini_batch_size = 10
samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)
mean = params_zCX[0].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples)
log_var = params_zCX[1].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples)
log_N = math.log(len_dataset)
with trange(n_samples, leave=False, disable=self.is_progress_bar) as t:
for k in range(0, n_samples, mini_batch_size):
# log q(z_j|x) for n_samples
idcs = slice(k, k + mini_batch_size)
log_q_zCx = log_density_gaussian(samples_zCx[..., idcs],
mean[..., idcs],
log_var[..., idcs])
# numerically stable log q(z_j) for n_samples:
# log q(z_j) = -log N + logsumexp_{n=1}^N log q(z_j|x_n)
# As we don't know q(z) we appoximate it with the monte carlo
# expectation of q(z_j|x_n) over x. => fix a single z and look at
# proba for every x to generate it. n_samples is not used here !
log_q_z = -log_N + torch.logsumexp(log_q_zCx, dim=0, keepdim=False)
# H(z_j) = E_{z_j}[- log q(z_j)]
# mean over n_samples (i.e. dimesnion 1 because already summed over 0).
H_z += (-log_q_z).sum(1)
t.update(mini_batch_size)
H_z /= n_samples
return H_z
def _estimate_H_zCv(self, samples_zCx, params_zCx, lat_sizes, lat_names):
"""Estimate conditional entropies :math:`H[z|v]`."""
latent_dim = samples_zCx.size(-1)
len_dataset = reduce((lambda x, y: x * y), lat_sizes)
H_zCv = torch.zeros(len(lat_sizes), latent_dim, device=self.device)
for i_fac_var, (lat_size, lat_name) in enumerate(zip(lat_sizes, lat_names)):
idcs = [slice(None)] * len(lat_sizes)
for i in range(lat_size):
self.logger.info("Estimating conditional entropies for the {}th value of {}.".format(i, lat_name))
idcs[i_fac_var] = i
# samples from q(z,x|v)
samples_zxCv = samples_zCx[idcs].contiguous().view(len_dataset // lat_size,
latent_dim)
params_zxCv = tuple(p[idcs].contiguous().view(len_dataset // lat_size, latent_dim)
for p in params_zCx)
H_zCv[i_fac_var] += self._estimate_latent_entropies(samples_zxCv, params_zxCv
) / lat_size
return H_zCv