File size: 1,561 Bytes
9e8c5c7
2b20220
 
 
 
 
 
 
 
62aadf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
pip install --upgrade pip
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(<your_image>)
    masks, _, _ = predictor.predict(<input_prompts>)
import gradio as gr
import torch
import numpy as np
from PIL import Image
from segment_anything_2 import SAM2ImagePredictor, build_sam2

# Load your model
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = "checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
model = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(model)

def process_image(image, input_points, input_labels):
    input_point = np.array([input_points])
    input_label = np.array([input_labels])
    
    # Use predictor to predict mask
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
    return Image.fromarray(masks[0].astype(np.uint8))

# Define Gradio Interface
image_input = gr.inputs.Image(type="pil")
point_input = gr.inputs.Number(label="Point X,Y (comma-separated)")
label_input = gr.inputs.Radio([0, 1], label="Label (0 for background, 1 for object)")

iface = gr.Interface(
    fn=process_image,
    inputs=[image_input, point_input, label_input],
    outputs="image",
    description="Interactive tool for mask prediction with Segment Anything 2 and CUTIE"
)

iface.launch()