kadirnar's picture
Update app.py (#7)
c992a7d
raw
history blame
6.24 kB
from sahi import utils, predict, AutoDetectionModel
from PIL import Image
import gradio as gr
import numpy
import torch
model_id_list = ['deprem-ml/Binafarktespit-yolo5x-v1-xview', 'SerdarHelli/deprem_satellite_labeled_yolov8', 'kadirnar/yolov7-v0.1', 'kadirnar/UNet-EfficientNet-b6-Istanbul']
current_device = "cuda" if torch.cuda.is_available() else "cpu"
model_types = ["YOLOv5", "YOLOv5 + SAHI", "YOLOv8", "YOLOv7", "Unet-Istanbul"]
def sahi_yolov5_inference(
image,
model_id,
model_type,
image_size,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.1,
overlap_width_ratio=0.1,
postprocess_type="NMS",
postprocess_match_metric="IOU",
postprocess_match_threshold=0.25,
postprocess_class_agnostic=False,
):
rect_th = None or max(round(sum(image.size) / 2 * 0.0001), 1)
text_th = None or max(rect_th - 2, 1)
if model_type == "YOLOv5":
# standard inference
model = AutoDetectionModel.from_pretrained(
model_type="yolov5",
model_path=model_id,
device=current_device,
confidence_threshold=0.5,
image_size=image_size,
)
prediction_result_1 = predict.get_prediction(
image=image, detection_model=model
)
visual_result_1 = utils.cv.visualize_object_predictions(
image=numpy.array(image),
object_prediction_list=prediction_result_1.object_prediction_list,
rect_th=rect_th,
text_th=text_th,
)
output = Image.fromarray(visual_result_1["image"])
return output
elif model_type == "YOLOv5 + SAHI":
model = AutoDetectionModel.from_pretrained(
model_type="yolov5",
model_path=model_id,
device=current_device,
confidence_threshold=0.5,
image_size=image_size,
)
prediction_result_2 = predict.get_sliced_prediction(
image=image,
detection_model=model,
slice_height=int(slice_height),
slice_width=int(slice_width),
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
postprocess_type=postprocess_type,
postprocess_match_metric=postprocess_match_metric,
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
)
visual_result_2 = utils.cv.visualize_object_predictions(
image=numpy.array(image),
object_prediction_list=prediction_result_2.object_prediction_list,
rect_th=rect_th,
text_th=text_th,
)
output = Image.fromarray(visual_result_2["image"])
return output
elif model_type == "YOLOv8":
from ultralyticsplus import YOLO, render_result
model = YOLO('SerdarHelli/deprem_satellite_labeled_yolov8')
result = model.predict(image, imgsz=image_size)[0]
render = render_result(model=model, image=image, result=result)
return render
elif model_type == "YOLOv7":
import yolov7
model = yolov7.load(model_id, device="cuda:0", hf_model=True, trace=False)
results = model([image], size=image_size)
return results.render()[0]
elif model_type == "Unet-Istanbul":
from istanbul_unet import unet_prediction
output = unet_prediction(input_path=image, model_path=model_id)
return output
inputs = [
gr.Image(type="pil", label="Original Image"),
gr.Dropdown(choices=model_id_list,label="Choose Model",value=model_id_list[0]),
gr.Dropdown( choices=model_types, label="Choose Model Type", value=model_types[1]),
gr.Slider(minimum=128, maximum=2048, value=640, step=32, label="Image Size"),
gr.Slider(minimum=128, maximum=2048, value=512, step=32, label="Slice Height"),
gr.Slider(minimum=128, maximum=2048, value=512, step=32, label="Slice Width"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Overlap Height Ratio"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Overlap Width Ratio"),
gr.Dropdown(["NMS", "GREEDYNMM"], type="value", value="NMS", label="Postprocess Type"),
gr.Dropdown(["IOU", "IOS"], type="value", value="IOU", label="Postprocess Type"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="Postprocess Match Threshold"),
gr.Checkbox(value=True, label="Postprocess Class Agnostic"),
]
outputs = [gr.outputs.Image(type="pil", label="Output")]
title = "Building Detection from Satellite Images with SAHI + YOLOv5"
description = "SAHI + YOLOv5 demo for building detection from satellite images. Upload an image or click an example image to use."
article = "<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection/ instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> | <a href='https://github.com/fcakyon/yolov5-pip'>YOLOv5 Github</a> </p>"
examples = [
["data/26.jpg", 'deprem-ml/Binafarktespit-yolo5x-v1-xview', "YOLOv5 + SAHI", 640, 512, 512, 0.1, 0.1, "NMS", "IOU", 0.25, False],
["data/27.jpg", 'deprem-ml/Binafarktespit-yolo5x-v1-xview', "YOLOv5 + SAHI", 640, 512, 512, 0.1, 0.1, "NMS", "IOU", 0.25, False],
["data/28.jpg", 'deprem-ml/Binafarktespit-yolo5x-v1-xview', "YOLOv5 + SAHI", 640, 512, 512, 0.1, 0.1, "NMS", "IOU", 0.25, False],
["data/31.jpg", 'deprem-ml/SerdarHelli-yolov8-v1-xview', "YOLOv8", 640, 512, 512, 0.1, 0.1, "NMS", "IOU", 0.25, False],
["data/Istanbul.jpg", 'kadirnar/UNet-EfficientNet-b6-Istanbul', "Unet-Istanbul", 512, 512, 512, 0.1, 0.1, "NMS", "IOU", 0.25, False],
]
demo = gr.Interface(
sahi_yolov5_inference,
inputs,
outputs,
title=title,
description=description,
article=article,
examples=examples,
theme="huggingface",
cache_examples=True,
)
demo.launch(debug=True, enable_queue=True)