caliex commited on
Commit
3f1124b
·
1 Parent(s): 3287554

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ from tempfile import NamedTemporaryFile
4
+
5
+ from sklearn.neighbors import KNeighborsTransformer, KNeighborsClassifier
6
+ from sklearn.model_selection import GridSearchCV
7
+ from sklearn.datasets import load_digits
8
+ from sklearn.pipeline import Pipeline
9
+
10
+ def classify_digits(n_neighbors):
11
+ X, y = load_digits(return_X_y=True)
12
+ n_neighbors_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
13
+
14
+ graph_model = KNeighborsTransformer(n_neighbors=max(n_neighbors_list), mode="distance")
15
+ classifier_model = KNeighborsClassifier(metric="precomputed")
16
+
17
+ full_model = Pipeline(
18
+ steps=[("graph", graph_model), ("classifier", classifier_model)]
19
+ )
20
+
21
+ param_grid = {"classifier__n_neighbors": n_neighbors_list}
22
+ grid_model = GridSearchCV(full_model, param_grid)
23
+ grid_model.fit(X, y)
24
+
25
+ # Plot the results of the grid search.
26
+ fig, axes = plt.subplots(1, 2, figsize=(8, 4))
27
+ axes[0].errorbar(
28
+ x=n_neighbors_list,
29
+ y=grid_model.cv_results_["mean_test_score"],
30
+ yerr=grid_model.cv_results_["std_test_score"],
31
+ )
32
+ axes[0].set(xlabel="n_neighbors", title="Classification accuracy")
33
+ axes[1].errorbar(
34
+ x=n_neighbors_list,
35
+ y=grid_model.cv_results_["mean_fit_time"],
36
+ yerr=grid_model.cv_results_["std_fit_time"],
37
+ color="r",
38
+ )
39
+ axes[1].set(xlabel="n_neighbors", title="Fit time (with caching)")
40
+ fig.tight_layout()
41
+
42
+ # Save the plot to a temporary file
43
+ with NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
44
+ plot_path = temp_file.name
45
+ plt.savefig(plot_path)
46
+
47
+ plt.close()
48
+
49
+ return plot_path
50
+
51
+ # Create a Gradio interface with adjustable parameters
52
+ n_neighbors_input = gr.inputs.Slider(minimum=1, maximum=10, default=5, step=1, label="Number of Neighbors")
53
+ plot_output = gr.outputs.Image(type="pil")
54
+
55
+ iface = gr.Interface(
56
+ fn=classify_digits,
57
+ inputs=n_neighbors_input,
58
+ outputs=plot_output,
59
+ title="Digits Classifier",
60
+ description="This example demonstrates how to precompute the k nearest neighbors before using them in KNeighborsClassifier. KNeighborsClassifier can compute the nearest neighbors internally, but precomputing them can have several benefits, such as finer parameter control, caching for multiple use, or custom implementations. See the original scikit-learn example [here](https://scikit-learn.org/stable/auto_examples/neighbors/plot_caching_nearest_neighbors.html).",
61
+ examples=[
62
+ ["2"], # Example 1
63
+ ["7"], # Example 2
64
+ ["4"], # Example 3
65
+ ]
66
+ )
67
+ iface.launch()