import torch import numpy as np class GaussianMixtureModel: def __init__(self, mu, Sigma, pi): # Ensure all inputs are torch tensors assert isinstance(mu, torch.Tensor), "mu must be a torch.Tensor." assert isinstance(Sigma, torch.Tensor), "Sigma must be a torch.Tensor." assert isinstance(pi, torch.Tensor), "pi must be a torch.Tensor." # Ensure the dtype is torch.float32 assert mu.dtype == torch.float32, "mu must have dtype torch.float32." assert Sigma.dtype == torch.float32, "Sigma must have dtype torch.float32." assert pi.dtype == torch.float32, "pi must have dtype torch.float32." self.K, self.d = mu.shape[:2] # Check the shape of mu if mu.shape == (self.K, self.d): mu = mu.unsqueeze(-1) # Shape: (K, d, 1) assert mu.shape == (self.K, self.d, 1), "mu must have shape (K, d, 1)." # Check the shape of Sigma assert Sigma.shape == (self.K, self.d, self.d), "Sigma must have shape (K, d, d)." # Check the shape of pi and fix it if necessary if pi.shape == (self.K,): pi = pi.view(self.K, 1, 1) elif pi.shape == (self.K, 1): pi = pi.unsqueeze(-1) assert pi.shape == (self.K, 1, 1), "pi must have shape (K, 1, 1)." # Ensure pi sums to 1 assert torch.isclose(torch.sum(pi), torch.tensor(1.0)), "Mixture weights must sum to 1." self.mu = mu self.Sigma = Sigma self.pi = pi def sample(self, n_samples): # Sample from the mixture model samples = [] for _ in range(n_samples): # Choose a component based on mixture weights k = torch.multinomial(self.pi.reshape(self.pi.shape[0]), 1).item() sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample() samples.append(sample) return torch.stack(samples) def log_prob(self, x): # Compute the log probability of a given sample x x = x.view(1, self.d, 1) # Shape: (1, d, 1) diff = x - self.mu # Shape: (K, d, 1) inv_Sigma = torch.inverse(self.Sigma) # Shape: (K, d, d) exponent = -0.5 * torch.bmm(torch.bmm(diff.transpose(1, 2), inv_Sigma), diff).squeeze() # Shape: (K,) normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)) / 2 # Shape: (K,) log_probs = torch.log(self.pi.squeeze()) + exponent - normalization # Shape: (K,) return torch.logsumexp(log_probs, dim=0) def score(self, x): # Compute the score function (gradient of log probability) for a batch of samples x B = x.shape[0] x = x.view(B, 1, self.d, 1) # Shape: (B, 1, d, 1) diff = x - self.mu.unsqueeze(0) # Shape: (B, K, d, 1) inv_Sigma = torch.inverse(self.Sigma).unsqueeze(0) # Shape: (1, K, d, d) diff_t = diff.transpose(-2, -1).contiguous().view(B * self.K, 1, self.d) # Shape: (B*K, 1, d) inv_Sigma_flat = inv_Sigma.view(self.K, self.d, self.d).expand(B, self.K, self.d, self.d).contiguous().view(B * self.K, self.d, self.d) # Shape: (B*K, d, d) exponent = -0.5 * torch.bmm(torch.bmm(diff_t, inv_Sigma_flat), diff.view(B * self.K, self.d, 1)).view(B, self.K) # Shape: (B, K) normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)).unsqueeze(0) / 2 # Shape: (1, K) probs = torch.exp(torch.log(self.pi.squeeze()).unsqueeze(0) + exponent - normalization) # Shape: (B, K) norm_probs = probs / torch.sum(probs, dim=1, keepdim=True) # Shape: (B, K) gradients = [] for k in range(self.K): gradient = -torch.bmm(inv_Sigma[:, k].expand(B, self.d, self.d), diff[:, k]) # Shape: (B, d, 1) gradients.append(norm_probs[:, k].unsqueeze(1) * gradient.squeeze(-1)) # Shape: (B, d) return torch.sum(torch.stack(gradients, dim=1), dim=1) # Shape: (B, d) def forward_diffusion(self, t): # Compute the evolved mean and covariance mu_t = self.mu * torch.tensor(np.exp(-0.5 * t), dtype=torch.float32) # Shape: (K, d, 1) exp_neg_beta_t = torch.tensor(np.exp(-t), dtype=torch.float32).view(1, 1, 1) # Shape: (1, 1, 1) Sigma_t = self.Sigma * exp_neg_beta_t + torch.eye(self.d, dtype=torch.float32) * (1 - exp_neg_beta_t) # Shape: (K, d, d) return GaussianMixtureModel(mu_t, Sigma_t, self.pi) def flow(self, x_t, t, dt, num_steps): for _ in range(num_steps): x_t = self.probability_flow_ode(x_t, t, dt) t += dt return x_t def flow_gmm_to_normal(self, x_0, T=5, N=32): dt = T / N return self.flow(x_0, 0, dt, N) def flow_normal_to_gmm(self, x_T, T=5, N=32): dt = -T / N return self.flow(x_T, T, dt, N) def probability_flow_ode(self, x_t, t, dt): # Compute the evolved x_t based on the probability flow ODE using the Euler method # Forward diffusion to get the evolved GMM parameters gmm_t = self.forward_diffusion(t) # Compute the score function at x_t score = gmm_t.score(x_t) # Compute the drift term drift = -0.5 * x_t - 0.5 * score # Euler update x_t_plus_dt = x_t + drift * dt return x_t_plus_dt # Example usage if __name__ == "__main__": mu = torch.tensor([[0, 0], [1, 1]], dtype=torch.float32) # Shape: (K, d) Sigma = torch.stack([torch.eye(2), torch.eye(2)], dim=0).float() # Shape: (K, d, d) pi = torch.tensor([0.5, 0.5], dtype=torch.float32) # Shape: (K,) gmm = GaussianMixtureModel(mu, Sigma, pi) # Perform forward diffusion t = 1.0 gmm_t = gmm.forward_diffusion(t) # Sampling from the diffused GMM samples = gmm_t.sample(10) print("Samples after forward diffusion:\n", samples) # Log probability of a sample after forward diffusion log_prob = gmm_t.log_prob(torch.tensor([0.5, 0.5], dtype=torch.float32)) print("Log Probability after forward diffusion:", log_prob) # Score of a sample after forward diffusion score = gmm_t.score(torch.tensor([[0.5, 0.5]], dtype=torch.float32)) print("Score after forward diffusion:", score) # Using the flow_gmm_to_normal x_0 = torch.tensor([[0.5, 0.5]], dtype=torch.float32) x_T_normal = gmm.flow_gmm_to_normal(x_0) print("x_T_normal after flowing from GMM to normal:", x_T_normal) # Using the flow_normal_to_gmm x_T = torch.tensor([[0.0, 0.0]], dtype=torch.float32) x_0_gmm = gmm.flow_normal_to_gmm(x_T) print("x_0_gmm after flowing from normal to GMM:", x_0_gmm)