Spaces:
Sleeping
Sleeping
tivnanmatt
commited on
Commit
·
7b1ccbd
1
Parent(s):
7bc83e9
working version
Browse files
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
CHANGED
@@ -1,42 +1,84 @@
|
|
1 |
# app.py
|
2 |
|
3 |
import gradio as gr
|
4 |
-
from utils import initialize_gmm, generate_grid, generate_contours, generate_intermediate_points, plot_samples_and_contours
|
5 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def visualize_gmm(mu_list, Sigma_list, pi_list, dx, dtheta, T, N):
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
demo = gr.Interface(
|
23 |
fn=visualize_gmm,
|
24 |
inputs=[
|
25 |
-
gr.Textbox(label="Mu List", placeholder="Enter means as a list of lists, e.g., [[0,0], [1,1]]"),
|
26 |
-
gr.Textbox(label="Sigma List", placeholder="Enter covariances as a list of lists, e.g., [[[0.2, 0.1], [0.1, 0.3]], [[1.0, -0.1], [-0.1, 0.1]]]"),
|
27 |
-
gr.Textbox(label="Pi List", placeholder="Enter weights as a list, e.g., [0.5, 0.5]"),
|
28 |
-
gr.Slider(minimum=0.01, maximum=1.0, label="dx",
|
29 |
-
gr.Slider(minimum=
|
30 |
-
gr.Slider(minimum=1, maximum=100, label="T",
|
31 |
-
gr.Slider(minimum=1, maximum=500, label="N",
|
32 |
],
|
33 |
outputs=[
|
34 |
-
gr.Plot(label="GMM to Normal Flow"),
|
35 |
-
gr.Plot(label="Normal to GMM Flow")
|
36 |
-
gr.HTML(label="GMM to Normal Animation"),
|
37 |
-
gr.HTML(label="Normal to GMM Animation")
|
38 |
],
|
39 |
live=True
|
40 |
)
|
41 |
|
42 |
-
demo.launch()
|
|
|
1 |
# app.py
|
2 |
|
3 |
import gradio as gr
|
4 |
+
from utils import initialize_gmm, generate_grid, generate_contours, generate_intermediate_points, plot_samples_and_contours
|
5 |
import matplotlib.pyplot as plt
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
def validate_inputs(mu_list, Sigma_list, pi_list):
|
10 |
+
try:
|
11 |
+
mu = eval(mu_list)
|
12 |
+
Sigma = eval(Sigma_list)
|
13 |
+
pi = eval(pi_list)
|
14 |
+
|
15 |
+
if not (isinstance(mu, list) and all(isinstance(i, list) for i in mu)):
|
16 |
+
return False, "Mu list is invalid."
|
17 |
+
if not (isinstance(Sigma, list) and all(isinstance(i, list) for i in Sigma)):
|
18 |
+
return False, "Sigma list is invalid."
|
19 |
+
if not isinstance(pi, list):
|
20 |
+
return False, "Pi list is invalid."
|
21 |
+
|
22 |
+
if not torch.isclose(torch.tensor(pi).sum(), torch.tensor(1.0)):
|
23 |
+
return False, "Mixture weights must sum to 1."
|
24 |
+
|
25 |
+
return True, ""
|
26 |
+
except Exception as e:
|
27 |
+
return False, str(e)
|
28 |
|
29 |
def visualize_gmm(mu_list, Sigma_list, pi_list, dx, dtheta, T, N):
|
30 |
+
is_valid, error_message = validate_inputs(mu_list, Sigma_list, pi_list)
|
31 |
+
if not is_valid:
|
32 |
+
fig, ax = plt.subplots()
|
33 |
+
ax.text(0.5, 0.5, f'Invalid input: {error_message}', horizontalalignment='center', verticalalignment='center')
|
34 |
+
ax.set_xlim(-5, 5)
|
35 |
+
ax.set_ylim(-5, 5)
|
36 |
+
ax.set_aspect('equal', adjustable='box')
|
37 |
+
plt.close(fig)
|
38 |
+
return fig, fig
|
39 |
+
|
40 |
+
try:
|
41 |
+
gmm = initialize_gmm(eval(mu_list), eval(Sigma_list), eval(pi_list))
|
42 |
+
grid_points = generate_grid(dx)
|
43 |
+
std_normal_contours = generate_contours(dtheta)
|
44 |
+
gmm_samples = gmm.sample(500)
|
45 |
+
normal_samples = torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)).sample((500,))
|
46 |
+
(intermediate_points_gmm_to_normal, contour_intermediate_points_gmm_to_normal, grid_intermediate_points_gmm_to_normal,
|
47 |
+
intermediate_points_normal_to_gmm, contour_intermediate_points_normal_to_gmm, grid_intermediate_points_normal_to_gmm) = \
|
48 |
+
generate_intermediate_points(gmm, grid_points, std_normal_contours, gmm_samples, normal_samples, T, N)
|
49 |
+
|
50 |
+
final_frame_gmm_to_normal = intermediate_points_gmm_to_normal.cpu().detach().numpy()
|
51 |
+
final_frame_normal_to_gmm = intermediate_points_normal_to_gmm.cpu().detach().numpy()
|
52 |
+
|
53 |
+
fig1, ax1 = plot_samples_and_contours(final_frame_gmm_to_normal, contour_intermediate_points_gmm_to_normal.cpu().detach().numpy(), grid_intermediate_points_gmm_to_normal.cpu().detach().numpy(), "GMM to Normal Final Frame")
|
54 |
+
fig2, ax2 = plot_samples_and_contours(final_frame_normal_to_gmm, contour_intermediate_points_normal_to_gmm.cpu().detach().numpy(), grid_intermediate_points_normal_to_gmm.cpu().detach().numpy(), "Normal to GMM Final Frame")
|
55 |
+
|
56 |
+
return fig1, fig2
|
57 |
+
except Exception as e:
|
58 |
+
fig, ax = plt.subplots()
|
59 |
+
ax.text(0.5, 0.5, f'Error during visualization: {str(e)}', horizontalalignment='center', verticalalignment='center')
|
60 |
+
ax.set_xlim(-5, 5)
|
61 |
+
ax.set_ylim(-5, 5)
|
62 |
+
ax.set_aspect('equal', adjustable='box')
|
63 |
+
plt.close(fig)
|
64 |
+
return fig, fig
|
65 |
|
66 |
demo = gr.Interface(
|
67 |
fn=visualize_gmm,
|
68 |
inputs=[
|
69 |
+
gr.Textbox(label="Mu List", value="[[2, 1], [-1, -2], [3, -2]]", placeholder="Enter means as a list of lists, e.g., [[0,0], [1,1]]"),
|
70 |
+
gr.Textbox(label="Sigma List", value="[[[0.2, 0.1], [0.1, 0.3]], [[1.0, -0.1], [-0.1, 0.1]], [[0.05, 0.0], [0.0, 0.05]]]", placeholder="Enter covariances as a list of lists, e.g., [[[0.2, 0.1], [0.1, 0.3]], [[1.0, -0.1], [-0.1, 0.1]]]"),
|
71 |
+
gr.Textbox(label="Pi List", value="[0.05, 0.8, 0.15]", placeholder="Enter weights as a list, e.g., [0.5, 0.5]"),
|
72 |
+
gr.Slider(minimum=0.01, maximum=1.0, label="dx", value=0.1),
|
73 |
+
gr.Slider(minimum=2*np.pi/3600, maximum=2*np.pi/36, label="dtheta", value=2*np.pi/360),
|
74 |
+
gr.Slider(minimum=1, maximum=100, label="T", value=10),
|
75 |
+
gr.Slider(minimum=1, maximum=500, label="N", value=100)
|
76 |
],
|
77 |
outputs=[
|
78 |
+
gr.Plot(label="GMM to Normal Flow Final Frame"),
|
79 |
+
gr.Plot(label="Normal to GMM Flow Final Frame")
|
|
|
|
|
80 |
],
|
81 |
live=True
|
82 |
)
|
83 |
|
84 |
+
demo.launch()
|
gmm.py
CHANGED
@@ -42,7 +42,7 @@ class GaussianMixtureModel:
|
|
42 |
samples = []
|
43 |
for _ in range(n_samples):
|
44 |
# Choose a component based on mixture weights
|
45 |
-
k = torch.multinomial(self.pi.
|
46 |
sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample()
|
47 |
samples.append(sample)
|
48 |
return torch.stack(samples)
|
|
|
42 |
samples = []
|
43 |
for _ in range(n_samples):
|
44 |
# Choose a component based on mixture weights
|
45 |
+
k = torch.multinomial(self.pi.reshape(self.pi.shape[0]), 1).item()
|
46 |
sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample()
|
47 |
samples.append(sample)
|
48 |
return torch.stack(samples)
|
utils.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import matplotlib.pyplot as plt
|
6 |
-
from matplotlib.animation import FuncAnimation
|
7 |
from gmm import GaussianMixtureModel
|
8 |
|
9 |
def initialize_gmm(mu_list, Sigma_list, pi_list):
|
@@ -15,8 +14,10 @@ def initialize_gmm(mu_list, Sigma_list, pi_list):
|
|
15 |
def generate_grid(dx):
|
16 |
x_positions = np.arange(-10, 10.5, 0.5)
|
17 |
y_positions = np.arange(-10, 10.5, 0.5)
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
grid_points = np.concatenate(vertical_lines + horizontal_lines, axis=0)
|
21 |
return torch.tensor(grid_points, dtype=torch.float32)
|
22 |
|
@@ -25,14 +26,22 @@ def generate_contours(dtheta):
|
|
25 |
std_normal_contours = np.concatenate([np.stack([r * np.cos(angles), r * np.sin(angles)], axis=1) for r in range(1, 4)], axis=0)
|
26 |
return torch.tensor(std_normal_contours, dtype=torch.float32)
|
27 |
|
28 |
-
def
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
return (intermediate_points_gmm_to_normal, contour_intermediate_points_gmm_to_normal, grid_intermediate_points_gmm_to_normal,
|
38 |
intermediate_points_normal_to_gmm, contour_intermediate_points_normal_to_gmm, grid_intermediate_points_normal_to_gmm)
|
@@ -52,17 +61,3 @@ def plot_samples_and_contours(samples, contours, grid_points, title):
|
|
52 |
ax.set_aspect('equal', adjustable='box')
|
53 |
plt.close(fig)
|
54 |
return fig, ax
|
55 |
-
|
56 |
-
def create_animation(fig, ax, frames, intermediate_points, intermediate_samples, intermediate_contours, intermediate_grid):
|
57 |
-
scatter_grid = ax.scatter([], [], c='black', alpha=0.5, s=1, label='Grid Points')
|
58 |
-
contour_scatter = ax.scatter([], [], c='blue', alpha=0.5, s=3, label='Contours')
|
59 |
-
scatter_samples = ax.scatter([], [], c='red', alpha=0.5, label='Samples')
|
60 |
-
|
61 |
-
def update(frame):
|
62 |
-
scatter_grid.set_offsets(intermediate_points[frame].numpy())
|
63 |
-
scatter_samples.set_offsets(intermediate_samples[frame].numpy())
|
64 |
-
contour_scatter.set_offsets(intermediate_contours[frame].numpy())
|
65 |
-
return scatter_grid, scatter_samples, contour_scatter
|
66 |
-
|
67 |
-
anim = FuncAnimation(fig, update, frames=frames, blit=True)
|
68 |
-
return anim
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import matplotlib.pyplot as plt
|
|
|
6 |
from gmm import GaussianMixtureModel
|
7 |
|
8 |
def initialize_gmm(mu_list, Sigma_list, pi_list):
|
|
|
14 |
def generate_grid(dx):
|
15 |
x_positions = np.arange(-10, 10.5, 0.5)
|
16 |
y_positions = np.arange(-10, 10.5, 0.5)
|
17 |
+
fine_points = np.arange(-10, 10 + dx, dx)
|
18 |
+
ones_same_size = np.ones_like(fine_points)
|
19 |
+
vertical_lines = [np.stack([x*ones_same_size, fine_points], axis=1) for x in x_positions]
|
20 |
+
horizontal_lines = [np.stack([fine_points, y*ones_same_size], axis=1) for y in y_positions]
|
21 |
grid_points = np.concatenate(vertical_lines + horizontal_lines, axis=0)
|
22 |
return torch.tensor(grid_points, dtype=torch.float32)
|
23 |
|
|
|
26 |
std_normal_contours = np.concatenate([np.stack([r * np.cos(angles), r * np.sin(angles)], axis=1) for r in range(1, 4)], axis=0)
|
27 |
return torch.tensor(std_normal_contours, dtype=torch.float32)
|
28 |
|
29 |
+
def transform_std_to_gmm_contours(std_contours, mu, Sigma):
|
30 |
+
gmm_contours = []
|
31 |
+
for k in range(mu.shape[0]):
|
32 |
+
L = torch.linalg.cholesky(Sigma[k])
|
33 |
+
gmm_contours.append(mu[k] + torch.matmul(std_contours, L.T))
|
34 |
+
return torch.cat(gmm_contours, dim=0)
|
35 |
+
|
36 |
+
def generate_intermediate_points(gmm, grid_points, std_normal_contours, gmm_samples, normal_samples, T, N):
|
37 |
+
gmm_contours = transform_std_to_gmm_contours(std_normal_contours, gmm.mu.squeeze(), gmm.Sigma)
|
38 |
+
intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(gmm_samples.clone(), T, N)
|
39 |
+
contour_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(gmm_contours.clone(), T, N)
|
40 |
+
grid_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(grid_points.clone(), T, N)
|
41 |
+
|
42 |
+
intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(normal_samples.clone(), T, N)
|
43 |
+
contour_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(std_normal_contours.clone(), T, N)
|
44 |
+
grid_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(grid_points.clone(), T, N)
|
45 |
|
46 |
return (intermediate_points_gmm_to_normal, contour_intermediate_points_gmm_to_normal, grid_intermediate_points_gmm_to_normal,
|
47 |
intermediate_points_normal_to_gmm, contour_intermediate_points_normal_to_gmm, grid_intermediate_points_normal_to_gmm)
|
|
|
61 |
ax.set_aspect('equal', adjustable='box')
|
62 |
plt.close(fig)
|
63 |
return fig, ax
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|