yuragoithf's picture
Update app.py
6577280 verified
import os, io
import cv2
import gradio as gr
import tensorflow as tf
import numpy as np
import keras.backend as K
from matplotlib import pyplot as plt
from PIL import Image
from tensorflow import keras
resized_shape = (768, 768, 3)
IMG_SCALING = (1, 1)
# # Download the model file
# def download_model():
# url = "https://drive.google.com/uc?id=1FhICkeGn6GcNXWTDn1s83ctC-6Mo1UXk"
# output = "seg_unet_model.h5"
# gdown.download(url, output, quiet=False)
# return output
model_file = "./seg_unet_model.h5"
#Custom objects for model
def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
targets = tf.dtypes.cast(K.flatten(y_true), tf.float32)
inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32)
intersection = K.sum(targets * inputs)
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
inputs = K.clip(inputs, eps, 1.0 - eps)
out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs))))
weighted_ce = K.mean(out, axis=-1)
combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice)
return combo
def dice_coef(y_true, y_pred, smooth=1):
y_pred = tf.dtypes.cast(y_pred, tf.int32)
y_true = tf.dtypes.cast(y_true, tf.int32)
intersection = K.sum(y_true * y_pred, axis=[1,2,3])
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)
def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
return focal_loss_fixed
# Load the model
seg_model = keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef})
# inputs = gr.inputs.Image(type="pil", label="Upload an image")
# image_output = gr.outputs.Image(type="pil", label="Output Image")
# outputs = gr.outputs.HTML() #uncomment for single class output
rows = 1
columns = 1
def gen_pred(img, model=seg_model):
pil_image = img.convert('RGB')
open_cv_image = np.array(pil_image)
img = open_cv_image[:, :, ::-1].copy()
# img = cv2.imread("./003e2c95d.jpg")
img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img/255
img = tf.expand_dims(img, axis=0)
pred = model.predict(img)
pred = np.squeeze(pred, axis=0)
fig = plt.figure(figsize=(3, 3))
fig.add_subplot(rows, columns, 1)
# plt.imshow(pred, interpolation='catrom')
plt.imshow(pred)
plt.axis('off')
plt.show()
return fig
title = "<h1 style='text-align: center;'>Semantic Segmentation (Airbus Ship Detection Challenge)</h1>"
description = "Upload an image and get prediction mask"
gr.Interface(fn=gen_pred,
inputs=[gr.components.Image(type='pil')],
outputs=["plot"],
title=title,
examples=[["00c3db267.jpg"], ["00dc34840.jpg"], ["00371aa92.jpg"]],
description=description).launch()