tivnanmatt commited on
Commit
7bc83e9
·
1 Parent(s): ca10114

initial commit of app

Browse files
Files changed (3) hide show
  1. app.py +42 -0
  2. gmm.py +151 -0
  3. utils.py +68 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ from utils import initialize_gmm, generate_grid, generate_contours, generate_intermediate_points, plot_samples_and_contours, create_animation
5
+ import matplotlib.pyplot as plt
6
+
7
+ def visualize_gmm(mu_list, Sigma_list, pi_list, dx, dtheta, T, N):
8
+ gmm = initialize_gmm(mu_list, Sigma_list, pi_list)
9
+ grid_points = generate_grid(dx)
10
+ std_normal_contours = generate_contours(dtheta)
11
+ gmm_samples = gmm.sample(500)
12
+ intermediate_points = generate_intermediate_points(gmm, grid_points, std_normal_contours, gmm_samples, T, N)
13
+
14
+ fig1, ax1 = plot_samples_and_contours(gmm_samples, std_normal_contours, grid_points, "GMM Samples and Contours")
15
+ fig2, ax2 = plot_samples_and_contours(gmm_samples, std_normal_contours, grid_points, "Standard Normal Samples and Contours")
16
+
17
+ anim1 = create_animation(fig1, ax1, N, *intermediate_points[:3])
18
+ anim2 = create_animation(fig2, ax2, N, *intermediate_points[3:])
19
+
20
+ return fig1, fig2, anim1.to_jshtml(), anim2.to_jshtml()
21
+
22
+ demo = gr.Interface(
23
+ fn=visualize_gmm,
24
+ inputs=[
25
+ gr.Textbox(label="Mu List", placeholder="Enter means as a list of lists, e.g., [[0,0], [1,1]]"),
26
+ gr.Textbox(label="Sigma List", placeholder="Enter covariances as a list of lists, e.g., [[[0.2, 0.1], [0.1, 0.3]], [[1.0, -0.1], [-0.1, 0.1]]]"),
27
+ gr.Textbox(label="Pi List", placeholder="Enter weights as a list, e.g., [0.5, 0.5]"),
28
+ gr.Slider(minimum=0.01, maximum=1.0, label="dx", default=0.1),
29
+ gr.Slider(minimum=0.01, maximum=0.1, label="dtheta", default=0.01),
30
+ gr.Slider(minimum=1, maximum=100, label="T", default=10),
31
+ gr.Slider(minimum=1, maximum=500, label="N", default=100)
32
+ ],
33
+ outputs=[
34
+ gr.Plot(label="GMM to Normal Flow"),
35
+ gr.Plot(label="Normal to GMM Flow"),
36
+ gr.HTML(label="GMM to Normal Animation"),
37
+ gr.HTML(label="Normal to GMM Animation")
38
+ ],
39
+ live=True
40
+ )
41
+
42
+ demo.launch()
gmm.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class GaussianMixtureModel:
5
+ def __init__(self, mu, Sigma, pi):
6
+ # Ensure all inputs are torch tensors
7
+ assert isinstance(mu, torch.Tensor), "mu must be a torch.Tensor."
8
+ assert isinstance(Sigma, torch.Tensor), "Sigma must be a torch.Tensor."
9
+ assert isinstance(pi, torch.Tensor), "pi must be a torch.Tensor."
10
+
11
+ # Ensure the dtype is torch.float32
12
+ assert mu.dtype == torch.float32, "mu must have dtype torch.float32."
13
+ assert Sigma.dtype == torch.float32, "Sigma must have dtype torch.float32."
14
+ assert pi.dtype == torch.float32, "pi must have dtype torch.float32."
15
+
16
+ self.K, self.d = mu.shape[:2]
17
+
18
+ # Check the shape of mu
19
+ if mu.shape == (self.K, self.d):
20
+ mu = mu.unsqueeze(-1) # Shape: (K, d, 1)
21
+ assert mu.shape == (self.K, self.d, 1), "mu must have shape (K, d, 1)."
22
+
23
+ # Check the shape of Sigma
24
+ assert Sigma.shape == (self.K, self.d, self.d), "Sigma must have shape (K, d, d)."
25
+
26
+ # Check the shape of pi and fix it if necessary
27
+ if pi.shape == (self.K,):
28
+ pi = pi.view(self.K, 1, 1)
29
+ elif pi.shape == (self.K, 1):
30
+ pi = pi.unsqueeze(-1)
31
+ assert pi.shape == (self.K, 1, 1), "pi must have shape (K, 1, 1)."
32
+
33
+ # Ensure pi sums to 1
34
+ assert torch.isclose(torch.sum(pi), torch.tensor(1.0)), "Mixture weights must sum to 1."
35
+
36
+ self.mu = mu
37
+ self.Sigma = Sigma
38
+ self.pi = pi
39
+
40
+ def sample(self, n_samples):
41
+ # Sample from the mixture model
42
+ samples = []
43
+ for _ in range(n_samples):
44
+ # Choose a component based on mixture weights
45
+ k = torch.multinomial(self.pi.squeeze(), 1).item()
46
+ sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample()
47
+ samples.append(sample)
48
+ return torch.stack(samples)
49
+
50
+ def log_prob(self, x):
51
+ # Compute the log probability of a given sample x
52
+ x = x.view(1, self.d, 1) # Shape: (1, d, 1)
53
+ diff = x - self.mu # Shape: (K, d, 1)
54
+ inv_Sigma = torch.inverse(self.Sigma) # Shape: (K, d, d)
55
+
56
+ exponent = -0.5 * torch.bmm(torch.bmm(diff.transpose(1, 2), inv_Sigma), diff).squeeze() # Shape: (K,)
57
+ normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)) / 2 # Shape: (K,)
58
+ log_probs = torch.log(self.pi.squeeze()) + exponent - normalization # Shape: (K,)
59
+ return torch.logsumexp(log_probs, dim=0)
60
+
61
+ def score(self, x):
62
+ # Compute the score function (gradient of log probability) for a batch of samples x
63
+ B = x.shape[0]
64
+ x = x.view(B, 1, self.d, 1) # Shape: (B, 1, d, 1)
65
+ diff = x - self.mu.unsqueeze(0) # Shape: (B, K, d, 1)
66
+ inv_Sigma = torch.inverse(self.Sigma).unsqueeze(0) # Shape: (1, K, d, d)
67
+
68
+ diff_t = diff.transpose(-2, -1).contiguous().view(B * self.K, 1, self.d) # Shape: (B*K, 1, d)
69
+ inv_Sigma_flat = inv_Sigma.view(self.K, self.d, self.d).expand(B, self.K, self.d, self.d).contiguous().view(B * self.K, self.d, self.d) # Shape: (B*K, d, d)
70
+ exponent = -0.5 * torch.bmm(torch.bmm(diff_t, inv_Sigma_flat), diff.view(B * self.K, self.d, 1)).view(B, self.K) # Shape: (B, K)
71
+
72
+ normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)).unsqueeze(0) / 2 # Shape: (1, K)
73
+ probs = torch.exp(torch.log(self.pi.squeeze()).unsqueeze(0) + exponent - normalization) # Shape: (B, K)
74
+ norm_probs = probs / torch.sum(probs, dim=1, keepdim=True) # Shape: (B, K)
75
+
76
+ gradients = []
77
+ for k in range(self.K):
78
+ gradient = -torch.bmm(inv_Sigma[:, k].expand(B, self.d, self.d), diff[:, k]) # Shape: (B, d, 1)
79
+ gradients.append(norm_probs[:, k].unsqueeze(1) * gradient.squeeze(-1)) # Shape: (B, d)
80
+ return torch.sum(torch.stack(gradients, dim=1), dim=1) # Shape: (B, d)
81
+
82
+ def forward_diffusion(self, t):
83
+ # Compute the evolved mean and covariance
84
+ mu_t = self.mu * torch.tensor(np.exp(-0.5 * t), dtype=torch.float32) # Shape: (K, d, 1)
85
+ exp_neg_beta_t = torch.tensor(np.exp(-t), dtype=torch.float32).view(1, 1, 1) # Shape: (1, 1, 1)
86
+ Sigma_t = self.Sigma * exp_neg_beta_t + torch.eye(self.d, dtype=torch.float32) * (1 - exp_neg_beta_t) # Shape: (K, d, d)
87
+ return GaussianMixtureModel(mu_t, Sigma_t, self.pi)
88
+
89
+ def flow(self, x_t, t, dt, num_steps):
90
+ for _ in range(num_steps):
91
+ x_t = self.probability_flow_ode(x_t, t, dt)
92
+ t += dt
93
+ return x_t
94
+
95
+ def flow_gmm_to_normal(self, x_0, T=5, N=32):
96
+ dt = T / N
97
+ return self.flow(x_0, 0, dt, N)
98
+
99
+ def flow_normal_to_gmm(self, x_T, T=5, N=32):
100
+ dt = -T / N
101
+ return self.flow(x_T, T, dt, N)
102
+
103
+ def probability_flow_ode(self, x_t, t, dt):
104
+ # Compute the evolved x_t based on the probability flow ODE using the Euler method
105
+ # Forward diffusion to get the evolved GMM parameters
106
+ gmm_t = self.forward_diffusion(t)
107
+
108
+ # Compute the score function at x_t
109
+ score = gmm_t.score(x_t)
110
+
111
+ # Compute the drift term
112
+ drift = -0.5 * x_t - 0.5 * score
113
+
114
+ # Euler update
115
+ x_t_plus_dt = x_t + drift * dt
116
+
117
+ return x_t_plus_dt
118
+
119
+ # Example usage
120
+ if __name__ == "__main__":
121
+ mu = torch.tensor([[0, 0], [1, 1]], dtype=torch.float32) # Shape: (K, d)
122
+ Sigma = torch.stack([torch.eye(2), torch.eye(2)], dim=0).float() # Shape: (K, d, d)
123
+ pi = torch.tensor([0.5, 0.5], dtype=torch.float32) # Shape: (K,)
124
+
125
+ gmm = GaussianMixtureModel(mu, Sigma, pi)
126
+
127
+ # Perform forward diffusion
128
+ t = 1.0
129
+ gmm_t = gmm.forward_diffusion(t)
130
+
131
+ # Sampling from the diffused GMM
132
+ samples = gmm_t.sample(10)
133
+ print("Samples after forward diffusion:\n", samples)
134
+
135
+ # Log probability of a sample after forward diffusion
136
+ log_prob = gmm_t.log_prob(torch.tensor([0.5, 0.5], dtype=torch.float32))
137
+ print("Log Probability after forward diffusion:", log_prob)
138
+
139
+ # Score of a sample after forward diffusion
140
+ score = gmm_t.score(torch.tensor([[0.5, 0.5]], dtype=torch.float32))
141
+ print("Score after forward diffusion:", score)
142
+
143
+ # Using the flow_gmm_to_normal
144
+ x_0 = torch.tensor([[0.5, 0.5]], dtype=torch.float32)
145
+ x_T_normal = gmm.flow_gmm_to_normal(x_0)
146
+ print("x_T_normal after flowing from GMM to normal:", x_T_normal)
147
+
148
+ # Using the flow_normal_to_gmm
149
+ x_T = torch.tensor([[0.0, 0.0]], dtype=torch.float32)
150
+ x_0_gmm = gmm.flow_normal_to_gmm(x_T)
151
+ print("x_0_gmm after flowing from normal to GMM:", x_0_gmm)
utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.animation import FuncAnimation
7
+ from gmm import GaussianMixtureModel
8
+
9
+ def initialize_gmm(mu_list, Sigma_list, pi_list):
10
+ mu = torch.tensor(mu_list, dtype=torch.float32)
11
+ Sigma = torch.tensor(Sigma_list, dtype=torch.float32)
12
+ pi = torch.tensor(pi_list, dtype=torch.float32)
13
+ return GaussianMixtureModel(mu, Sigma, pi)
14
+
15
+ def generate_grid(dx):
16
+ x_positions = np.arange(-10, 10.5, 0.5)
17
+ y_positions = np.arange(-10, 10.5, 0.5)
18
+ vertical_lines = [np.stack([np.full(int((10 - (-10))/ dx + 1), x), np.arange(-10, 10 + dx, dx)], axis=1) for x in x_positions]
19
+ horizontal_lines = [np.stack([np.arange(-10, 10 + dx, dx), np.full(int((10 - (-10)) / dx + 1), y)], axis=1) for y in y_positions]
20
+ grid_points = np.concatenate(vertical_lines + horizontal_lines, axis=0)
21
+ return torch.tensor(grid_points, dtype=torch.float32)
22
+
23
+ def generate_contours(dtheta):
24
+ angles = np.linspace(0, 2 * np.pi, int(2 * np.pi / dtheta))
25
+ std_normal_contours = np.concatenate([np.stack([r * np.cos(angles), r * np.sin(angles)], axis=1) for r in range(1, 4)], axis=0)
26
+ return torch.tensor(std_normal_contours, dtype=torch.float32)
27
+
28
+ def generate_intermediate_points(gmm, grid_points, std_normal_contours, gmm_samples, T, N):
29
+ intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(grid_points, T, N)
30
+ contour_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(std_normal_contours, T, N)
31
+ grid_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(grid_points, T, N)
32
+
33
+ intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(gmm_samples, T, N)
34
+ contour_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(std_normal_contours, T, N)
35
+ grid_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(grid_points, T, N)
36
+
37
+ return (intermediate_points_gmm_to_normal, contour_intermediate_points_gmm_to_normal, grid_intermediate_points_gmm_to_normal,
38
+ intermediate_points_normal_to_gmm, contour_intermediate_points_normal_to_gmm, grid_intermediate_points_normal_to_gmm)
39
+
40
+ def plot_samples_and_contours(samples, contours, grid_points, title):
41
+ fig, ax = plt.subplots(figsize=(8, 6))
42
+ ax.scatter(grid_points[:, 0], grid_points[:, 1], alpha=0.5, c='black', s=1, label='Grid Points')
43
+ ax.scatter(contours[:, 0], contours[:, 1], alpha=0.5, s=3, c='blue', label='Contours')
44
+ ax.scatter(samples[:, 0], samples[:, 1], alpha=0.5, c='red', label='Samples')
45
+ ax.set_title(title)
46
+ ax.set_xlabel("x1")
47
+ ax.set_ylabel("x2")
48
+ ax.grid(True)
49
+ ax.legend(loc='upper right')
50
+ ax.set_xlim(-5, 5)
51
+ ax.set_ylim(-5, 5)
52
+ ax.set_aspect('equal', adjustable='box')
53
+ plt.close(fig)
54
+ return fig, ax
55
+
56
+ def create_animation(fig, ax, frames, intermediate_points, intermediate_samples, intermediate_contours, intermediate_grid):
57
+ scatter_grid = ax.scatter([], [], c='black', alpha=0.5, s=1, label='Grid Points')
58
+ contour_scatter = ax.scatter([], [], c='blue', alpha=0.5, s=3, label='Contours')
59
+ scatter_samples = ax.scatter([], [], c='red', alpha=0.5, label='Samples')
60
+
61
+ def update(frame):
62
+ scatter_grid.set_offsets(intermediate_points[frame].numpy())
63
+ scatter_samples.set_offsets(intermediate_samples[frame].numpy())
64
+ contour_scatter.set_offsets(intermediate_contours[frame].numpy())
65
+ return scatter_grid, scatter_samples, contour_scatter
66
+
67
+ anim = FuncAnimation(fig, update, frames=frames, blit=True)
68
+ return anim