|
import gradio as gr |
|
import tensorflow as tf |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
model_path = "pokemon_classifier_model.keras" |
|
model = tf.keras.models.load_model(model_path) |
|
|
|
|
|
classes = ['Aerodactyl', 'Alakazam', 'Beedrill'] |
|
|
|
|
|
def classify_image(image): |
|
try: |
|
|
|
if image.ndim == 2: |
|
image = np.stack((image,)*3, axis=-1) |
|
elif image.shape[2] == 4: |
|
image = image[:, :, :3] |
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
image = image.resize((150, 150)) |
|
|
|
image_array = np.array(image) / 255.0 |
|
image_array = np.expand_dims(image_array, axis=0) |
|
|
|
|
|
prediction = model.predict(image_array) |
|
predicted_class = classes[np.argmax(prediction)] |
|
confidence = np.max(prediction) |
|
|
|
return f"Predicted Pokémon: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%" |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
input_image = gr.Image() |
|
output_label = gr.Label() |
|
|
|
interface = gr.Interface(fn=classify_image, |
|
inputs=input_image, |
|
outputs=output_label, |
|
examples=["pokemon/aerodactyl.png", "pokemon/alakazam.png", "pokemon/beedrill.png"], |
|
description="Upload an image of a Pokémon (Aerodactyl, Alakazam or Beedrill) to classify!") |
|
|
|
interface.launch() |