Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
class GaussianMixtureModel: | |
def __init__(self, mu, Sigma, pi): | |
# Ensure all inputs are torch tensors | |
assert isinstance(mu, torch.Tensor), "mu must be a torch.Tensor." | |
assert isinstance(Sigma, torch.Tensor), "Sigma must be a torch.Tensor." | |
assert isinstance(pi, torch.Tensor), "pi must be a torch.Tensor." | |
# Ensure the dtype is torch.float32 | |
assert mu.dtype == torch.float32, "mu must have dtype torch.float32." | |
assert Sigma.dtype == torch.float32, "Sigma must have dtype torch.float32." | |
assert pi.dtype == torch.float32, "pi must have dtype torch.float32." | |
self.K, self.d = mu.shape[:2] | |
# Check the shape of mu | |
if mu.shape == (self.K, self.d): | |
mu = mu.unsqueeze(-1) # Shape: (K, d, 1) | |
assert mu.shape == (self.K, self.d, 1), "mu must have shape (K, d, 1)." | |
# Check the shape of Sigma | |
assert Sigma.shape == (self.K, self.d, self.d), "Sigma must have shape (K, d, d)." | |
# Check the shape of pi and fix it if necessary | |
if pi.shape == (self.K,): | |
pi = pi.view(self.K, 1, 1) | |
elif pi.shape == (self.K, 1): | |
pi = pi.unsqueeze(-1) | |
assert pi.shape == (self.K, 1, 1), "pi must have shape (K, 1, 1)." | |
# Ensure pi sums to 1 | |
assert torch.isclose(torch.sum(pi), torch.tensor(1.0)), "Mixture weights must sum to 1." | |
self.mu = mu | |
self.Sigma = Sigma | |
self.pi = pi | |
def sample(self, n_samples): | |
# Sample from the mixture model | |
samples = [] | |
for _ in range(n_samples): | |
# Choose a component based on mixture weights | |
k = torch.multinomial(self.pi.squeeze(), 1).item() | |
sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample() | |
samples.append(sample) | |
return torch.stack(samples) | |
def log_prob(self, x): | |
# Compute the log probability of a given sample x | |
x = x.view(1, self.d, 1) # Shape: (1, d, 1) | |
diff = x - self.mu # Shape: (K, d, 1) | |
inv_Sigma = torch.inverse(self.Sigma) # Shape: (K, d, d) | |
exponent = -0.5 * torch.bmm(torch.bmm(diff.transpose(1, 2), inv_Sigma), diff).squeeze() # Shape: (K,) | |
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)) / 2 # Shape: (K,) | |
log_probs = torch.log(self.pi.squeeze()) + exponent - normalization # Shape: (K,) | |
return torch.logsumexp(log_probs, dim=0) | |
def score(self, x): | |
# Compute the score function (gradient of log probability) for a batch of samples x | |
B = x.shape[0] | |
x = x.view(B, 1, self.d, 1) # Shape: (B, 1, d, 1) | |
diff = x - self.mu.unsqueeze(0) # Shape: (B, K, d, 1) | |
inv_Sigma = torch.inverse(self.Sigma).unsqueeze(0) # Shape: (1, K, d, d) | |
diff_t = diff.transpose(-2, -1).contiguous().view(B * self.K, 1, self.d) # Shape: (B*K, 1, d) | |
inv_Sigma_flat = inv_Sigma.view(self.K, self.d, self.d).expand(B, self.K, self.d, self.d).contiguous().view(B * self.K, self.d, self.d) # Shape: (B*K, d, d) | |
exponent = -0.5 * torch.bmm(torch.bmm(diff_t, inv_Sigma_flat), diff.view(B * self.K, self.d, 1)).view(B, self.K) # Shape: (B, K) | |
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)).unsqueeze(0) / 2 # Shape: (1, K) | |
probs = torch.exp(torch.log(self.pi.squeeze()).unsqueeze(0) + exponent - normalization) # Shape: (B, K) | |
norm_probs = probs / torch.sum(probs, dim=1, keepdim=True) # Shape: (B, K) | |
gradients = [] | |
for k in range(self.K): | |
gradient = -torch.bmm(inv_Sigma[:, k].expand(B, self.d, self.d), diff[:, k]) # Shape: (B, d, 1) | |
gradients.append(norm_probs[:, k].unsqueeze(1) * gradient.squeeze(-1)) # Shape: (B, d) | |
return torch.sum(torch.stack(gradients, dim=1), dim=1) # Shape: (B, d) | |
def forward_diffusion(self, t): | |
# Compute the evolved mean and covariance | |
mu_t = self.mu * torch.tensor(np.exp(-0.5 * t), dtype=torch.float32) # Shape: (K, d, 1) | |
exp_neg_beta_t = torch.tensor(np.exp(-t), dtype=torch.float32).view(1, 1, 1) # Shape: (1, 1, 1) | |
Sigma_t = self.Sigma * exp_neg_beta_t + torch.eye(self.d, dtype=torch.float32) * (1 - exp_neg_beta_t) # Shape: (K, d, d) | |
return GaussianMixtureModel(mu_t, Sigma_t, self.pi) | |
def flow(self, x_t, t, dt, num_steps): | |
for _ in range(num_steps): | |
x_t = self.probability_flow_ode(x_t, t, dt) | |
t += dt | |
return x_t | |
def flow_gmm_to_normal(self, x_0, T=5, N=32): | |
dt = T / N | |
return self.flow(x_0, 0, dt, N) | |
def flow_normal_to_gmm(self, x_T, T=5, N=32): | |
dt = -T / N | |
return self.flow(x_T, T, dt, N) | |
def probability_flow_ode(self, x_t, t, dt): | |
# Compute the evolved x_t based on the probability flow ODE using the Euler method | |
# Forward diffusion to get the evolved GMM parameters | |
gmm_t = self.forward_diffusion(t) | |
# Compute the score function at x_t | |
score = gmm_t.score(x_t) | |
# Compute the drift term | |
drift = -0.5 * x_t - 0.5 * score | |
# Euler update | |
x_t_plus_dt = x_t + drift * dt | |
return x_t_plus_dt | |
# Example usage | |
if __name__ == "__main__": | |
mu = torch.tensor([[0, 0], [1, 1]], dtype=torch.float32) # Shape: (K, d) | |
Sigma = torch.stack([torch.eye(2), torch.eye(2)], dim=0).float() # Shape: (K, d, d) | |
pi = torch.tensor([0.5, 0.5], dtype=torch.float32) # Shape: (K,) | |
gmm = GaussianMixtureModel(mu, Sigma, pi) | |
# Perform forward diffusion | |
t = 1.0 | |
gmm_t = gmm.forward_diffusion(t) | |
# Sampling from the diffused GMM | |
samples = gmm_t.sample(10) | |
print("Samples after forward diffusion:\n", samples) | |
# Log probability of a sample after forward diffusion | |
log_prob = gmm_t.log_prob(torch.tensor([0.5, 0.5], dtype=torch.float32)) | |
print("Log Probability after forward diffusion:", log_prob) | |
# Score of a sample after forward diffusion | |
score = gmm_t.score(torch.tensor([[0.5, 0.5]], dtype=torch.float32)) | |
print("Score after forward diffusion:", score) | |
# Using the flow_gmm_to_normal | |
x_0 = torch.tensor([[0.5, 0.5]], dtype=torch.float32) | |
x_T_normal = gmm.flow_gmm_to_normal(x_0) | |
print("x_T_normal after flowing from GMM to normal:", x_T_normal) | |
# Using the flow_normal_to_gmm | |
x_T = torch.tensor([[0.0, 0.0]], dtype=torch.float32) | |
x_0_gmm = gmm.flow_normal_to_gmm(x_T) | |
print("x_0_gmm after flowing from normal to GMM:", x_0_gmm) | |