import gradio as gr from ultralytics import YOLO import numpy as np import cv2 # Load models model = YOLO("best-3.pt") # load a custom model for segmentation (protection zone) model2 = YOLO('yolo11s.pt') # load a second model for object detection def process_image(image): # Gradio passes the image as RGB, so we'll work in RGB color space image_rgb = np.array(image) # Predict protection zone with the first model segment_results = model(image_rgb) # predict segments protection_mask = np.zeros(image_rgb.shape[:2], dtype=np.uint8) # create an empty mask for result in segment_results: if result.masks is not None: for segment in result.masks.data: segment_array = segment.cpu().numpy().astype(np.uint8) segment_array = cv2.resize(segment_array, (image_rgb.shape[1], image_rgb.shape[0])) protection_mask = cv2.bitwise_or(protection_mask, segment_array * 255) # Create a copy of the original image to draw on output_image = image_rgb.copy() # Create a red overlay for the protection zone (using RGB) protection_overlay = np.zeros_like(output_image) protection_overlay[protection_mask > 0] = [255, 0, 0] # Red color in RGB # Overlay the protection zone on the output image output_image = cv2.addWeighted(output_image, 1, protection_overlay, 0.3, 0) # Predict objects with the second model object_results = model2(image_rgb) # predict objects using model2 for result in object_results: boxes = result.boxes.xyxy.cpu().numpy().astype(int) classes = result.boxes.cls.cpu().numpy() names = result.names for box, cls in zip(boxes, classes): x1, y1, x2, y2 = box # Check if the object is within the protection zone object_mask = np.zeros(image_rgb.shape[:2], dtype=np.uint8) object_mask[y1:y2, x1:x2] = 1 # create a mask for the object # Check overlap overlap = cv2.bitwise_and(protection_mask, object_mask) is_inside = np.sum(overlap) > 0 # Red if in zone, green if outside (in RGB) color = (255, 0, 0) if is_inside else (0, 255, 0) # Draw bounding box around the object cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2) # If inside protection zone, display class name if is_inside: class_name = names[int(cls)] label = f"{class_name}" (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) cv2.rectangle(output_image, (x1, y1 - label_height - 5), (x1 + label_width, y1), color, -1) cv2.putText(output_image, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) return output_image # Define Gradio interface iface = gr.Interface( fn=process_image, inputs=gr.Image(), outputs=gr.Image(label="Protection Zone and Detected Objects"), title="Protection Zone and Object Detection", description="Upload an image to detect protection zones (in red) and objects. Objects inside the protection zone are marked in red with their class name, while objects outside are marked in green." ) # Launch the Gradio app iface.launch()