drkareemkamal's picture
Update app.py
fd49d44 verified
import gradio as gr
from timeit import default_timer as timer
from typing import Tuple , Dict
import tensorflow as tf
import numpy as np
import cv2
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from PIL import Image
import os
# 1.Import and class names setup
IMG_H = 320
IMG_W = 416
NUM_CLASSES = 8
input_shape = (IMG_H,IMG_W,3)
# 2. Model annd transforms prepration
# model = tf.keras.models.load_model(
# 'oct_classification_final_model_lg.keras', custom_objects=None, compile=True, safe_mode=True
# )
model = "AROI_image_segmentation.keras"
# Load the model
model = load_model(model)
print(f"The model loaded successfully")
model.compile(
loss = 'categorical_crossentropy',
optimizer = tf.keras.optimizers.Adam(1e-4),
)
# Load save weights
# 3.prediction function (predict())
def load_and_prep_imgg(image_path , input_shape=[IMG_H,IMG_W], scale=True):
# if not isinstance(filename, str):
# raise ValueError("The filename must be a string representing the file path.")
# img = tf.io.read_file(filename)
# img = tf.io.decode_image(img, channels=3)
# img = tf.image.resize(img, size=[img_shape, img_shape])
# if scale:
# return img / 255
# else:
# return img
image = cv2.imread(image_path)
if image is None:
print(f"Error: Cannot load image from {image_path}")
return
image = cv2.resize(image, (input_shape[1], input_shape[0])) # Resize to input_shape
image = image / 255.0 # Normalize
image = image.astype(np.float32)
def predict(img) -> Tuple[Dict,float,float] :
start_time = timer()
image = load_and_prep_imgg(img)
pred_mask = model.predict(np.expand_dims(image, axis=0)) # Add batch dimension
pred_mask = np.argmax(pred_mask, axis=-1)[0] # Remove batch dimension and get class labels
# Plot the original image and the predicted mask
plt.figure(figsize=(10, 5))
# Plot the original image
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
# Plot the predicted mask
plt.subplot(1, 2, 2)
plt.imshow(pred_mask, cmap='jet') # Use a color map for the predicted mask
plt.title("Predicted Mask")
plt.axis("off")
# Add the colormap legend
legend_patches = get_legend_patches(colormap, class_names)
plt.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
# Display the plot
plt.tight_layout()
plt.show()
#pred_class = class_names[pred_img.argmax()]
#print(f"Predicted macular diseases is: {pred_class} with probability: {pred_img.max():.2f}")
pred_probbb = pred_img.max() * 100
end_time = timer()
pred_time = round(end_time - start_time , 4)
return pred_class , pred_probbb , pred_time
### 4. Gradio app - our Gradio interface + launch command
title = 'Macular Disease Classification'
description = 'Feature Extraction VGG model to classify Macular Diseases by OCT'
article = 'Created with TensorFlow Model Deployment'
# Create example list
example_list = [['examples/'+ example] for example in os.listdir('examples')]
example_list
# create a gradio demo
demo = gr.Interface(fn=predict ,
inputs=gr.Image(type='pil'),
outputs=[gr.Label(num_top_classes = 3 , label= 'prediction'),
gr.Number(label= 'Prediction Probabilities'),
gr.Number(label= 'Prediction time (s)')],
examples = example_list,
title = title,
description = description,
article= article)
# Launch the demo
demo.launch(debug= False)