import os import cv2 import gradio as gr import tensorflow as tf import urllib.request import numpy as np import keras.backend as K from PIL import Image from matplotlib import cm resized_shape = (768, 768, 3) IMG_SCALING = (1, 1) # # Download the model file # def download_model(): # url = "" # output = "seg_unet_model.h5" #, output, quiet=False) # return output model_file = "./seg_unet_model.h5" #Custom objects for model def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1): targets = tf.dtypes.cast(K.flatten(y_true), tf.float32) inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32) intersection = K.sum(targets * inputs) dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth) inputs = K.clip(inputs, eps, 1.0 - eps) out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs)))) weighted_ce = K.mean(out, axis=-1) combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice) return combo def dice_coef(y_true, y_pred, smooth=1): y_pred = tf.dtypes.cast(y_pred, tf.int32) y_true = tf.dtypes.cast(y_true, tf.int32) intersection = K.sum(y_true * y_pred, axis=[1,2,3]) union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) return K.mean((2 * intersection + smooth) / (union + smooth), axis=0) # Load the model seg_model = tf.keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'dice_coef': dice_coef}) inputs = gr.inputs.Image(type="pil", label="Upload an image") image_output = gr.outputs.Image(type="pil", label="Output Image") # outputs = gr.outputs.HTML() #uncomment for single class output def gen_pred(img, model=seg_model): # pil_image = img.convert('RGB') # open_cv_image = np.array(pil_image) # img = open_cv_image[:, :, ::-1].copy() # # img = cv2.imread("./003e2c95d.jpg") img = img[::IMG_SCALING[0], ::IMG_SCALING[1]] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img/255 img = tf.expand_dims(img, axis=0) pred = model.predict(img) pred = np.squeeze(pred, axis=0) print(pred) img_bytes = pred.tobytes() nparr = np.frombuffer(img_bytes, np.byte) pred_pil = cv2.imdecode(nparr, cv2.IMREAD_REDUCED_COLOR_8) print(pred_pil) # return "UI in developing process ..." return pred_pil title = "

Semantic Segmentation

" description = "Upload an image and get prediction mask" # css_code='body{background-image:url("file=wave.mp4");}' gr.Interface(fn=gen_pred, inputs=[gr.Image()], outputs='image', title=title, examples=[["003e2c95d.jpg"], ["003b50a15.jpg"], ["003b48a9e.jpg"], ["0038cbe45.jpg"], ["00371aa92.jpg"]], description=description, enable_queue=True).launch()