import gradio as gr from timeit import default_timer as timer from typing import Tuple , Dict import tensorflow as tf import numpy as np import cv2 from tensorflow.keras.models import load_model import matplotlib.pyplot as plt import matplotlib.patches as mpatches from PIL import Image import os # 1.Import and class names setup IMG_H = 320 IMG_W = 416 NUM_CLASSES = 8 input_shape = (IMG_H,IMG_W,3) # 2. Model annd transforms prepration # model = tf.keras.models.load_model( # 'oct_classification_final_model_lg.keras', custom_objects=None, compile=True, safe_mode=True # ) model = "AROI_image_segmentation.keras" # Load the model model = load_model(model) print(f"The model loaded successfully") model.compile( loss = 'categorical_crossentropy', optimizer = tf.keras.optimizers.Adam(1e-4), ) # Load save weights # 3.prediction function (predict()) def load_and_prep_imgg(image_path , input_shape=[IMG_H,IMG_W], scale=True): # if not isinstance(filename, str): # raise ValueError("The filename must be a string representing the file path.") # img = tf.io.read_file(filename) # img = tf.io.decode_image(img, channels=3) # img = tf.image.resize(img, size=[img_shape, img_shape]) # if scale: # return img / 255 # else: # return img image = cv2.imread(image_path) if image is None: print(f"Error: Cannot load image from {image_path}") return image = cv2.resize(image, (input_shape[1], input_shape[0])) # Resize to input_shape image = image / 255.0 # Normalize image = image.astype(np.float32) def predict(img) -> Tuple[Dict,float,float] : start_time = timer() image = load_and_prep_imgg(img) pred_mask = model.predict(np.expand_dims(image, axis=0)) # Add batch dimension pred_mask = np.argmax(pred_mask, axis=-1)[0] # Remove batch dimension and get class labels # Plot the original image and the predicted mask plt.figure(figsize=(10, 5)) # Plot the original image plt.subplot(1, 2, 1) plt.imshow(image) plt.title("Original Image") plt.axis("off") # Plot the predicted mask plt.subplot(1, 2, 2) plt.imshow(pred_mask, cmap='jet') # Use a color map for the predicted mask plt.title("Predicted Mask") plt.axis("off") # Add the colormap legend legend_patches = get_legend_patches(colormap, class_names) plt.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.) # Display the plot plt.tight_layout() plt.show() #pred_class = class_names[pred_img.argmax()] #print(f"Predicted macular diseases is: {pred_class} with probability: {pred_img.max():.2f}") pred_probbb = pred_img.max() * 100 end_time = timer() pred_time = round(end_time - start_time , 4) return pred_class , pred_probbb , pred_time ### 4. Gradio app - our Gradio interface + launch command title = 'Macular Disease Classification' description = 'Feature Extraction VGG model to classify Macular Diseases by OCT' article = 'Created with TensorFlow Model Deployment' # Create example list example_list = [['examples/'+ example] for example in os.listdir('examples')] example_list # create a gradio demo demo = gr.Interface(fn=predict , inputs=gr.Image(type='pil'), outputs=[gr.Label(num_top_classes = 3 , label= 'prediction'), gr.Number(label= 'Prediction Probabilities'), gr.Number(label= 'Prediction time (s)')], examples = example_list, title = title, description = description, article= article) # Launch the demo demo.launch(debug= False)