yunyangx commited on
Commit
9516ab6
·
1 Parent(s): f77a5b5

demo file with dependency

Browse files
Files changed (2) hide show
  1. app.py +78 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from torchvision.transforms import ToTensor
5
+ from PIL import Image
6
+
7
+ # loading EfficientSAM model
8
+ model_path = "efficientsam_s_cpu.jit"
9
+ with open(model_path, "rb") as f:
10
+ model = torch.jit.load(f)
11
+
12
+ # getting mask using points
13
+ def get_sam_mask_using_points(img_tensor, pts_sampled, model):
14
+ pts_sampled = torch.reshape(torch.tensor(pts_sampled), [1, 1, -1, 2])
15
+ max_num_pts = pts_sampled.shape[2]
16
+ pts_labels = torch.ones(1, 1, max_num_pts)
17
+
18
+ predicted_logits, predicted_iou = model(
19
+ img_tensor[None, ...],
20
+ pts_sampled,
21
+ pts_labels,
22
+ )
23
+ predicted_logits = predicted_logits.cpu()
24
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
25
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
26
+
27
+ max_predicted_iou = -1
28
+ selected_mask_using_predicted_iou = None
29
+ for m in range(all_masks.shape[0]):
30
+ curr_predicted_iou = predicted_iou[m]
31
+ if (
32
+ curr_predicted_iou > max_predicted_iou
33
+ or selected_mask_using_predicted_iou is None
34
+ ):
35
+ max_predicted_iou = curr_predicted_iou
36
+ selected_mask_using_predicted_iou = all_masks[m]
37
+ return selected_mask_using_predicted_iou
38
+
39
+ # examples
40
+ examples = [["examples/image1.jpg"], ["examples/image2.jpg"], ["examples/image3.jpg"], ["examples/image4.jpg"],
41
+ ["examples/image5.jpg"], ["examples/image6.jpg"], ["examples/image7.jpg"], ["examples/image8.jpg"],
42
+ ["examples/image9.jpg"], ["examples/image10.jpg"], ["examples/image11.jpg"], ["examples/image12.jpg"]
43
+ ["examples/image13.jpg"], ["examples/image14.jpg"]]
44
+
45
+
46
+ with gr.Blocks() as demo:
47
+ with gr.Row():
48
+ input_img = gr.Image(label="Input",height=512)
49
+ output_img = gr.Image(label="Selected Segment",height=512)
50
+
51
+ with gr.Row():
52
+ gr.Markdown("Try some of the examples below ⬇️")
53
+ gr.Examples(examples=examples,
54
+ inputs=[input_img])
55
+
56
+ def get_select_coords(img, evt: gr.SelectData):
57
+ img_tensor = ToTensor()(img)
58
+ _, H, W = img_tensor.shape
59
+
60
+ visited_pixels = set()
61
+ pixels_in_queue = set()
62
+ pixels_in_segment = set()
63
+
64
+ mask = get_sam_mask_using_points(img_tensor, [[evt.index[0], evt.index[1]]], model)
65
+
66
+ out = img.copy()
67
+
68
+ out = out.astype(np.uint8)
69
+ out *= mask[:,:,None]
70
+ for pixel in pixels_in_segment:
71
+ out[pixel[0], pixel[1]] = img[pixel[0], pixel[1]]
72
+ print(out)
73
+ return out
74
+
75
+ input_img.select(get_select_coords, [input_img], output_img)
76
+
77
+ if __name__ == "__main__":
78
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ gradio
4
+ transformers==4.32.0
5
+ opencv-python
6
+ pandas==2.0.3
7
+ matplotlib==3.7.2