franckew commited on
Commit
fcd73ff
·
verified ·
1 Parent(s): 90fcef0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -8,19 +8,19 @@ from tensorflow.keras.models import load_model
8
  # Load your trained model
9
  model = load_model('/home/user/app/resnet50.h5') # Ensure this path is correct
10
 
11
- def predict_character(img):
12
- img = Image.fromarray(img.astype('uint8'), 'RGB') # Ensure the image is in RGB format
13
- img = img.resize((299, 299)) # Resize the image to the required size for Xception
14
  img_array = keras_image.img_to_array(img) # Convert the image to an array
15
- img_array = np.expand_dims(img_array, axis=0) # Expand dimensions to match the model input
16
- img_array = preprocess_input(img_array) # Preprocess the input for Xception
17
 
18
- prediction = model.predict(img_array) # Make a prediction with the model
19
- classes = ['bishop', 'knight', 'rook'] # Specific character names
20
  return {classes[i]: float(prediction[0][i]) for i in range(3)} # Return the prediction
21
 
22
- # Define the Gradio interface
23
- interface = gr.Interface(fn=predict_character,
24
  inputs="image", # Simplified input type
25
  outputs="label", # Simplified output type
26
  title="Chess Piece Classifier",
 
8
  # Load your trained model
9
  model = load_model('/home/user/app/resnet50.h5') # Ensure this path is correct
10
 
11
+ def predict_pokemon(img):
12
+ img = Image.fromarray(img.astype('uint8'), 'RGB') # Ensure the image is in RGB
13
+ img = img.resize((224, 224)) # Resize the image properly using PIL
14
  img_array = keras_image.img_to_array(img) # Convert the image to an array
15
+ img_array = np.expand_dims(img_array, axis=0) # Expand dimensions to fit model input
16
+ img_array = preprocess_input(img_array) # Preprocess the input as expected by ResNet50
17
 
18
+ prediction = model.predict(img_array) # Predict using the model
19
+ classes = ['bishop', 'knight', 'rook' ] # Specific Pokémon names
20
  return {classes[i]: float(prediction[0][i]) for i in range(3)} # Return the prediction
21
 
22
+ # Define Gradio interface
23
+ interface = gr.Interface(fn=predict_pokemon,
24
  inputs="image", # Simplified input type
25
  outputs="label", # Simplified output type
26
  title="Chess Piece Classifier",