Jonas Becker
1st try
7f19394
import math
from tqdm import trange, tqdm
import torch
def matrix_log_density_gaussian(x, mu, logvar):
"""Calculates log density of a Gaussian for all combination of bacth pairs of
`x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)`
instead of (batch_size, dim) in the usual log density.
Parameters
----------
x: torch.Tensor
Value at which to compute the density. Shape: (batch_size, dim).
mu: torch.Tensor
Mean. Shape: (batch_size, dim).
logvar: torch.Tensor
Log variance. Shape: (batch_size, dim).
batch_size: int
number of training images in the batch
"""
batch_size, dim = x.shape
x = x.view(batch_size, 1, dim)
mu = mu.view(1, batch_size, dim)
logvar = logvar.view(1, batch_size, dim)
return log_density_gaussian(x, mu, logvar)
def log_density_gaussian(x, mu, logvar):
"""Calculates log density of a Gaussian.
Parameters
----------
x: torch.Tensor or np.ndarray or float
Value at which to compute the density.
mu: torch.Tensor or np.ndarray or float
Mean.
logvar: torch.Tensor or np.ndarray or float
Log variance.
"""
normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
inv_var = torch.exp(-logvar)
log_density = normalization - 0.5 * ((x - mu)**2 * inv_var)
return log_density
def log_importance_weight_matrix(batch_size, dataset_size):
"""
Calculates a log importance weight matrix
Parameters
----------
batch_size: int
number of training images in the batch
dataset_size: int
number of training images in the dataset
"""
N = dataset_size
M = batch_size - 1
strat_weight = (N - M) / (N * M)
W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
W.view(-1)[::M + 1] = 1 / N
W.view(-1)[1::M + 1] = strat_weight
W[M - 1, 0] = strat_weight
return W.log()