Spaces:
Sleeping
Sleeping
from typing import Optional, Callable, Tuple | |
import warnings | |
from abc import ABC, abstractmethod | |
from einops import rearrange, repeat | |
import torch | |
import torch.nn as nn | |
def relaxed_one_hot_categorical_without_replacement(temperature, logits, num_samples=1): | |
# See paper Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement (https://arxiv.org/pdf/1903.06059.pdf) | |
# for explanation of the trick | |
scores = ( | |
(torch.distributions.Gumbel(logits, 1).rsample() / temperature) | |
.softmax(-1) | |
.clamp_min(1e-10) | |
) | |
top_scores, top_indices = torch.topk( | |
scores, | |
num_samples, | |
dim=-1, | |
) | |
return scores, top_indices | |
class AbstractLatentDistribution(nn.Module, ABC): | |
"""Base class for latent distribution""" | |
def sample( | |
self, num_samples: int, *args, **kwargs | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Sample from the latent distribution.""" | |
def kl_loss( | |
self, | |
other: "GaussianLatentDistribution", | |
threshold: float = 0, | |
mask_z: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
"""Compute the KL divergence between two latent distributions.""" | |
def sampling_loss(self) -> torch.Tensor: | |
"""Loss of the latent distribution.""" | |
def average( | |
self, other: "AbstractLatentDistribution", weight_other: torch.Tensor | |
) -> "AbstractLatentDistribution": | |
"""Average of the latent distribution.""" | |
def log_dict(self, type: str) -> dict: | |
"""Log the latent distribution values.""" | |
class GaussianLatentDistribution(AbstractLatentDistribution): | |
"""Gaussian latent distribution""" | |
def __init__(self, latent_representation: torch.Tensor): | |
super().__init__() | |
mu, logvar = torch.chunk(latent_representation, 2, dim=-1) | |
self.register_buffer("mu", mu, False) | |
self.register_buffer("logvar", logvar, False) | |
def sample( | |
self, n_samples: int = 0, *args, **kwargs | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Sample from Gaussian with a reparametrization trick | |
Args: | |
n_samples (optional): number of samples to make, (if 0 one sample with no extra | |
dimension). Defaults to 0. | |
Returns: | |
Random Gaussian sample of size (some_shape, (n_samples), latent_dim) | |
""" | |
std = (self.logvar / 2).exp() | |
if n_samples <= 0: | |
eps = torch.randn_like(std) | |
latent_samples = self.mu + eps * std | |
weights = torch.ones_like(latent_samples[..., 0]) | |
else: | |
eps = torch.randn( | |
[*std.shape[:-1], n_samples, self.mu.shape[-1]], device=std.device | |
) | |
# Reshape | |
latent_samples = self.mu.unsqueeze(-2) + eps * std.unsqueeze(-2) | |
weights = torch.ones_like(latent_samples[..., 0]) / n_samples | |
return latent_samples, weights | |
def kl_loss( | |
self, | |
other: "GaussianLatentDistribution", | |
threshold: float = 0, | |
mask_z: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
"""Compute the KL divergence between two latent distributions.""" | |
assert type(other) == GaussianLatentDistribution | |
kl_loss = ( | |
(other.logvar | |
- self.logvar | |
+ ((self.mu - other.mu).square() + self.logvar.exp()) / other.logvar.exp() | |
- 1)*0.5 | |
).clamp_min(threshold) | |
if mask_z is None: | |
return kl_loss.mean() | |
else: | |
assert mask_z.any() | |
return torch.sum(kl_loss.mean(-1) * mask_z) / torch.sum(mask_z) | |
def sampling_loss(self) -> torch.Tensor: | |
return torch.zeros(1, device=self.mu.device) | |
def average( | |
self, other: "GaussianLatentDistribution", weight_other: torch.Tensor | |
) -> "GaussianLatentDistribution": | |
assert type(other) == GaussianLatentDistribution | |
assert other.mu.shape == self.mu.shape | |
average_log_var = ( | |
self.logvar.exp() * (1 - weight_other) + other.logvar.exp() * weight_other | |
).log() | |
return GaussianLatentDistribution( | |
torch.cat( | |
( | |
self.mu * (1 - weight_other) + other.mu * weight_other, | |
average_log_var, | |
), | |
dim=-1, | |
) | |
) | |
def log_dict(self, type: str) -> dict: | |
return { | |
f"latent/{type}/abs_mean": self.mu.abs().mean(), | |
f"latent/{type}/std": (self.logvar * 0.5).exp().mean(), | |
} | |
class QuantizedLatentDistribution(AbstractLatentDistribution): | |
"""Quantized latent distribution. | |
It is defined with a codebook of quantized latents and a continuous latent. | |
The distribution is based on distances of the continuous latent to the codebook. | |
Sampling is only quantizing the continuous latent. | |
Args: | |
continuous_latent : Continuous latent representation of shape (some_shape, latent_dim) | |
codebook : Codebook of shape (num_embeddings, latent_dim) | |
""" | |
def __init__( | |
self, | |
continuous_latent: torch.Tensor, | |
codebook: torch.Tensor, | |
flush_weights: Callable[[], None], | |
get_weights: Callable[[], torch.Tensor], | |
index_add_one_weights: Callable[[torch.Tensor], None], | |
): | |
super().__init__() | |
self.register_buffer("continuous_latent", continuous_latent, False) | |
self.register_buffer("codebook", codebook, False) | |
self.flush_weights = flush_weights | |
self.get_weights = get_weights | |
self.index_add_one_weights = index_add_one_weights | |
self.quantization_loss = None | |
self.accuracy = None | |
def sample( | |
self, n_samples: int = 0, *args, **kwargs | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Quantize the continuous latent from the latent dictionary. | |
Args: | |
latent: (batch_size, num_agents, latent_dim) Continuous latent input | |
Returns: | |
quantized_latent, quantization_loss | |
""" | |
assert n_samples == 0, "Only one sample is supported for quantized latent" | |
distances_to_quantized = ( | |
( | |
self.codebook.view(1, 1, *self.codebook.shape) | |
- self.continuous_latent.unsqueeze(-2) | |
) | |
.square() | |
.sum(-1) | |
) | |
batch_size, num_agents, num_vq = distances_to_quantized.shape | |
self.soft_one_hot = ( | |
(-100 * distances_to_quantized) | |
.softmax(dim=-1) | |
.view(batch_size, num_agents, num_vq) | |
) | |
# quantized, args_selected = self.sample(soft_one_hot) | |
_, args_selected = torch.min(distances_to_quantized, dim=-1) | |
quantized = self.codebook[args_selected, :] | |
args_selected = args_selected.view(-1) | |
# Update weights | |
self.index_add_one_weights(args_selected) | |
distances_to_quantized = distances_to_quantized.view( | |
batch_size * num_agents, num_vq | |
) | |
# Resample useless latent vectors | |
random_latents = self.continuous_latent.view( | |
batch_size * num_agents, self.codebook.shape[-1] | |
)[torch.randint(batch_size * num_agents, (num_vq,))] | |
codebook_weights = self.get_weights() | |
total_samples = codebook_weights.sum() | |
# TODO: The value 100 is arbitrary, should it be a parameter? | |
# The uselessness of a codebook vector is defined by the number of times it has been sampled | |
# if it has been sampled less than 1% of the time, it is pushed towards a random continuous latent sample | |
# this prevents the codebook from being dominated by a few vectors | |
self.uselessness = ( | |
( | |
torch.where( | |
(codebook_weights < total_samples / (100 * num_vq)).unsqueeze(-1), | |
random_latents.detach() - self.codebook, | |
torch.zeros_like(self.codebook), | |
).abs() | |
+ 1 | |
) | |
.log() | |
.sum(-1) | |
.mean() | |
) | |
# TODO: The value 1e6 is arbitrary, should it be a parameter? | |
if total_samples > 1e6 * num_vq: | |
# Flush the codebook weights when the number of samples is too high | |
# This prevents the codebook from being dominated by its history | |
# if a few vectors were visited a lot and also prevents overflows | |
self.flush_weights() | |
# commit_loss = (self.continuous_latent - quantized.detach()).square().clamp_min(self.distance_threshold).sum(-1).mean() | |
self.quantization_loss = ( | |
(self.continuous_latent - quantized).square().sum(-1).mean() | |
) | |
quantized = ( | |
quantized.detach() | |
+ self.continuous_latent | |
- self.continuous_latent.detach() | |
) | |
self.latent_diversity = ( | |
(self.continuous_latent[None, ...] - self.continuous_latent[:, None, ...]) | |
.square() | |
.sum(-1) | |
.mean() | |
) | |
return quantized, torch.ones_like(quantized[..., 0]) / num_vq | |
def kl_loss( | |
self, | |
other: "ClassifiedLatentDistribution", | |
threshold: float = 0, | |
mask_z: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
"""Compute the cross entropy between two latent distributions.""" | |
assert type(other) == ClassifiedLatentDistribution | |
min_logits = -10 | |
max_logits = 10 | |
pred_log = other.logits.clamp(min_logits, max_logits).log_softmax(-1) | |
self_pred = self.soft_one_hot | |
self.accuracy = (self_pred.argmax(-1) == other.logits.argmax(-1)).float().mean() | |
return -2 * (pred_log * self_pred).sum(-1).mean() | |
def sampling_loss(self) -> torch.Tensor: | |
if self.quantization_loss is None: | |
self.sample() | |
return 0.5 * ( | |
self.quantization_loss + self.uselessness + 0.001 * self.latent_diversity | |
) | |
def average( | |
self, other: "QuantizedLatentDistribution", weight_other: torch.Tensor | |
) -> "QuantizedLatentDistribution": | |
raise NotImplementedError( | |
"Average is not implemented for QuantizedLatentDistribution" | |
) | |
def log_dict(self, type: str) -> dict: | |
log_dict = { | |
f"latent/{type}/quantization_loss": self.quantization_loss, | |
f"latent/{type}/uselessness": self.uselessness, | |
f"latent/{type}/latent_diversity": self.latent_diversity, | |
f"latent/{type}/codebook_abs_mean": self.codebook.abs().mean(), | |
f"latent/{type}/codebook_std": self.codebook.std(), | |
f"latent/{type}/latent_abs_mean": self.continuous_latent.abs().mean(), | |
f"latent/{type}/latent_std": self.continuous_latent.std(), | |
} | |
if self.accuracy is not None: | |
log_dict[f"latent/{type}/accuracy"] = self.accuracy | |
return log_dict | |
class ClassifiedLatentDistribution(AbstractLatentDistribution): | |
"""Classified latent distribution. | |
It is defined with a codebook of quantized latents and a probability distribution over the codebook elements. | |
Args: | |
logits : Logits of shape (some_shape, num_embeddings) | |
codebook : Codebook of shape (num_embeddings, latent_dim) | |
""" | |
def __init__(self, logits: torch.Tensor, codebook: torch.Tensor): | |
super().__init__() | |
self.register_buffer("logits", logits, persistent=False) | |
self.register_buffer("codebook", codebook, persistent=False) | |
def sample( | |
self, n_samples: int = 0, replacement: bool = True, *args, **kwargs | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
batch_size, num_agents, num_vq = self.logits.shape | |
squeeze_out = False | |
if n_samples == 0: | |
squeeze_out = True | |
n_samples = 1 | |
elif n_samples > self.codebook.shape[0]: | |
warnings.warn( | |
f"Requested {n_samples} samples but only {self.codebook.shape[0]} are available in the descrete latent space. Switching to replacement=True to support it." | |
) | |
replacement = True | |
if self.training: | |
# TODO: should we make the temperature a parameter? | |
all_weights, indices = relaxed_one_hot_categorical_without_replacement( | |
logits=self.logits, temperature=1, num_samples=n_samples | |
) | |
selected_latents = self.codebook[indices, :] | |
# Cumulative mask of indices that have been sampled in order of probability | |
mask_selection = torch.nn.functional.one_hot(indices, num_vq).cumsum(-2) | |
mask_selection[..., 1:, :] = mask_selection[..., :-1, :] | |
mask_selection[..., 0, :] = 0.0 | |
# Remove the probability of previous samples to account for sampling without replacement | |
masked_weights = all_weights.unsqueeze(-2) * (1 - mask_selection.float()) | |
# Renormalize the probabilities to sum to 1 | |
masked_weights = masked_weights / masked_weights.sum(-1, keepdim=True) | |
latent_samples = ( | |
masked_weights.unsqueeze(-1) | |
* self.codebook[None, None, None, ...].detach() | |
).sum(-2) | |
latent_samples = ( | |
selected_latents.detach() + latent_samples - latent_samples.detach() | |
) | |
probs = torch.gather(self.logits.softmax(-1), -1, indices) | |
else: | |
probs = self.logits.softmax(-1) | |
samples = torch.multinomial( | |
probs.view(batch_size * num_agents, num_vq), | |
n_samples, | |
replacement=replacement, | |
) | |
latent_samples = self.codebook[samples] | |
probs = torch.gather( | |
probs, -1, samples.view(batch_size, num_agents, num_vq) | |
) | |
if squeeze_out: | |
latent_samples = latent_samples.view( | |
batch_size, num_agents, self.codebook.shape[-1] | |
) | |
else: | |
latent_samples = latent_samples.view( | |
batch_size, num_agents, n_samples, self.codebook.shape[-1] | |
) | |
return latent_samples, probs | |
def kl_loss( | |
self, | |
other: "ClassifiedLatentDistribution", | |
threshold: float = 0, | |
mask_z: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
"""Compute the cross entropy between two latent distributions. Self being the reference distribution and other the distribution to compare.""" | |
assert type(other) == ClassifiedLatentDistribution | |
min_logits = -10 | |
max_logits = 10 | |
pred_log = other.logits.clamp(min_logits, max_logits).log_softmax(-1) | |
self_pred = ( | |
(0.5 * (self.logits.detach() + self.logits)) | |
.clamp(min_logits, max_logits) | |
.softmax(-1) | |
) | |
return -2 * (pred_log * self_pred).sum(-1).mean() | |
def sampling_loss(self) -> torch.Tensor: | |
return torch.zeros(1, device=self.logits.device) | |
def average( | |
self, other: "ClassifiedLatentDistribution", weight_other: torch.Tensor | |
) -> "ClassifiedLatentDistribution": | |
assert type(other) == ClassifiedLatentDistribution | |
assert (self.codebook == other.codebook).all() | |
return ClassifiedLatentDistribution( | |
( | |
self.logits.exp() * (1 - weight_other) | |
+ other.logits.exp() * weight_other | |
).log(), | |
self.codebook, | |
) | |
def log_dict(self, type: str) -> dict: | |
max_probs, _ = self.logits.softmax(-1).max(-1) | |
return { | |
f"latent/{type}/codebook_abs_mean": self.codebook.abs().mean(), | |
f"latent/{type}/codebook_std": self.codebook.std(), | |
f"latent/{type}/class_max_mean": max_probs.mean(), | |
f"latent/{type}/class_max_std": max_probs.std(), | |
} | |
class QuantizedDistributionCreator(nn.Module): | |
"""Creates a distribution from a latent vector.""" | |
def __init__( | |
self, | |
latent_dim: int, | |
num_embeddings: int, | |
): | |
super().__init__() | |
self.latent_dim = latent_dim | |
self.num_embeddings = num_embeddings | |
self.codebook = nn.Parameter(torch.randn(num_embeddings, latent_dim)) | |
self.register_buffer( | |
"codebook_weights", | |
torch.ones(num_embeddings, requires_grad=False), | |
persistent=False, | |
) | |
def _flush_codebook_weights(self): | |
self.codebook_weights = torch.ones_like(self.codebook_weights) | |
def _get_codebook_weights(self): | |
return self.codebook_weights | |
def _index_add_one_codebook_weight(self, indices: torch.Tensor): | |
self.codebook_weights = self.codebook_weights.index_add( | |
0, | |
indices.flatten(), | |
torch.ones_like(self.codebook_weights[indices]), | |
) | |
def forward(self, latent: torch.Tensor) -> AbstractLatentDistribution: | |
if latent.shape[-1] == self.latent_dim: | |
return QuantizedLatentDistribution( | |
latent, | |
self.codebook, | |
self._flush_codebook_weights, | |
self._get_codebook_weights, | |
self._index_add_one_codebook_weight, | |
) | |
elif latent.shape[-1] == self.num_embeddings: | |
return ClassifiedLatentDistribution( | |
latent, | |
self.codebook, | |
) | |
else: | |
raise ValueError(f"Latent vector has wrong dimension: {latent.shape[-1]}") | |