Spaces:
Sleeping
Sleeping
tivnanmatt
commited on
Commit
·
7bc83e9
1
Parent(s):
ca10114
initial commit of app
Browse files
app.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from utils import initialize_gmm, generate_grid, generate_contours, generate_intermediate_points, plot_samples_and_contours, create_animation
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
def visualize_gmm(mu_list, Sigma_list, pi_list, dx, dtheta, T, N):
|
8 |
+
gmm = initialize_gmm(mu_list, Sigma_list, pi_list)
|
9 |
+
grid_points = generate_grid(dx)
|
10 |
+
std_normal_contours = generate_contours(dtheta)
|
11 |
+
gmm_samples = gmm.sample(500)
|
12 |
+
intermediate_points = generate_intermediate_points(gmm, grid_points, std_normal_contours, gmm_samples, T, N)
|
13 |
+
|
14 |
+
fig1, ax1 = plot_samples_and_contours(gmm_samples, std_normal_contours, grid_points, "GMM Samples and Contours")
|
15 |
+
fig2, ax2 = plot_samples_and_contours(gmm_samples, std_normal_contours, grid_points, "Standard Normal Samples and Contours")
|
16 |
+
|
17 |
+
anim1 = create_animation(fig1, ax1, N, *intermediate_points[:3])
|
18 |
+
anim2 = create_animation(fig2, ax2, N, *intermediate_points[3:])
|
19 |
+
|
20 |
+
return fig1, fig2, anim1.to_jshtml(), anim2.to_jshtml()
|
21 |
+
|
22 |
+
demo = gr.Interface(
|
23 |
+
fn=visualize_gmm,
|
24 |
+
inputs=[
|
25 |
+
gr.Textbox(label="Mu List", placeholder="Enter means as a list of lists, e.g., [[0,0], [1,1]]"),
|
26 |
+
gr.Textbox(label="Sigma List", placeholder="Enter covariances as a list of lists, e.g., [[[0.2, 0.1], [0.1, 0.3]], [[1.0, -0.1], [-0.1, 0.1]]]"),
|
27 |
+
gr.Textbox(label="Pi List", placeholder="Enter weights as a list, e.g., [0.5, 0.5]"),
|
28 |
+
gr.Slider(minimum=0.01, maximum=1.0, label="dx", default=0.1),
|
29 |
+
gr.Slider(minimum=0.01, maximum=0.1, label="dtheta", default=0.01),
|
30 |
+
gr.Slider(minimum=1, maximum=100, label="T", default=10),
|
31 |
+
gr.Slider(minimum=1, maximum=500, label="N", default=100)
|
32 |
+
],
|
33 |
+
outputs=[
|
34 |
+
gr.Plot(label="GMM to Normal Flow"),
|
35 |
+
gr.Plot(label="Normal to GMM Flow"),
|
36 |
+
gr.HTML(label="GMM to Normal Animation"),
|
37 |
+
gr.HTML(label="Normal to GMM Animation")
|
38 |
+
],
|
39 |
+
live=True
|
40 |
+
)
|
41 |
+
|
42 |
+
demo.launch()
|
gmm.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class GaussianMixtureModel:
|
5 |
+
def __init__(self, mu, Sigma, pi):
|
6 |
+
# Ensure all inputs are torch tensors
|
7 |
+
assert isinstance(mu, torch.Tensor), "mu must be a torch.Tensor."
|
8 |
+
assert isinstance(Sigma, torch.Tensor), "Sigma must be a torch.Tensor."
|
9 |
+
assert isinstance(pi, torch.Tensor), "pi must be a torch.Tensor."
|
10 |
+
|
11 |
+
# Ensure the dtype is torch.float32
|
12 |
+
assert mu.dtype == torch.float32, "mu must have dtype torch.float32."
|
13 |
+
assert Sigma.dtype == torch.float32, "Sigma must have dtype torch.float32."
|
14 |
+
assert pi.dtype == torch.float32, "pi must have dtype torch.float32."
|
15 |
+
|
16 |
+
self.K, self.d = mu.shape[:2]
|
17 |
+
|
18 |
+
# Check the shape of mu
|
19 |
+
if mu.shape == (self.K, self.d):
|
20 |
+
mu = mu.unsqueeze(-1) # Shape: (K, d, 1)
|
21 |
+
assert mu.shape == (self.K, self.d, 1), "mu must have shape (K, d, 1)."
|
22 |
+
|
23 |
+
# Check the shape of Sigma
|
24 |
+
assert Sigma.shape == (self.K, self.d, self.d), "Sigma must have shape (K, d, d)."
|
25 |
+
|
26 |
+
# Check the shape of pi and fix it if necessary
|
27 |
+
if pi.shape == (self.K,):
|
28 |
+
pi = pi.view(self.K, 1, 1)
|
29 |
+
elif pi.shape == (self.K, 1):
|
30 |
+
pi = pi.unsqueeze(-1)
|
31 |
+
assert pi.shape == (self.K, 1, 1), "pi must have shape (K, 1, 1)."
|
32 |
+
|
33 |
+
# Ensure pi sums to 1
|
34 |
+
assert torch.isclose(torch.sum(pi), torch.tensor(1.0)), "Mixture weights must sum to 1."
|
35 |
+
|
36 |
+
self.mu = mu
|
37 |
+
self.Sigma = Sigma
|
38 |
+
self.pi = pi
|
39 |
+
|
40 |
+
def sample(self, n_samples):
|
41 |
+
# Sample from the mixture model
|
42 |
+
samples = []
|
43 |
+
for _ in range(n_samples):
|
44 |
+
# Choose a component based on mixture weights
|
45 |
+
k = torch.multinomial(self.pi.squeeze(), 1).item()
|
46 |
+
sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample()
|
47 |
+
samples.append(sample)
|
48 |
+
return torch.stack(samples)
|
49 |
+
|
50 |
+
def log_prob(self, x):
|
51 |
+
# Compute the log probability of a given sample x
|
52 |
+
x = x.view(1, self.d, 1) # Shape: (1, d, 1)
|
53 |
+
diff = x - self.mu # Shape: (K, d, 1)
|
54 |
+
inv_Sigma = torch.inverse(self.Sigma) # Shape: (K, d, d)
|
55 |
+
|
56 |
+
exponent = -0.5 * torch.bmm(torch.bmm(diff.transpose(1, 2), inv_Sigma), diff).squeeze() # Shape: (K,)
|
57 |
+
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)) / 2 # Shape: (K,)
|
58 |
+
log_probs = torch.log(self.pi.squeeze()) + exponent - normalization # Shape: (K,)
|
59 |
+
return torch.logsumexp(log_probs, dim=0)
|
60 |
+
|
61 |
+
def score(self, x):
|
62 |
+
# Compute the score function (gradient of log probability) for a batch of samples x
|
63 |
+
B = x.shape[0]
|
64 |
+
x = x.view(B, 1, self.d, 1) # Shape: (B, 1, d, 1)
|
65 |
+
diff = x - self.mu.unsqueeze(0) # Shape: (B, K, d, 1)
|
66 |
+
inv_Sigma = torch.inverse(self.Sigma).unsqueeze(0) # Shape: (1, K, d, d)
|
67 |
+
|
68 |
+
diff_t = diff.transpose(-2, -1).contiguous().view(B * self.K, 1, self.d) # Shape: (B*K, 1, d)
|
69 |
+
inv_Sigma_flat = inv_Sigma.view(self.K, self.d, self.d).expand(B, self.K, self.d, self.d).contiguous().view(B * self.K, self.d, self.d) # Shape: (B*K, d, d)
|
70 |
+
exponent = -0.5 * torch.bmm(torch.bmm(diff_t, inv_Sigma_flat), diff.view(B * self.K, self.d, 1)).view(B, self.K) # Shape: (B, K)
|
71 |
+
|
72 |
+
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)).unsqueeze(0) / 2 # Shape: (1, K)
|
73 |
+
probs = torch.exp(torch.log(self.pi.squeeze()).unsqueeze(0) + exponent - normalization) # Shape: (B, K)
|
74 |
+
norm_probs = probs / torch.sum(probs, dim=1, keepdim=True) # Shape: (B, K)
|
75 |
+
|
76 |
+
gradients = []
|
77 |
+
for k in range(self.K):
|
78 |
+
gradient = -torch.bmm(inv_Sigma[:, k].expand(B, self.d, self.d), diff[:, k]) # Shape: (B, d, 1)
|
79 |
+
gradients.append(norm_probs[:, k].unsqueeze(1) * gradient.squeeze(-1)) # Shape: (B, d)
|
80 |
+
return torch.sum(torch.stack(gradients, dim=1), dim=1) # Shape: (B, d)
|
81 |
+
|
82 |
+
def forward_diffusion(self, t):
|
83 |
+
# Compute the evolved mean and covariance
|
84 |
+
mu_t = self.mu * torch.tensor(np.exp(-0.5 * t), dtype=torch.float32) # Shape: (K, d, 1)
|
85 |
+
exp_neg_beta_t = torch.tensor(np.exp(-t), dtype=torch.float32).view(1, 1, 1) # Shape: (1, 1, 1)
|
86 |
+
Sigma_t = self.Sigma * exp_neg_beta_t + torch.eye(self.d, dtype=torch.float32) * (1 - exp_neg_beta_t) # Shape: (K, d, d)
|
87 |
+
return GaussianMixtureModel(mu_t, Sigma_t, self.pi)
|
88 |
+
|
89 |
+
def flow(self, x_t, t, dt, num_steps):
|
90 |
+
for _ in range(num_steps):
|
91 |
+
x_t = self.probability_flow_ode(x_t, t, dt)
|
92 |
+
t += dt
|
93 |
+
return x_t
|
94 |
+
|
95 |
+
def flow_gmm_to_normal(self, x_0, T=5, N=32):
|
96 |
+
dt = T / N
|
97 |
+
return self.flow(x_0, 0, dt, N)
|
98 |
+
|
99 |
+
def flow_normal_to_gmm(self, x_T, T=5, N=32):
|
100 |
+
dt = -T / N
|
101 |
+
return self.flow(x_T, T, dt, N)
|
102 |
+
|
103 |
+
def probability_flow_ode(self, x_t, t, dt):
|
104 |
+
# Compute the evolved x_t based on the probability flow ODE using the Euler method
|
105 |
+
# Forward diffusion to get the evolved GMM parameters
|
106 |
+
gmm_t = self.forward_diffusion(t)
|
107 |
+
|
108 |
+
# Compute the score function at x_t
|
109 |
+
score = gmm_t.score(x_t)
|
110 |
+
|
111 |
+
# Compute the drift term
|
112 |
+
drift = -0.5 * x_t - 0.5 * score
|
113 |
+
|
114 |
+
# Euler update
|
115 |
+
x_t_plus_dt = x_t + drift * dt
|
116 |
+
|
117 |
+
return x_t_plus_dt
|
118 |
+
|
119 |
+
# Example usage
|
120 |
+
if __name__ == "__main__":
|
121 |
+
mu = torch.tensor([[0, 0], [1, 1]], dtype=torch.float32) # Shape: (K, d)
|
122 |
+
Sigma = torch.stack([torch.eye(2), torch.eye(2)], dim=0).float() # Shape: (K, d, d)
|
123 |
+
pi = torch.tensor([0.5, 0.5], dtype=torch.float32) # Shape: (K,)
|
124 |
+
|
125 |
+
gmm = GaussianMixtureModel(mu, Sigma, pi)
|
126 |
+
|
127 |
+
# Perform forward diffusion
|
128 |
+
t = 1.0
|
129 |
+
gmm_t = gmm.forward_diffusion(t)
|
130 |
+
|
131 |
+
# Sampling from the diffused GMM
|
132 |
+
samples = gmm_t.sample(10)
|
133 |
+
print("Samples after forward diffusion:\n", samples)
|
134 |
+
|
135 |
+
# Log probability of a sample after forward diffusion
|
136 |
+
log_prob = gmm_t.log_prob(torch.tensor([0.5, 0.5], dtype=torch.float32))
|
137 |
+
print("Log Probability after forward diffusion:", log_prob)
|
138 |
+
|
139 |
+
# Score of a sample after forward diffusion
|
140 |
+
score = gmm_t.score(torch.tensor([[0.5, 0.5]], dtype=torch.float32))
|
141 |
+
print("Score after forward diffusion:", score)
|
142 |
+
|
143 |
+
# Using the flow_gmm_to_normal
|
144 |
+
x_0 = torch.tensor([[0.5, 0.5]], dtype=torch.float32)
|
145 |
+
x_T_normal = gmm.flow_gmm_to_normal(x_0)
|
146 |
+
print("x_T_normal after flowing from GMM to normal:", x_T_normal)
|
147 |
+
|
148 |
+
# Using the flow_normal_to_gmm
|
149 |
+
x_T = torch.tensor([[0.0, 0.0]], dtype=torch.float32)
|
150 |
+
x_0_gmm = gmm.flow_normal_to_gmm(x_T)
|
151 |
+
print("x_0_gmm after flowing from normal to GMM:", x_0_gmm)
|
utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from matplotlib.animation import FuncAnimation
|
7 |
+
from gmm import GaussianMixtureModel
|
8 |
+
|
9 |
+
def initialize_gmm(mu_list, Sigma_list, pi_list):
|
10 |
+
mu = torch.tensor(mu_list, dtype=torch.float32)
|
11 |
+
Sigma = torch.tensor(Sigma_list, dtype=torch.float32)
|
12 |
+
pi = torch.tensor(pi_list, dtype=torch.float32)
|
13 |
+
return GaussianMixtureModel(mu, Sigma, pi)
|
14 |
+
|
15 |
+
def generate_grid(dx):
|
16 |
+
x_positions = np.arange(-10, 10.5, 0.5)
|
17 |
+
y_positions = np.arange(-10, 10.5, 0.5)
|
18 |
+
vertical_lines = [np.stack([np.full(int((10 - (-10))/ dx + 1), x), np.arange(-10, 10 + dx, dx)], axis=1) for x in x_positions]
|
19 |
+
horizontal_lines = [np.stack([np.arange(-10, 10 + dx, dx), np.full(int((10 - (-10)) / dx + 1), y)], axis=1) for y in y_positions]
|
20 |
+
grid_points = np.concatenate(vertical_lines + horizontal_lines, axis=0)
|
21 |
+
return torch.tensor(grid_points, dtype=torch.float32)
|
22 |
+
|
23 |
+
def generate_contours(dtheta):
|
24 |
+
angles = np.linspace(0, 2 * np.pi, int(2 * np.pi / dtheta))
|
25 |
+
std_normal_contours = np.concatenate([np.stack([r * np.cos(angles), r * np.sin(angles)], axis=1) for r in range(1, 4)], axis=0)
|
26 |
+
return torch.tensor(std_normal_contours, dtype=torch.float32)
|
27 |
+
|
28 |
+
def generate_intermediate_points(gmm, grid_points, std_normal_contours, gmm_samples, T, N):
|
29 |
+
intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(grid_points, T, N)
|
30 |
+
contour_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(std_normal_contours, T, N)
|
31 |
+
grid_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(grid_points, T, N)
|
32 |
+
|
33 |
+
intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(gmm_samples, T, N)
|
34 |
+
contour_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(std_normal_contours, T, N)
|
35 |
+
grid_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(grid_points, T, N)
|
36 |
+
|
37 |
+
return (intermediate_points_gmm_to_normal, contour_intermediate_points_gmm_to_normal, grid_intermediate_points_gmm_to_normal,
|
38 |
+
intermediate_points_normal_to_gmm, contour_intermediate_points_normal_to_gmm, grid_intermediate_points_normal_to_gmm)
|
39 |
+
|
40 |
+
def plot_samples_and_contours(samples, contours, grid_points, title):
|
41 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
42 |
+
ax.scatter(grid_points[:, 0], grid_points[:, 1], alpha=0.5, c='black', s=1, label='Grid Points')
|
43 |
+
ax.scatter(contours[:, 0], contours[:, 1], alpha=0.5, s=3, c='blue', label='Contours')
|
44 |
+
ax.scatter(samples[:, 0], samples[:, 1], alpha=0.5, c='red', label='Samples')
|
45 |
+
ax.set_title(title)
|
46 |
+
ax.set_xlabel("x1")
|
47 |
+
ax.set_ylabel("x2")
|
48 |
+
ax.grid(True)
|
49 |
+
ax.legend(loc='upper right')
|
50 |
+
ax.set_xlim(-5, 5)
|
51 |
+
ax.set_ylim(-5, 5)
|
52 |
+
ax.set_aspect('equal', adjustable='box')
|
53 |
+
plt.close(fig)
|
54 |
+
return fig, ax
|
55 |
+
|
56 |
+
def create_animation(fig, ax, frames, intermediate_points, intermediate_samples, intermediate_contours, intermediate_grid):
|
57 |
+
scatter_grid = ax.scatter([], [], c='black', alpha=0.5, s=1, label='Grid Points')
|
58 |
+
contour_scatter = ax.scatter([], [], c='blue', alpha=0.5, s=3, label='Contours')
|
59 |
+
scatter_samples = ax.scatter([], [], c='red', alpha=0.5, label='Samples')
|
60 |
+
|
61 |
+
def update(frame):
|
62 |
+
scatter_grid.set_offsets(intermediate_points[frame].numpy())
|
63 |
+
scatter_samples.set_offsets(intermediate_samples[frame].numpy())
|
64 |
+
contour_scatter.set_offsets(intermediate_contours[frame].numpy())
|
65 |
+
return scatter_grid, scatter_samples, contour_scatter
|
66 |
+
|
67 |
+
anim = FuncAnimation(fig, update, frames=frames, blit=True)
|
68 |
+
return anim
|