import os os.system("pip install ultralytics") import cv2 import gradio as gr import numpy as np from PIL.Image import Image from ultralytics import SAM import warnings warnings.filterwarnings("ignore") class SAMModel: def __init__(self): model_path = 'mobile_sam.pt' self.model = SAM(model_path) def mobilesam_point_predict(self, image, x, y): result = self.model.predict(image, points=[x, y], labels=[1]) plotted = result[0].plot() plotted = cv2.cvtColor(np.array(plotted), cv2.COLOR_BGR2RGB) return plotted def mobile_bbox_predict(self, image: Image, bbox: str) -> np.ndarray: # Parse the bounding box string bbox_list = list(map(int, bbox.split(','))) # Predict a segment based on a box prompt result = self.model.predict(image, bboxes=bbox_list) plotted = result[0].plot() plotted = cv2.cvtColor(np.array(plotted), cv2.COLOR_BGR2RGB) return plotted def launch(self): """Launches the Gradio interface.""" # Create the UI with gr.Blocks() as app: # Header gr.Markdown("# SAM Model Demo") # Tabs with gr.Tabs(): # Point-predict-button Tab with gr.TabItem("point-predict"): with gr.Column(): inputs = [ gr.inputs.Image(type='pil', label='Input Image'), gr.inputs.Number(default=900, label='X Coordinate'), gr.inputs.Number(default=370, label='Y Coordinate'), ] output = gr.outputs.Image(type='pil', label='Output Image') point_predict_button = gr.Button("inference") # Run object detection on the input image when the button is clicked point_predict_button.click(self.mobilesam_point_predict, inputs=inputs, outputs=output) # Bbox-predict-button Tab with gr.TabItem("bbox-predict"): image_input = gr.inputs.Image(type='pil') text_input = gr.inputs.Textbox(lines=1, label="Bounding Box (x1, y1, x2, y2)", default="439, 437, 524, 709") image_output = gr.outputs.Image('pil') inputs = [image_input, text_input] output = image_output point_predict_button = gr.Button("inference") # Run object detection on the input image when the button is clicked point_predict_button.click(self.mobile_bbox_predict, inputs=inputs, outputs=output) app.launch() if __name__ == '__main__': web_ui = SAMModel() web_ui.launch()