Spaces:
Runtime error
Runtime error
import gradio as gr | |
from sklearn.tree import DecisionTreeClassifier | |
from sklearn.ensemble import RandomForestClassifier | |
import joblib | |
from sklearn.ensemble import AdaBoostClassifier | |
from sklearn.ensemble import GradientBoostingClassifier | |
from sklearn.neighbors import KNeighborsClassifier | |
import numpy as np | |
import gradio as gr | |
# define image data type | |
input_image = gr.inputs.Image(label = "Input Image") | |
select_algorithm = gr.inputs.Dropdown(choices=["Decision Tree", "Random Forest", "AdaBoost", "Gradient Tree Boosting", "KNN"], label = "Select Algorithm") | |
out_classify = gr.outputs.Textbox(label = "Predict Class") | |
out_prob = gr.outputs.Textbox(label = "Predict Probability") | |
""" | |
gradio interface | |
""" | |
def predict_interface(input_image, select_algorithm): | |
""" | |
evaluate model | |
""" | |
# Convert image to NumPy array | |
print(input_image.shape) | |
input_image2 = input_image.mean(axis=2) | |
print(input_image2.shape) | |
img_array = input_image2.reshape(1, 28*28) | |
model_dict = {"Decision Tree":"best_dt_model.joblib", | |
"Random Forest":"best_rf_model.joblib", | |
"AdaBoost":"best_ada_model.joblib", | |
"Gradient Tree Boosting":"best_gbc_model.joblib", | |
"KNN":"best_knn_model.joblib"} | |
# Reload the best trained model from disk using joblib | |
loaded_model = joblib.load(model_dict[select_algorithm]) | |
# Use the reloaded model to make predictions on the validation data | |
out_classify = loaded_model.predict(img_array) | |
out_prob = loaded_model.predict_proba(img_array) | |
class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", | |
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] | |
out_prob2 = '\n'.join([f"{name}\t\t\t{np.round(pro,2)}" for name, pro in zip(class_names, out_prob[0])]) | |
return class_names[out_classify[0]], out_prob2 | |
gr.Interface(fn=predict_interface, inputs=[input_image, select_algorithm], | |
outputs=[out_classify, out_prob], | |
examples=[["fashion_1.png", "Random Forest"], | |
["fashion_2.png", "Random Forest"], | |
["fashion_3.png", "Random Forest"]] | |
).launch(debug=True) | |