Spaces:
Build error
Build error
import math | |
import cv2 | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from scipy import ndimage | |
from skimage import measure, color, io | |
from tensorflow.keras.preprocessing import image | |
from scipy import ndimage | |
import skimage.io as io | |
import skimage.transform as trans | |
import numpy as np | |
import tensorflow as tf | |
import gradio as gr | |
from huggingface_hub.keras_mixin import from_pretrained_keras | |
#Function that predicts on only 1 sample | |
def predict_sample(image): | |
prediction = model.predict(image[tf.newaxis, ...]) | |
prediction[prediction > 0.5 ] = 1 | |
prediction[prediction !=1] = 0 | |
result = prediction[0]*255 | |
return result | |
def create_input_image(data, visualize=False): | |
#Initialize input matrix | |
input = np.ones((256,256)) | |
#Fill matrix with data point values | |
for i in range(0,len(data)): | |
if math.floor(data[i][0]) < 256 and math.floor(data[i][1]) < 256: | |
input[math.floor(data[i][0])][math.floor(data[i][1])] = 0 | |
elif math.floor(data[i][0]) >= 256: | |
input[255][math.floor(data[i][1])] = 0 | |
elif math.floor(data[i][1]) >= 256: | |
input[math.floor(data[i][0])][255] = 0 | |
#Visualize | |
if visualize == True: | |
plt.imshow(input.T, cmap='gray') | |
plt.gca().invert_yaxis() | |
return input | |
model= from_pretrained_keras("tareknaous/unet-visual-clustering") | |
def get_instances(prediction, data, max_filter_size=1): | |
#Adjust format (clusters to be 255 and rest is 0) | |
prediction[prediction == 255] = 3 | |
prediction[prediction == 0] = 4 | |
prediction[prediction == 3] = 0 | |
prediction[prediction == 4] = 255 | |
#Convert to 8-bit image | |
prediction = image.img_to_array(prediction, dtype='uint8') | |
#Get 1 color channel | |
cells=prediction[:,:,0] | |
#Threshold | |
ret1, thresh = cv2.threshold(cells, 0, 255, cv2.THRESH_BINARY) | |
#Filter to remove noise | |
kernel = np.ones((3,3),np.uint8) | |
opening = cv2.morphologyEx(thresh,cv2.MORPH_OPEN,kernel, iterations = 2) | |
#Get the background | |
background = cv2.dilate(opening,kernel,iterations=5) | |
dist_transform = cv2.distanceTransform(opening,cv2.DIST_L2,5) | |
ret2, foreground = cv2.threshold(dist_transform,0.04*dist_transform.max(),255,0) | |
foreground = np.uint8(foreground) | |
unknown = cv2.subtract(background,foreground) | |
#Connected Component Analysis | |
ret3, markers = cv2.connectedComponents(foreground) | |
markers = markers+10 | |
markers[unknown==255] = 0 | |
#Watershed | |
img = cv2.merge((prediction,prediction,prediction)) | |
markers = cv2.watershed(img,markers) | |
img[markers == -1] = [0,255,255] | |
#Maximum filtering | |
markers = ndimage.maximum_filter(markers, size=max_filter_size) | |
# plt.imshow(markers.T, cmap='gray') | |
# plt.gca().invert_yaxis() | |
#Get an RGB colored image | |
img2 = color.label2rgb(markers, bg_label=1) | |
# plt.imshow(img2) | |
# plt.gca().invert_yaxis() | |
#Get regions | |
regions = measure.regionprops(markers, intensity_image=cells) | |
#Get Cluster IDs | |
cluster_ids = np.zeros(len(data)) | |
for i in range(0,len(cluster_ids)): | |
row = math.floor(data[i][0]) | |
column = math.floor(data[i][1]) | |
if row < 256 and column < 256: | |
cluster_ids[i] = markers[row][column] - 10 | |
elif row >= 256: | |
# cluster_ids[i] = markers[255][column] | |
cluster_ids[i] = 0 | |
elif column >= 256: | |
# cluster_ids[i] = markers[row][255] | |
cluster_ids[i] = 0 | |
cluster_ids = cluster_ids.astype('int8') | |
cluster_ids[cluster_ids == -11] = 0 | |
return cluster_ids | |
import gradio as gr | |
from itertools import cycle, islice | |
def visual_clustering(cluster_type, num_clusters, num_samples, random_state, median_kernel_size, max_kernel_size): | |
NUM_CLUSTERS = num_clusters | |
CLUSTER_STD = 4 * np.ones(NUM_CLUSTERS) | |
if cluster_type == "blobs": | |
data = datasets.make_blobs(n_samples=num_samples, centers=NUM_CLUSTERS, random_state=random_state,center_box=(0, 256), cluster_std=CLUSTER_STD) | |
elif cluster_type == "varied blobs": | |
cluster_std = 1.5 * np.ones(NUM_CLUSTERS) | |
data = datasets.make_blobs(n_samples=num_samples, centers=NUM_CLUSTERS, cluster_std=cluster_std, random_state=random_state) | |
elif cluster_type == "aniso": | |
X, y = datasets.make_blobs(n_samples=num_samples, centers=NUM_CLUSTERS, random_state=random_state, center_box=(-30, 30)) | |
transformation = [[0.8, -0.6], [-0.4, 0.8]] | |
X_aniso = np.dot(X, transformation) | |
data = (X_aniso, y) | |
elif cluster_type == "noisy moons": | |
data = datasets.make_moons(n_samples=num_samples, noise=.05) | |
elif cluster_type == "noisy circles": | |
data = datasets.make_circles(n_samples=num_samples, factor=.01, noise=.05) | |
max_x = max(data[0][:, 0]) | |
min_x = min(data[0][:, 0]) | |
new_max = 256 | |
new_min = 0 | |
data[0][:, 0] = (((data[0][:, 0] - min_x)*(new_max-new_min))/(max_x-min_x))+ new_min | |
max_y = max(data[0][:, 1]) | |
min_y = min(data[0][:, 1]) | |
new_max_y = 256 | |
new_min_y = 0 | |
data[0][:, 1] = (((data[0][:, 1] - min_y)*(new_max_y-new_min_y))/(max_y-min_y))+ new_min_y | |
fig1 = plt.figure() | |
plt.scatter(data[0][:, 0], data[0][:, 1], s=1, c='black') | |
plt.close() | |
input = create_input_image(data[0]) | |
filtered = ndimage.median_filter(input, size=median_kernel_size) | |
result = predict_sample(filtered) | |
y_km = get_instances(result, data[0], max_filter_size=max_kernel_size) | |
colors = np.array(list(islice(cycle(["#000000", '#377eb8', '#ff7f00', '#4daf4a', | |
'#f781bf', '#a65628', '#984ea3', | |
'#999999', '#e41a1c', '#dede00' ,'#491010']), | |
int(max(y_km) + 1)))) | |
#add black color for outliers (if any) | |
colors = np.append(colors, ["#000000"]) | |
fig2 = plt.figure() | |
plt.scatter(data[0][:, 0], data[0][:, 1], s=10, color=colors[y_km.astype('int8')]) | |
plt.close() | |
return fig1, fig2 | |
iface = gr.Interface( | |
fn=visual_clustering, | |
inputs=[ | |
gr.inputs.Dropdown(["blobs", "varied blobs", "aniso", "noisy moons", "noisy circles" ]), | |
gr.inputs.Slider(1, 10, step=1, label='Number of Clusters'), | |
gr.inputs.Slider(10000, 1000000, step=10000, label='Number of Samples'), | |
gr.inputs.Slider(1, 100, step=1, label='Random State'), | |
gr.inputs.Slider(1, 100, step=1, label='Denoising Filter Kernel Size'), | |
gr.inputs.Slider(1,100, step=1, label='Max Filter Kernel Size') | |
], | |
outputs=[ | |
gr.outputs.Image(type='plot', label='Dataset'), | |
gr.outputs.Image(type='plot', label='Clustering Result') | |
] | |
) | |
iface.launch(debug=True) |