tivnanmatt's picture
working version
7b1ccbd
import torch
import numpy as np
class GaussianMixtureModel:
def __init__(self, mu, Sigma, pi):
# Ensure all inputs are torch tensors
assert isinstance(mu, torch.Tensor), "mu must be a torch.Tensor."
assert isinstance(Sigma, torch.Tensor), "Sigma must be a torch.Tensor."
assert isinstance(pi, torch.Tensor), "pi must be a torch.Tensor."
# Ensure the dtype is torch.float32
assert mu.dtype == torch.float32, "mu must have dtype torch.float32."
assert Sigma.dtype == torch.float32, "Sigma must have dtype torch.float32."
assert pi.dtype == torch.float32, "pi must have dtype torch.float32."
self.K, self.d = mu.shape[:2]
# Check the shape of mu
if mu.shape == (self.K, self.d):
mu = mu.unsqueeze(-1) # Shape: (K, d, 1)
assert mu.shape == (self.K, self.d, 1), "mu must have shape (K, d, 1)."
# Check the shape of Sigma
assert Sigma.shape == (self.K, self.d, self.d), "Sigma must have shape (K, d, d)."
# Check the shape of pi and fix it if necessary
if pi.shape == (self.K,):
pi = pi.view(self.K, 1, 1)
elif pi.shape == (self.K, 1):
pi = pi.unsqueeze(-1)
assert pi.shape == (self.K, 1, 1), "pi must have shape (K, 1, 1)."
# Ensure pi sums to 1
assert torch.isclose(torch.sum(pi), torch.tensor(1.0)), "Mixture weights must sum to 1."
self.mu = mu
self.Sigma = Sigma
self.pi = pi
def sample(self, n_samples):
# Sample from the mixture model
samples = []
for _ in range(n_samples):
# Choose a component based on mixture weights
k = torch.multinomial(self.pi.reshape(self.pi.shape[0]), 1).item()
sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample()
samples.append(sample)
return torch.stack(samples)
def log_prob(self, x):
# Compute the log probability of a given sample x
x = x.view(1, self.d, 1) # Shape: (1, d, 1)
diff = x - self.mu # Shape: (K, d, 1)
inv_Sigma = torch.inverse(self.Sigma) # Shape: (K, d, d)
exponent = -0.5 * torch.bmm(torch.bmm(diff.transpose(1, 2), inv_Sigma), diff).squeeze() # Shape: (K,)
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)) / 2 # Shape: (K,)
log_probs = torch.log(self.pi.squeeze()) + exponent - normalization # Shape: (K,)
return torch.logsumexp(log_probs, dim=0)
def score(self, x):
# Compute the score function (gradient of log probability) for a batch of samples x
B = x.shape[0]
x = x.view(B, 1, self.d, 1) # Shape: (B, 1, d, 1)
diff = x - self.mu.unsqueeze(0) # Shape: (B, K, d, 1)
inv_Sigma = torch.inverse(self.Sigma).unsqueeze(0) # Shape: (1, K, d, d)
diff_t = diff.transpose(-2, -1).contiguous().view(B * self.K, 1, self.d) # Shape: (B*K, 1, d)
inv_Sigma_flat = inv_Sigma.view(self.K, self.d, self.d).expand(B, self.K, self.d, self.d).contiguous().view(B * self.K, self.d, self.d) # Shape: (B*K, d, d)
exponent = -0.5 * torch.bmm(torch.bmm(diff_t, inv_Sigma_flat), diff.view(B * self.K, self.d, 1)).view(B, self.K) # Shape: (B, K)
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)).unsqueeze(0) / 2 # Shape: (1, K)
probs = torch.exp(torch.log(self.pi.squeeze()).unsqueeze(0) + exponent - normalization) # Shape: (B, K)
norm_probs = probs / torch.sum(probs, dim=1, keepdim=True) # Shape: (B, K)
gradients = []
for k in range(self.K):
gradient = -torch.bmm(inv_Sigma[:, k].expand(B, self.d, self.d), diff[:, k]) # Shape: (B, d, 1)
gradients.append(norm_probs[:, k].unsqueeze(1) * gradient.squeeze(-1)) # Shape: (B, d)
return torch.sum(torch.stack(gradients, dim=1), dim=1) # Shape: (B, d)
def forward_diffusion(self, t):
# Compute the evolved mean and covariance
mu_t = self.mu * torch.tensor(np.exp(-0.5 * t), dtype=torch.float32) # Shape: (K, d, 1)
exp_neg_beta_t = torch.tensor(np.exp(-t), dtype=torch.float32).view(1, 1, 1) # Shape: (1, 1, 1)
Sigma_t = self.Sigma * exp_neg_beta_t + torch.eye(self.d, dtype=torch.float32) * (1 - exp_neg_beta_t) # Shape: (K, d, d)
return GaussianMixtureModel(mu_t, Sigma_t, self.pi)
def flow(self, x_t, t, dt, num_steps):
for _ in range(num_steps):
x_t = self.probability_flow_ode(x_t, t, dt)
t += dt
return x_t
def flow_gmm_to_normal(self, x_0, T=5, N=32):
dt = T / N
return self.flow(x_0, 0, dt, N)
def flow_normal_to_gmm(self, x_T, T=5, N=32):
dt = -T / N
return self.flow(x_T, T, dt, N)
def probability_flow_ode(self, x_t, t, dt):
# Compute the evolved x_t based on the probability flow ODE using the Euler method
# Forward diffusion to get the evolved GMM parameters
gmm_t = self.forward_diffusion(t)
# Compute the score function at x_t
score = gmm_t.score(x_t)
# Compute the drift term
drift = -0.5 * x_t - 0.5 * score
# Euler update
x_t_plus_dt = x_t + drift * dt
return x_t_plus_dt
# Example usage
if __name__ == "__main__":
mu = torch.tensor([[0, 0], [1, 1]], dtype=torch.float32) # Shape: (K, d)
Sigma = torch.stack([torch.eye(2), torch.eye(2)], dim=0).float() # Shape: (K, d, d)
pi = torch.tensor([0.5, 0.5], dtype=torch.float32) # Shape: (K,)
gmm = GaussianMixtureModel(mu, Sigma, pi)
# Perform forward diffusion
t = 1.0
gmm_t = gmm.forward_diffusion(t)
# Sampling from the diffused GMM
samples = gmm_t.sample(10)
print("Samples after forward diffusion:\n", samples)
# Log probability of a sample after forward diffusion
log_prob = gmm_t.log_prob(torch.tensor([0.5, 0.5], dtype=torch.float32))
print("Log Probability after forward diffusion:", log_prob)
# Score of a sample after forward diffusion
score = gmm_t.score(torch.tensor([[0.5, 0.5]], dtype=torch.float32))
print("Score after forward diffusion:", score)
# Using the flow_gmm_to_normal
x_0 = torch.tensor([[0.5, 0.5]], dtype=torch.float32)
x_T_normal = gmm.flow_gmm_to_normal(x_0)
print("x_T_normal after flowing from GMM to normal:", x_T_normal)
# Using the flow_normal_to_gmm
x_T = torch.tensor([[0.0, 0.0]], dtype=torch.float32)
x_0_gmm = gmm.flow_normal_to_gmm(x_T)
print("x_0_gmm after flowing from normal to GMM:", x_0_gmm)