|
import os, io |
|
import cv2 |
|
import gradio as gr |
|
import tensorflow as tf |
|
import numpy as np |
|
import keras.backend as K |
|
|
|
from matplotlib import pyplot as plt |
|
from PIL import Image |
|
from tensorflow import keras |
|
|
|
|
|
resized_shape = (768, 768, 3) |
|
IMG_SCALING = (1, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_file = "./seg_unet_model.h5" |
|
|
|
|
|
|
|
def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1): |
|
targets = tf.dtypes.cast(K.flatten(y_true), tf.float32) |
|
inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32) |
|
intersection = K.sum(targets * inputs) |
|
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth) |
|
inputs = K.clip(inputs, eps, 1.0 - eps) |
|
out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs)))) |
|
weighted_ce = K.mean(out, axis=-1) |
|
combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice) |
|
return combo |
|
|
|
def dice_coef(y_true, y_pred, smooth=1): |
|
y_pred = tf.dtypes.cast(y_pred, tf.int32) |
|
y_true = tf.dtypes.cast(y_true, tf.int32) |
|
intersection = K.sum(y_true * y_pred, axis=[1,2,3]) |
|
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) |
|
return K.mean((2 * intersection + smooth) / (union + smooth), axis=0) |
|
|
|
def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25): |
|
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred)) |
|
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred)) |
|
focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon())) |
|
return focal_loss_fixed |
|
|
|
|
|
seg_model = keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef}) |
|
|
|
|
|
|
|
|
|
|
|
rows = 1 |
|
columns = 1 |
|
|
|
def gen_pred(img, model=seg_model): |
|
pil_image = img.convert('RGB') |
|
open_cv_image = np.array(pil_image) |
|
img = open_cv_image[:, :, ::-1].copy() |
|
|
|
img = img[::IMG_SCALING[0], ::IMG_SCALING[1]] |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img = img/255 |
|
img = tf.expand_dims(img, axis=0) |
|
pred = model.predict(img) |
|
pred = np.squeeze(pred, axis=0) |
|
fig = plt.figure(figsize=(3, 3)) |
|
fig.add_subplot(rows, columns, 1) |
|
|
|
plt.imshow(pred) |
|
plt.axis('off') |
|
plt.show() |
|
return fig |
|
|
|
title = "<h1 style='text-align: center;'>Semantic Segmentation (Airbus Ship Detection Challenge)</h1>" |
|
description = "Upload an image and get prediction mask" |
|
|
|
gr.Interface(fn=gen_pred, |
|
inputs=[gr.components.Image(type='pil')], |
|
outputs=["plot"], |
|
title=title, |
|
examples=[["00c3db267.jpg"], ["00dc34840.jpg"], ["00371aa92.jpg"]], |
|
description=description).launch() |