import gradio as gr import tensorflow as tf from PIL import Image import numpy as np # Load the pre-trained Pokémon model model_path = "pokemon_classifier_model.keras" model = tf.keras.models.load_model(model_path) # Define the Pokémon classes classes = ['Aerodactyl', 'Alakazam', 'Beedrill'] # Adjust these as per your model's classes # Define the image classification function def classify_image(image): try: # Ensure the image is in RGB and normalize it if image.ndim == 2: # Check if the image is grayscale image = np.stack((image,)*3, axis=-1) # Convert grayscale to RGB by repeating the gray channel elif image.shape[2] == 4: # Check if the image has an alpha channel image = image[:, :, :3] # Drop the alpha channel image = Image.fromarray(image.astype('uint8'), 'RGB') # Convert to PIL Image to resize image = image.resize((150, 150)) # Resize to match the model's input size image_array = np.array(image) / 255.0 # Convert to array and normalize image_array = np.expand_dims(image_array, axis=0) # Add batch dimension # Predict using the model prediction = model.predict(image_array) predicted_class = classes[np.argmax(prediction)] confidence = np.max(prediction) return f"Predicted Pokémon: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%" except Exception as e: return str(e) # Return the error message if something goes wrong # Create Gradio interface input_image = gr.Image() # Using Gradio's Image component correctly output_label = gr.Label() interface = gr.Interface(fn=classify_image, inputs=input_image, outputs=output_label, examples=["pokemon/aerodactyl.png", "pokemon/alakazam.png", "pokemon/beedrill.png"], description="Upload an image of a Pokémon (Aerodactyl, Alakazam or Beedrill) to classify!") interface.launch()