Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import math | |
from functools import reduce | |
from collections import defaultdict | |
import json | |
from timeit import default_timer | |
from tqdm import trange, tqdm | |
import numpy as np | |
import torch | |
from disvae.models.losses import get_loss_f | |
from disvae.utils.math import log_density_gaussian | |
from disvae.utils.modelIO import save_metadata | |
TEST_LOSSES_FILE = "test_losses.log" | |
METRICS_FILENAME = "metrics.log" | |
METRIC_HELPERS_FILE = "metric_helpers.pth" | |
class Evaluator: | |
""" | |
Class to handle training of model. | |
Parameters | |
---------- | |
model: disvae.vae.VAE | |
loss_f: disvae.models.BaseLoss | |
Loss function. | |
device: torch.device, optional | |
Device on which to run the code. | |
logger: logging.Logger, optional | |
Logger. | |
save_dir : str, optional | |
Directory for saving logs. | |
is_progress_bar: bool, optional | |
Whether to use a progress bar for training. | |
""" | |
def __init__(self, model, loss_f, | |
device=torch.device("cpu"), | |
logger=logging.getLogger(__name__), | |
save_dir="results", | |
is_progress_bar=True): | |
self.device = device | |
self.loss_f = loss_f | |
self.model = model.to(self.device) | |
self.logger = logger | |
self.save_dir = save_dir | |
self.is_progress_bar = is_progress_bar | |
self.logger.info("Testing Device: {}".format(self.device)) | |
def __call__(self, data_loader, is_metrics=False, is_losses=True): | |
"""Compute all test losses. | |
Parameters | |
---------- | |
data_loader: torch.utils.data.DataLoader | |
is_metrics: bool, optional | |
Whether to compute and store the disentangling metrics. | |
is_losses: bool, optional | |
Whether to compute and store the test losses. | |
""" | |
start = default_timer() | |
is_still_training = self.model.training | |
self.model.eval() | |
metric, losses = None, None | |
if is_metrics: | |
self.logger.info('Computing metrics...') | |
metrics = self.compute_metrics(data_loader) | |
self.logger.info('Losses: {}'.format(metrics)) | |
save_metadata(metrics, self.save_dir, filename=METRICS_FILENAME) | |
if is_losses: | |
self.logger.info('Computing losses...') | |
losses = self.compute_losses(data_loader) | |
self.logger.info('Losses: {}'.format(losses)) | |
save_metadata(losses, self.save_dir, filename=TEST_LOSSES_FILE) | |
if is_still_training: | |
self.model.train() | |
self.logger.info('Finished evaluating after {:.1f} min.'.format((default_timer() - start) / 60)) | |
return metric, losses | |
def compute_losses(self, dataloader): | |
"""Compute all test losses. | |
Parameters | |
---------- | |
data_loader: torch.utils.data.DataLoader | |
""" | |
storer = defaultdict(list) | |
for data, _ in tqdm(dataloader, leave=False, disable=not self.is_progress_bar): | |
data = data.to(self.device) | |
try: | |
recon_batch, latent_dist, latent_sample = self.model(data) | |
_ = self.loss_f(data, recon_batch, latent_dist, self.model.training, | |
storer, latent_sample=latent_sample) | |
except ValueError: | |
# for losses that use multiple optimizers (e.g. Factor) | |
_ = self.loss_f.call_optimize(data, self.model, None, storer) | |
losses = {k: sum(v) / len(dataloader) for k, v in storer.items()} | |
return losses | |
def compute_metrics(self, dataloader): | |
"""Compute all the metrics. | |
Parameters | |
---------- | |
data_loader: torch.utils.data.DataLoader | |
""" | |
try: | |
lat_sizes = dataloader.dataset.lat_sizes | |
lat_names = dataloader.dataset.lat_names | |
except AttributeError: | |
raise ValueError("Dataset needs to have known true factors of variations to compute the metric. This does not seem to be the case for {}".format(type(dataloader.__dict__["dataset"]).__name__)) | |
self.logger.info("Computing the empirical distribution q(z|x).") | |
samples_zCx, params_zCx = self._compute_q_zCx(dataloader) | |
len_dataset, latent_dim = samples_zCx.shape | |
self.logger.info("Estimating the marginal entropy.") | |
# marginal entropy H(z_j) | |
H_z = self._estimate_latent_entropies(samples_zCx, params_zCx) | |
# conditional entropy H(z|v) | |
samples_zCx = samples_zCx.view(*lat_sizes, latent_dim) | |
params_zCx = tuple(p.view(*lat_sizes, latent_dim) for p in params_zCx) | |
H_zCv = self._estimate_H_zCv(samples_zCx, params_zCx, lat_sizes, lat_names) | |
H_z = H_z.cpu() | |
H_zCv = H_zCv.cpu() | |
# I[z_j;v_k] = E[log \sum_x q(z_j|x)p(x|v_k)] + H[z_j] = - H[z_j|v_k] + H[z_j] | |
mut_info = - H_zCv + H_z | |
sorted_mut_info = torch.sort(mut_info, dim=1, descending=True)[0].clamp(min=0) | |
metric_helpers = {'marginal_entropies': H_z, 'cond_entropies': H_zCv} | |
mig = self._mutual_information_gap(sorted_mut_info, lat_sizes, storer=metric_helpers) | |
aam = self._axis_aligned_metric(sorted_mut_info, storer=metric_helpers) | |
metrics = {'MIG': mig.item(), 'AAM': aam.item()} | |
torch.save(metric_helpers, os.path.join(self.save_dir, METRIC_HELPERS_FILE)) | |
return metrics | |
def _mutual_information_gap(self, sorted_mut_info, lat_sizes, storer=None): | |
"""Compute the mutual information gap as in [1]. | |
References | |
---------- | |
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational | |
autoencoders." Advances in Neural Information Processing Systems. 2018. | |
""" | |
# difference between the largest and second largest mutual info | |
delta_mut_info = sorted_mut_info[:, 0] - sorted_mut_info[:, 1] | |
# NOTE: currently only works if balanced dataset for every factor of variation | |
# then H(v_k) = - |V_k|/|V_k| log(1/|V_k|) = log(|V_k|) | |
H_v = torch.from_numpy(lat_sizes).float().log() | |
mig_k = delta_mut_info / H_v | |
mig = mig_k.mean() # mean over factor of variations | |
if storer is not None: | |
storer["mig_k"] = mig_k | |
storer["mig"] = mig | |
return mig | |
def _axis_aligned_metric(self, sorted_mut_info, storer=None): | |
"""Compute the proposed axis aligned metrics.""" | |
numerator = (sorted_mut_info[:, 0] - sorted_mut_info[:, 1:].sum(dim=1)).clamp(min=0) | |
aam_k = numerator / sorted_mut_info[:, 0] | |
aam_k[torch.isnan(aam_k)] = 0 | |
aam = aam_k.mean() # mean over factor of variations | |
if storer is not None: | |
storer["aam_k"] = aam_k | |
storer["aam"] = aam | |
return aam | |
def _compute_q_zCx(self, dataloader): | |
"""Compute the empiricall disitribution of q(z|x). | |
Parameter | |
--------- | |
dataloader: torch.utils.data.DataLoader | |
Batch data iterator. | |
Return | |
------ | |
samples_zCx: torch.tensor | |
Tensor of shape (len_dataset, latent_dim) containing a sample of | |
q(z|x) for every x in the dataset. | |
params_zCX: tuple of torch.Tensor | |
Sufficient statistics q(z|x) for each training example. E.g. for | |
gaussian (mean, log_var) each of shape : (len_dataset, latent_dim). | |
""" | |
len_dataset = len(dataloader.dataset) | |
latent_dim = self.model.latent_dim | |
n_suff_stat = 2 | |
q_zCx = torch.zeros(len_dataset, latent_dim, n_suff_stat, device=self.device) | |
n = 0 | |
with torch.no_grad(): | |
for x, label in dataloader: | |
batch_size = x.size(0) | |
idcs = slice(n, n + batch_size) | |
q_zCx[idcs, :, 0], q_zCx[idcs, :, 1] = self.model.encoder(x.to(self.device)) | |
n += batch_size | |
params_zCX = q_zCx.unbind(-1) | |
samples_zCx = self.model.reparameterize(*params_zCX) | |
return samples_zCx, params_zCX | |
def _estimate_latent_entropies(self, samples_zCx, params_zCX, | |
n_samples=10000): | |
r"""Estimate :math:`H(z_j) = E_{q(z_j)} [-log q(z_j)] = E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]` | |
using the emperical distribution of :math:`p(x)`. | |
Note | |
---- | |
- the expectation over the emperical distributio is: :math:`q(z) = 1/N sum_{n=1}^N q(z|x_n)`. | |
- we assume that q(z|x) is factorial i.e. :math:`q(z|x) = \prod_j q(z_j|x)`. | |
- computes numerically stable NLL: :math:`- log q(z) = log N - logsumexp_n=1^N log q(z|x_n)`. | |
Parameters | |
---------- | |
samples_zCx: torch.tensor | |
Tensor of shape (len_dataset, latent_dim) containing a sample of | |
q(z|x) for every x in the dataset. | |
params_zCX: tuple of torch.Tensor | |
Sufficient statistics q(z|x) for each training example. E.g. for | |
gaussian (mean, log_var) each of shape : (len_dataset, latent_dim). | |
n_samples: int, optional | |
Number of samples to use to estimate the entropies. | |
Return | |
------ | |
H_z: torch.Tensor | |
Tensor of shape (latent_dim) containing the marginal entropies H(z_j) | |
""" | |
len_dataset, latent_dim = samples_zCx.shape | |
device = samples_zCx.device | |
H_z = torch.zeros(latent_dim, device=device) | |
# sample from p(x) | |
samples_x = torch.randperm(len_dataset, device=device)[:n_samples] | |
# sample from p(z|x) | |
samples_zCx = samples_zCx.index_select(0, samples_x).view(latent_dim, n_samples) | |
mini_batch_size = 10 | |
samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples) | |
mean = params_zCX[0].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples) | |
log_var = params_zCX[1].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples) | |
log_N = math.log(len_dataset) | |
with trange(n_samples, leave=False, disable=self.is_progress_bar) as t: | |
for k in range(0, n_samples, mini_batch_size): | |
# log q(z_j|x) for n_samples | |
idcs = slice(k, k + mini_batch_size) | |
log_q_zCx = log_density_gaussian(samples_zCx[..., idcs], | |
mean[..., idcs], | |
log_var[..., idcs]) | |
# numerically stable log q(z_j) for n_samples: | |
# log q(z_j) = -log N + logsumexp_{n=1}^N log q(z_j|x_n) | |
# As we don't know q(z) we appoximate it with the monte carlo | |
# expectation of q(z_j|x_n) over x. => fix a single z and look at | |
# proba for every x to generate it. n_samples is not used here ! | |
log_q_z = -log_N + torch.logsumexp(log_q_zCx, dim=0, keepdim=False) | |
# H(z_j) = E_{z_j}[- log q(z_j)] | |
# mean over n_samples (i.e. dimesnion 1 because already summed over 0). | |
H_z += (-log_q_z).sum(1) | |
t.update(mini_batch_size) | |
H_z /= n_samples | |
return H_z | |
def _estimate_H_zCv(self, samples_zCx, params_zCx, lat_sizes, lat_names): | |
"""Estimate conditional entropies :math:`H[z|v]`.""" | |
latent_dim = samples_zCx.size(-1) | |
len_dataset = reduce((lambda x, y: x * y), lat_sizes) | |
H_zCv = torch.zeros(len(lat_sizes), latent_dim, device=self.device) | |
for i_fac_var, (lat_size, lat_name) in enumerate(zip(lat_sizes, lat_names)): | |
idcs = [slice(None)] * len(lat_sizes) | |
for i in range(lat_size): | |
self.logger.info("Estimating conditional entropies for the {}th value of {}.".format(i, lat_name)) | |
idcs[i_fac_var] = i | |
# samples from q(z,x|v) | |
samples_zxCv = samples_zCx[idcs].contiguous().view(len_dataset // lat_size, | |
latent_dim) | |
params_zxCv = tuple(p[idcs].contiguous().view(len_dataset // lat_size, latent_dim) | |
for p in params_zCx) | |
H_zCv[i_fac_var] += self._estimate_latent_entropies(samples_zxCv, params_zxCv | |
) / lat_size | |
return H_zCv | |