import gradio as gr import torch from PIL import Image import numpy as np from sam2.sam2_image_predictor import SAM2ImagePredictor from huggingface_hub import hf_hub_download # Download the model weights model_path = hf_hub_download(repo_id="facebook/sam2-hiera-large", filename="sam2_hiera_large.pth") # Initialize the SAM2 predictor predictor = SAM2ImagePredictor.from_pretrained(model_path) def segment_image(input_image, x, y): # Convert gradio image to PIL Image input_image = Image.fromarray(input_image.astype('uint8'), 'RGB') # Prepare the image for the model predictor.set_image(input_image) # Prepare the prompt (point) input_point = np.array([[x, y]]) input_label = np.array([1]) # 1 for foreground # Generate the mask with torch.inference_mode(): masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label) # Convert the mask to an image mask = masks[0].cpu().numpy() mask_image = Image.fromarray((mask * 255).astype(np.uint8)) # Apply the mask to the original image result = Image.composite(input_image, Image.new('RGB', input_image.size, 'black'), mask_image) return result # Create the Gradio interface iface = gr.Interface( fn=segment_image, inputs=[ gr.Image(type="numpy"), gr.Slider(0, 1000, label="X coordinate"), gr.Slider(0, 1000, label="Y coordinate") ], outputs=gr.Image(type="pil"), title="SAM2 Image Segmentation", description="Upload an image and select a point to segment. Adjust X and Y coordinates to refine the selection." ) # Launch the app iface.launch()