import gradio as gr import torch import numpy as np from PIL import Image from segment_anything_2 import SAM2ImagePredictor, build_sam2 # Load your model device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = "checkpoints/sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" model = build_sam2(model_cfg, checkpoint, device=device) predictor = SAM2ImagePredictor(model) def process_image(image, input_points, input_labels): input_point = np.array([input_points]) input_label = np.array([input_labels]) # Use predictor to predict mask masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) return Image.fromarray(masks[0].astype(np.uint8)) # Define Gradio Interface image_input = gr.inputs.Image(type="pil") point_input = gr.inputs.Number(label="Point X,Y (comma-separated)") label_input = gr.inputs.Radio([0, 1], label="Label (0 for background, 1 for object)") iface = gr.Interface( fn=process_image, inputs=[image_input, point_input, label_input], outputs="image", description="Interactive tool for mask prediction with Segment Anything 2 and CUTIE" ) iface.launch()