|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
from sklearn.datasets import load_digits |
|
from sklearn.neighbors import KernelDensity |
|
from sklearn.decomposition import PCA |
|
from sklearn.model_selection import GridSearchCV |
|
|
|
def generate_digits(bandwidth, num_samples): |
|
|
|
|
|
bandwidth = int(bandwidth) |
|
|
|
|
|
num_samples = int(num_samples) |
|
|
|
|
|
digits = load_digits() |
|
|
|
|
|
pca = PCA(n_components=15, whiten=False) |
|
data = pca.fit_transform(digits.data) |
|
|
|
|
|
params = {"bandwidth": np.logspace(-1, 1, 20)} |
|
grid = GridSearchCV(KernelDensity(), params) |
|
grid.fit(data) |
|
|
|
|
|
kde = KernelDensity(bandwidth=bandwidth) |
|
kde.fit(data) |
|
|
|
|
|
new_data = kde.sample(num_samples, random_state=0) |
|
new_data = pca.inverse_transform(new_data) |
|
|
|
|
|
new_data = new_data.reshape((num_samples, 64)) |
|
real_data = digits.data[:num_samples].reshape((num_samples, 64)) |
|
|
|
|
|
fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[])) |
|
for j in range(11): |
|
ax[4, j].set_visible(False) |
|
for i in range(4): |
|
index = i * 11 + j |
|
if index < num_samples: |
|
im = ax[i, j].imshow( |
|
real_data[index].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest" |
|
) |
|
im.set_clim(0, 16) |
|
im = ax[i + 5, j].imshow( |
|
new_data[index].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest" |
|
) |
|
im.set_clim(0, 16) |
|
else: |
|
ax[i, j].axis("off") |
|
ax[i + 5, j].axis("off") |
|
|
|
ax[0, 5].set_title("Selection from the input data") |
|
ax[5, 5].set_title('"New" digits drawn from the kernel density model') |
|
|
|
|
|
|
|
plt.savefig("digits_plot.png") |
|
|
|
|
|
return "digits_plot.png" |
|
|
|
|
|
inputs = [ |
|
gr.inputs.Slider(minimum=1, maximum=10, step=1, label="Bandwidth"), |
|
|
|
|
|
gr.inputs.Slider(minimum=1, maximum=100, step=1, label="Number of Samples") |
|
] |
|
output = gr.outputs.Image(type="pil") |
|
|
|
title = "Kernel Density Estimation" |
|
description = "This example shows how kernel density estimation (KDE), a powerful non-parametric density estimation technique, can be used to learn a generative model for a dataset. With this generative model in place, new samples can be drawn. These new samples reflect the underlying model of the data. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/neighbors/plot_digits_kde_sampling.html" |
|
examples = [ |
|
[1, 44], |
|
[8, 22], |
|
[7, 51] |
|
] |
|
|
|
gr.Interface(generate_digits, inputs, output, title=title, description=description, examples=examples, live=True).launch() |
|
|