railway_3 / app.py
Sompote's picture
Upload app.py
438f042 verified
import gradio as gr
from ultralytics import YOLO
import numpy as np
import cv2
# Load models
model = YOLO("best-3.pt") # load a custom model for segmentation (protection zone)
model2 = YOLO('yolo11s.pt') # load a second model for object detection
def process_image(image):
# Gradio passes the image as RGB, so we'll work in RGB color space
image_rgb = np.array(image)
# Predict protection zone with the first model
segment_results = model(image_rgb) # predict segments
protection_mask = np.zeros(image_rgb.shape[:2], dtype=np.uint8) # create an empty mask
for result in segment_results:
if result.masks is not None:
for segment in result.masks.data:
segment_array = segment.cpu().numpy().astype(np.uint8)
segment_array = cv2.resize(segment_array, (image_rgb.shape[1], image_rgb.shape[0]))
protection_mask = cv2.bitwise_or(protection_mask, segment_array * 255)
# Create a copy of the original image to draw on
output_image = image_rgb.copy()
# Create a red overlay for the protection zone (using RGB)
protection_overlay = np.zeros_like(output_image)
protection_overlay[protection_mask > 0] = [255, 0, 0] # Red color in RGB
# Overlay the protection zone on the output image
output_image = cv2.addWeighted(output_image, 1, protection_overlay, 0.3, 0)
# Predict objects with the second model
object_results = model2(image_rgb) # predict objects using model2
for result in object_results:
boxes = result.boxes.xyxy.cpu().numpy().astype(int)
classes = result.boxes.cls.cpu().numpy()
names = result.names
for box, cls in zip(boxes, classes):
x1, y1, x2, y2 = box
# Check if the object is within the protection zone
object_mask = np.zeros(image_rgb.shape[:2], dtype=np.uint8)
object_mask[y1:y2, x1:x2] = 1 # create a mask for the object
# Check overlap
overlap = cv2.bitwise_and(protection_mask, object_mask)
is_inside = np.sum(overlap) > 0
# Red if in zone, green if outside (in RGB)
color = (255, 0, 0) if is_inside else (0, 255, 0)
# Draw bounding box around the object
cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
# If inside protection zone, display class name
if is_inside:
class_name = names[int(cls)]
label = f"{class_name}"
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(output_image, (x1, y1 - label_height - 5), (x1 + label_width, y1), color, -1)
cv2.putText(output_image, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
return output_image
# Define Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(),
outputs=gr.Image(label="Protection Zone and Detected Objects"),
title="Protection Zone and Object Detection",
description="Upload an image to detect protection zones (in red) and objects. Objects inside the protection zone are marked in red with their class name, while objects outside are marked in green."
)
# Launch the Gradio app
iface.launch()