File size: 3,336 Bytes
1c92b51 706be24 1c92b51 706be24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.neighbors import KernelDensity
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV
def generate_digits(bandwidth, num_samples):
# convert bandwidth to integer
bandwidth = int(bandwidth)
# convert num_samples to integer
num_samples = int(num_samples)
# load the data
digits = load_digits()
# project the 64-dimensional data to a lower dimension
pca = PCA(n_components=15, whiten=False)
data = pca.fit_transform(digits.data)
# use grid search cross-validation to optimize the bandwidth
params = {"bandwidth": np.logspace(-1, 1, 20)}
grid = GridSearchCV(KernelDensity(), params)
grid.fit(data)
# use the specified bandwidth to compute the kernel density estimate
kde = KernelDensity(bandwidth=bandwidth)
kde.fit(data)
# sample new points from the data
new_data = kde.sample(num_samples, random_state=0)
new_data = pca.inverse_transform(new_data)
# reshape the data into a 4x11 grid
new_data = new_data.reshape((num_samples, 64))
real_data = digits.data[:num_samples].reshape((num_samples, 64))
# create the plot
fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[]))
for j in range(11):
ax[4, j].set_visible(False)
for i in range(4):
index = i * 11 + j # Calculate the correct index
if index < num_samples:
im = ax[i, j].imshow(
real_data[index].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest"
)
im.set_clim(0, 16)
im = ax[i + 5, j].imshow(
new_data[index].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest"
)
im.set_clim(0, 16)
else:
ax[i, j].axis("off")
ax[i + 5, j].axis("off")
ax[0, 5].set_title("Selection from the input data")
ax[5, 5].set_title('"New" digits drawn from the kernel density model')
# save the plot to a file
plt.savefig("digits_plot.png")
# return the path to the generated plot
return "digits_plot.png"
# create the Gradio interface
inputs = [
gr.inputs.Slider(minimum=1, maximum=10, step=1, label="Bandwidth"),
# gr.inputs.Number(default=44, label="Number of Samples")
# Change to Slider
gr.inputs.Slider(minimum=1, maximum=100, step=1, label="Number of Samples")
]
output = gr.outputs.Image(type="pil")
title = "Kernel Density Estimation"
description = "This example shows how kernel density estimation (KDE), a powerful non-parametric density estimation technique, can be used to learn a generative model for a dataset. With this generative model in place, new samples can be drawn. These new samples reflect the underlying model of the data. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/neighbors/plot_digits_kde_sampling.html"
examples = [
[1, 44], # Changed to integer values
[8, 22], # Changed to integer values
[7, 51] # Changed to integer values
]
gr.Interface(generate_digits, inputs, output, title=title, description=description, examples=examples, live=True).launch()
|