Spaces:
Sleeping
Sleeping
File size: 6,768 Bytes
7bc83e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import numpy as np
class GaussianMixtureModel:
def __init__(self, mu, Sigma, pi):
# Ensure all inputs are torch tensors
assert isinstance(mu, torch.Tensor), "mu must be a torch.Tensor."
assert isinstance(Sigma, torch.Tensor), "Sigma must be a torch.Tensor."
assert isinstance(pi, torch.Tensor), "pi must be a torch.Tensor."
# Ensure the dtype is torch.float32
assert mu.dtype == torch.float32, "mu must have dtype torch.float32."
assert Sigma.dtype == torch.float32, "Sigma must have dtype torch.float32."
assert pi.dtype == torch.float32, "pi must have dtype torch.float32."
self.K, self.d = mu.shape[:2]
# Check the shape of mu
if mu.shape == (self.K, self.d):
mu = mu.unsqueeze(-1) # Shape: (K, d, 1)
assert mu.shape == (self.K, self.d, 1), "mu must have shape (K, d, 1)."
# Check the shape of Sigma
assert Sigma.shape == (self.K, self.d, self.d), "Sigma must have shape (K, d, d)."
# Check the shape of pi and fix it if necessary
if pi.shape == (self.K,):
pi = pi.view(self.K, 1, 1)
elif pi.shape == (self.K, 1):
pi = pi.unsqueeze(-1)
assert pi.shape == (self.K, 1, 1), "pi must have shape (K, 1, 1)."
# Ensure pi sums to 1
assert torch.isclose(torch.sum(pi), torch.tensor(1.0)), "Mixture weights must sum to 1."
self.mu = mu
self.Sigma = Sigma
self.pi = pi
def sample(self, n_samples):
# Sample from the mixture model
samples = []
for _ in range(n_samples):
# Choose a component based on mixture weights
k = torch.multinomial(self.pi.squeeze(), 1).item()
sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample()
samples.append(sample)
return torch.stack(samples)
def log_prob(self, x):
# Compute the log probability of a given sample x
x = x.view(1, self.d, 1) # Shape: (1, d, 1)
diff = x - self.mu # Shape: (K, d, 1)
inv_Sigma = torch.inverse(self.Sigma) # Shape: (K, d, d)
exponent = -0.5 * torch.bmm(torch.bmm(diff.transpose(1, 2), inv_Sigma), diff).squeeze() # Shape: (K,)
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)) / 2 # Shape: (K,)
log_probs = torch.log(self.pi.squeeze()) + exponent - normalization # Shape: (K,)
return torch.logsumexp(log_probs, dim=0)
def score(self, x):
# Compute the score function (gradient of log probability) for a batch of samples x
B = x.shape[0]
x = x.view(B, 1, self.d, 1) # Shape: (B, 1, d, 1)
diff = x - self.mu.unsqueeze(0) # Shape: (B, K, d, 1)
inv_Sigma = torch.inverse(self.Sigma).unsqueeze(0) # Shape: (1, K, d, d)
diff_t = diff.transpose(-2, -1).contiguous().view(B * self.K, 1, self.d) # Shape: (B*K, 1, d)
inv_Sigma_flat = inv_Sigma.view(self.K, self.d, self.d).expand(B, self.K, self.d, self.d).contiguous().view(B * self.K, self.d, self.d) # Shape: (B*K, d, d)
exponent = -0.5 * torch.bmm(torch.bmm(diff_t, inv_Sigma_flat), diff.view(B * self.K, self.d, 1)).view(B, self.K) # Shape: (B, K)
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)).unsqueeze(0) / 2 # Shape: (1, K)
probs = torch.exp(torch.log(self.pi.squeeze()).unsqueeze(0) + exponent - normalization) # Shape: (B, K)
norm_probs = probs / torch.sum(probs, dim=1, keepdim=True) # Shape: (B, K)
gradients = []
for k in range(self.K):
gradient = -torch.bmm(inv_Sigma[:, k].expand(B, self.d, self.d), diff[:, k]) # Shape: (B, d, 1)
gradients.append(norm_probs[:, k].unsqueeze(1) * gradient.squeeze(-1)) # Shape: (B, d)
return torch.sum(torch.stack(gradients, dim=1), dim=1) # Shape: (B, d)
def forward_diffusion(self, t):
# Compute the evolved mean and covariance
mu_t = self.mu * torch.tensor(np.exp(-0.5 * t), dtype=torch.float32) # Shape: (K, d, 1)
exp_neg_beta_t = torch.tensor(np.exp(-t), dtype=torch.float32).view(1, 1, 1) # Shape: (1, 1, 1)
Sigma_t = self.Sigma * exp_neg_beta_t + torch.eye(self.d, dtype=torch.float32) * (1 - exp_neg_beta_t) # Shape: (K, d, d)
return GaussianMixtureModel(mu_t, Sigma_t, self.pi)
def flow(self, x_t, t, dt, num_steps):
for _ in range(num_steps):
x_t = self.probability_flow_ode(x_t, t, dt)
t += dt
return x_t
def flow_gmm_to_normal(self, x_0, T=5, N=32):
dt = T / N
return self.flow(x_0, 0, dt, N)
def flow_normal_to_gmm(self, x_T, T=5, N=32):
dt = -T / N
return self.flow(x_T, T, dt, N)
def probability_flow_ode(self, x_t, t, dt):
# Compute the evolved x_t based on the probability flow ODE using the Euler method
# Forward diffusion to get the evolved GMM parameters
gmm_t = self.forward_diffusion(t)
# Compute the score function at x_t
score = gmm_t.score(x_t)
# Compute the drift term
drift = -0.5 * x_t - 0.5 * score
# Euler update
x_t_plus_dt = x_t + drift * dt
return x_t_plus_dt
# Example usage
if __name__ == "__main__":
mu = torch.tensor([[0, 0], [1, 1]], dtype=torch.float32) # Shape: (K, d)
Sigma = torch.stack([torch.eye(2), torch.eye(2)], dim=0).float() # Shape: (K, d, d)
pi = torch.tensor([0.5, 0.5], dtype=torch.float32) # Shape: (K,)
gmm = GaussianMixtureModel(mu, Sigma, pi)
# Perform forward diffusion
t = 1.0
gmm_t = gmm.forward_diffusion(t)
# Sampling from the diffused GMM
samples = gmm_t.sample(10)
print("Samples after forward diffusion:\n", samples)
# Log probability of a sample after forward diffusion
log_prob = gmm_t.log_prob(torch.tensor([0.5, 0.5], dtype=torch.float32))
print("Log Probability after forward diffusion:", log_prob)
# Score of a sample after forward diffusion
score = gmm_t.score(torch.tensor([[0.5, 0.5]], dtype=torch.float32))
print("Score after forward diffusion:", score)
# Using the flow_gmm_to_normal
x_0 = torch.tensor([[0.5, 0.5]], dtype=torch.float32)
x_T_normal = gmm.flow_gmm_to_normal(x_0)
print("x_T_normal after flowing from GMM to normal:", x_T_normal)
# Using the flow_normal_to_gmm
x_T = torch.tensor([[0.0, 0.0]], dtype=torch.float32)
x_0_gmm = gmm.flow_normal_to_gmm(x_T)
print("x_0_gmm after flowing from normal to GMM:", x_0_gmm)
|