Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.mixture import GaussianMixture
|
5 |
+
from sklearn.utils.extmath import row_norms
|
6 |
+
from sklearn.datasets._samples_generator import make_blobs
|
7 |
+
from timeit import default_timer as timer
|
8 |
+
|
9 |
+
def initialize_gmm(X, method, n_components, max_iter):
|
10 |
+
n_samples = X.shape[0]
|
11 |
+
x_squared_norms = row_norms(X, squared=True)
|
12 |
+
|
13 |
+
def get_initial_means(X, init_params, r):
|
14 |
+
gmm = GaussianMixture(
|
15 |
+
n_components=n_components, init_params=init_params, tol=1e-9, max_iter=0, random_state=r
|
16 |
+
).fit(X)
|
17 |
+
return gmm.means_
|
18 |
+
|
19 |
+
r = np.random.RandomState(seed=1234)
|
20 |
+
start = timer()
|
21 |
+
ini = get_initial_means(X, method, r)
|
22 |
+
end = timer()
|
23 |
+
init_time = end - start
|
24 |
+
|
25 |
+
gmm = GaussianMixture(
|
26 |
+
n_components=n_components, means_init=ini, tol=1e-9, max_iter=max_iter, random_state=r
|
27 |
+
).fit(X)
|
28 |
+
|
29 |
+
return gmm, ini, init_time
|
30 |
+
|
31 |
+
|
32 |
+
def visualize_gmm(X, gmm, ini, init_time, n_components):
|
33 |
+
methods = ["kmeans", "random_from_data", "k-means++", "random"]
|
34 |
+
colors = ["navy", "turquoise", "cornflowerblue", "darkorange"]
|
35 |
+
times_init = {}
|
36 |
+
relative_times = {}
|
37 |
+
|
38 |
+
plt.figure(figsize=(4 * len(methods) // 2, 6))
|
39 |
+
plt.subplots_adjust(
|
40 |
+
bottom=0.1, top=0.9, hspace=0.15, wspace=0.05, left=0.05, right=0.95
|
41 |
+
)
|
42 |
+
|
43 |
+
for n, method in enumerate(methods):
|
44 |
+
r = np.random.RandomState(seed=1234)
|
45 |
+
plt.subplot(2, len(methods) // 2, n + 1)
|
46 |
+
|
47 |
+
gmm = GaussianMixture(
|
48 |
+
n_components=n_components, means_init=ini, tol=1e-9, max_iter=gmm.n_iter_, random_state=r
|
49 |
+
).fit(X)
|
50 |
+
|
51 |
+
times_init[method] = init_time
|
52 |
+
for i, color in enumerate(colors):
|
53 |
+
data = X[gmm.predict(X) == i]
|
54 |
+
plt.scatter(data[:, 0], data[:, 1], color=color, marker="x")
|
55 |
+
|
56 |
+
plt.scatter(
|
57 |
+
ini[:, 0], ini[:, 1], s=75, marker="D", c="orange", lw=1.5, edgecolors="black"
|
58 |
+
)
|
59 |
+
relative_times[method] = times_init[method] / times_init[methods[0]]
|
60 |
+
|
61 |
+
plt.xticks(())
|
62 |
+
plt.yticks(())
|
63 |
+
plt.title(method, loc="left", fontsize=12)
|
64 |
+
plt.title(
|
65 |
+
"Iter %i | Init Time %.2fx" % (gmm.n_iter_, relative_times[method]),
|
66 |
+
loc="right",
|
67 |
+
fontsize=10,
|
68 |
+
)
|
69 |
+
|
70 |
+
return plt
|
71 |
+
|
72 |
+
# Generate some data
|
73 |
+
X, y_true = make_blobs(n_samples=4000, centers=4, cluster_std=0.60, random_state=0)
|
74 |
+
X = X[:, ::-1]
|
75 |
+
|
76 |
+
def run_gmm(method, n_components=4, max_iter=2000):
|
77 |
+
gmm, ini, init_time = initialize_gmm(X, method, int(n_components), int(max_iter))
|
78 |
+
plot = visualize_gmm(X, gmm, ini, init_time, int(n_components))
|
79 |
+
return plot
|
80 |
+
|
81 |
+
iface = gr.Interface(
|
82 |
+
fn=run_gmm,
|
83 |
+
title="Gaussian Mixture Model Initialization Methods",
|
84 |
+
description="GMM Initialization Methods is a visualization tool showcasing different initialization methods in Gaussian Mixture Models. The example demonstrates four initialization approaches: kmeans (default), random, random_from_data, and k-means++. The plot displays orange diamonds representing the initialization centers for each method, while crosses represent the data points with color-coded classifications after GMM convergence. The numbers in the subplots indicate the iteration count and relative initialization time. Alternative methods show lower initialization times but may require more iterations to converge. Notably, k-means++ achieves a good balance of fast initialization and convergence. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_init.html",
|
85 |
+
inputs=[
|
86 |
+
gr.inputs.Dropdown(["kmeans", "random_from_data", "k-means++", "random"], label="Method", default="kmeans"),
|
87 |
+
gr.inputs.Number(default=4, label="Number of Components"),
|
88 |
+
gr.inputs.Number(default=2000, label="Max Iterations")
|
89 |
+
],
|
90 |
+
outputs="plot",
|
91 |
+
examples=[
|
92 |
+
["kmeans", 4, 2000],
|
93 |
+
["random_from_data", 3, 1000],
|
94 |
+
["k-means++", 8, 1000],
|
95 |
+
["random", 11, 1000],
|
96 |
+
],
|
97 |
+
)
|
98 |
+
|
99 |
+
iface.launch()
|