caliex commited on
Commit
ff66f4a
·
1 Parent(s): 3250df6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from sklearn.mixture import GaussianMixture
5
+ from sklearn.utils.extmath import row_norms
6
+ from sklearn.datasets._samples_generator import make_blobs
7
+ from timeit import default_timer as timer
8
+
9
+ def initialize_gmm(X, method, n_components, max_iter):
10
+ n_samples = X.shape[0]
11
+ x_squared_norms = row_norms(X, squared=True)
12
+
13
+ def get_initial_means(X, init_params, r):
14
+ gmm = GaussianMixture(
15
+ n_components=n_components, init_params=init_params, tol=1e-9, max_iter=0, random_state=r
16
+ ).fit(X)
17
+ return gmm.means_
18
+
19
+ r = np.random.RandomState(seed=1234)
20
+ start = timer()
21
+ ini = get_initial_means(X, method, r)
22
+ end = timer()
23
+ init_time = end - start
24
+
25
+ gmm = GaussianMixture(
26
+ n_components=n_components, means_init=ini, tol=1e-9, max_iter=max_iter, random_state=r
27
+ ).fit(X)
28
+
29
+ return gmm, ini, init_time
30
+
31
+
32
+ def visualize_gmm(X, gmm, ini, init_time, n_components):
33
+ methods = ["kmeans", "random_from_data", "k-means++", "random"]
34
+ colors = ["navy", "turquoise", "cornflowerblue", "darkorange"]
35
+ times_init = {}
36
+ relative_times = {}
37
+
38
+ plt.figure(figsize=(4 * len(methods) // 2, 6))
39
+ plt.subplots_adjust(
40
+ bottom=0.1, top=0.9, hspace=0.15, wspace=0.05, left=0.05, right=0.95
41
+ )
42
+
43
+ for n, method in enumerate(methods):
44
+ r = np.random.RandomState(seed=1234)
45
+ plt.subplot(2, len(methods) // 2, n + 1)
46
+
47
+ gmm = GaussianMixture(
48
+ n_components=n_components, means_init=ini, tol=1e-9, max_iter=gmm.n_iter_, random_state=r
49
+ ).fit(X)
50
+
51
+ times_init[method] = init_time
52
+ for i, color in enumerate(colors):
53
+ data = X[gmm.predict(X) == i]
54
+ plt.scatter(data[:, 0], data[:, 1], color=color, marker="x")
55
+
56
+ plt.scatter(
57
+ ini[:, 0], ini[:, 1], s=75, marker="D", c="orange", lw=1.5, edgecolors="black"
58
+ )
59
+ relative_times[method] = times_init[method] / times_init[methods[0]]
60
+
61
+ plt.xticks(())
62
+ plt.yticks(())
63
+ plt.title(method, loc="left", fontsize=12)
64
+ plt.title(
65
+ "Iter %i | Init Time %.2fx" % (gmm.n_iter_, relative_times[method]),
66
+ loc="right",
67
+ fontsize=10,
68
+ )
69
+
70
+ return plt
71
+
72
+ # Generate some data
73
+ X, y_true = make_blobs(n_samples=4000, centers=4, cluster_std=0.60, random_state=0)
74
+ X = X[:, ::-1]
75
+
76
+ def run_gmm(method, n_components=4, max_iter=2000):
77
+ gmm, ini, init_time = initialize_gmm(X, method, int(n_components), int(max_iter))
78
+ plot = visualize_gmm(X, gmm, ini, init_time, int(n_components))
79
+ return plot
80
+
81
+ iface = gr.Interface(
82
+ fn=run_gmm,
83
+ title="Gaussian Mixture Model Initialization Methods",
84
+ description="GMM Initialization Methods is a visualization tool showcasing different initialization methods in Gaussian Mixture Models. The example demonstrates four initialization approaches: kmeans (default), random, random_from_data, and k-means++. The plot displays orange diamonds representing the initialization centers for each method, while crosses represent the data points with color-coded classifications after GMM convergence. The numbers in the subplots indicate the iteration count and relative initialization time. Alternative methods show lower initialization times but may require more iterations to converge. Notably, k-means++ achieves a good balance of fast initialization and convergence. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_init.html",
85
+ inputs=[
86
+ gr.inputs.Dropdown(["kmeans", "random_from_data", "k-means++", "random"], label="Method", default="kmeans"),
87
+ gr.inputs.Number(default=4, label="Number of Components"),
88
+ gr.inputs.Number(default=2000, label="Max Iterations")
89
+ ],
90
+ outputs="plot",
91
+ examples=[
92
+ ["kmeans", 4, 2000],
93
+ ["random_from_data", 3, 1000],
94
+ ["k-means++", 8, 1000],
95
+ ["random", 11, 1000],
96
+ ],
97
+ )
98
+
99
+ iface.launch()