|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
model_path = hf_hub_download(repo_id="facebook/sam2-hiera-large", filename="sam2_hiera_large.pth") |
|
|
|
|
|
predictor = SAM2ImagePredictor.from_pretrained(model_path) |
|
|
|
def segment_image(input_image, x, y): |
|
|
|
input_image = Image.fromarray(input_image.astype('uint8'), 'RGB') |
|
|
|
|
|
predictor.set_image(input_image) |
|
|
|
|
|
input_point = np.array([[x, y]]) |
|
input_label = np.array([1]) |
|
|
|
|
|
with torch.inference_mode(): |
|
masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label) |
|
|
|
|
|
mask = masks[0].cpu().numpy() |
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
|
|
|
result = Image.composite(input_image, Image.new('RGB', input_image.size, 'black'), mask_image) |
|
|
|
return result |
|
|
|
|
|
iface = gr.Interface( |
|
fn=segment_image, |
|
inputs=[ |
|
gr.Image(type="numpy"), |
|
gr.Slider(0, 1000, label="X coordinate"), |
|
gr.Slider(0, 1000, label="Y coordinate") |
|
], |
|
outputs=gr.Image(type="pil"), |
|
title="SAM2 Image Segmentation", |
|
description="Upload an image and select a point to segment. Adjust X and Y coordinates to refine the selection." |
|
) |
|
|
|
|
|
iface.launch() |