Satyajithchary commited on
Commit
62aadf9
·
verified ·
1 Parent(s): 54b67cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from segment_anything_2 import SAM2ImagePredictor, build_sam2
6
+
7
+ # Load your model
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ checkpoint = "checkpoints/sam2_hiera_large.pt"
10
+ model_cfg = "sam2_hiera_l.yaml"
11
+ model = build_sam2(model_cfg, checkpoint, device=device)
12
+ predictor = SAM2ImagePredictor(model)
13
+
14
+ def process_image(image, input_points, input_labels):
15
+ input_point = np.array([input_points])
16
+ input_label = np.array([input_labels])
17
+
18
+ # Use predictor to predict mask
19
+ masks, scores, logits = predictor.predict(
20
+ point_coords=input_point,
21
+ point_labels=input_label,
22
+ multimask_output=True,
23
+ )
24
+ return Image.fromarray(masks[0].astype(np.uint8))
25
+
26
+ # Define Gradio Interface
27
+ image_input = gr.inputs.Image(type="pil")
28
+ point_input = gr.inputs.Number(label="Point X,Y (comma-separated)")
29
+ label_input = gr.inputs.Radio([0, 1], label="Label (0 for background, 1 for object)")
30
+
31
+ iface = gr.Interface(
32
+ fn=process_image,
33
+ inputs=[image_input, point_input, label_input],
34
+ outputs="image",
35
+ description="Interactive tool for mask prediction with Segment Anything 2 and CUTIE"
36
+ )
37
+
38
+ iface.launch()