amaye15 commited on
Commit
933c40c
·
1 Parent(s): 7e76bff

Sam 2 point prompt

Browse files
Files changed (2) hide show
  1. app.py +170 -9
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,21 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from gradio_image_prompter import ImagePrompter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Define the Gradio interface
5
  demo = gr.Interface(
6
- fn=lambda prompts: (
7
- prompts["image"],
8
- prompts["points"],
9
- ), # Extract image and points from the ImagePrompter
10
  inputs=ImagePrompter(
11
  show_label=False
12
  ), # ImagePrompter for image input and point selection
13
  outputs=[
14
- gr.Image(show_label=False),
15
- gr.Dataframe(label="Points"),
16
- ], # Outputs: Image and DataFrame of points
17
- title="Image Point Collector",
18
- description="Upload an image, click on it, and get the coordinates of the clicked points.",
19
  )
20
 
21
  # Launch the Gradio app
 
1
+ # import gradio as gr
2
+ # from gradio_image_prompter import ImagePrompter
3
+
4
+ # import os
5
+ # import torch
6
+
7
+
8
+ # def prompter(prompts):
9
+ # image = prompts["image"] # Get the image from prompts
10
+ # points = prompts["points"] # Get the points from prompts
11
+
12
+ # # Print the collected inputs for debugging or logging
13
+ # print("Image received:", image)
14
+ # print("Points received:", points)
15
+
16
+ # import torch
17
+ # from sam2.sam2_image_predictor import SAM2ImagePredictor
18
+
19
+ # device = torch.device("cpu")
20
+
21
+ # predictor = SAM2ImagePredictor.from_pretrained(
22
+ # "facebook/sam2-hiera-base-plus", device=device
23
+ # )
24
+
25
+ # with torch.inference_mode():
26
+ # predictor.set_image(image)
27
+ # # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points])
28
+ # input_point = [[point[0], point[1]] for point in points]
29
+ # input_label = [1]
30
+ # masks, _, _ = predictor.predict(
31
+ # point_coords=input_point, point_labels=input_label
32
+ # )
33
+ # print("Predicted Mask:", masks)
34
+
35
+ # return image, points
36
+
37
+
38
+ # # Define the Gradio interface
39
+ # demo = gr.Interface(
40
+ # fn=prompter, # Use the custom prompter function
41
+ # inputs=ImagePrompter(
42
+ # show_label=False
43
+ # ), # ImagePrompter for image input and point selection
44
+ # outputs=[
45
+ # gr.Image(show_label=False), # Display the image
46
+ # gr.Dataframe(label="Points"), # Display the points in a DataFrame
47
+ # ],
48
+ # title="Image Point Collector",
49
+ # description="Upload an image, click on it, and get the coordinates of the clicked points.",
50
+ # )
51
+
52
+ # # Launch the Gradio app
53
+ # demo.launch()
54
+
55
+
56
+ # import gradio as gr
57
+ # from gradio_image_prompter import ImagePrompter
58
+ # import torch
59
+ # from sam2.sam2_image_predictor import SAM2ImagePredictor
60
+
61
+
62
+ # def prompter(prompts):
63
+ # image = prompts["image"] # Get the image from prompts
64
+ # points = prompts["points"] # Get the points from prompts
65
+
66
+ # # Print the collected inputs for debugging or logging
67
+ # print("Image received:", image)
68
+ # print("Points received:", points)
69
+
70
+ # device = torch.device("cpu")
71
+
72
+ # # Load the SAM2ImagePredictor model
73
+ # predictor = SAM2ImagePredictor.from_pretrained(
74
+ # "facebook/sam2-hiera-base-plus", device=device
75
+ # )
76
+
77
+ # # Perform inference
78
+ # with torch.inference_mode():
79
+ # predictor.set_image(image)
80
+ # input_point = [[point[0], point[1]] for point in points]
81
+ # input_label = [1] * len(points) # Assuming all points are foreground
82
+ # masks, _, _ = predictor.predict(
83
+ # point_coords=input_point, point_labels=input_label
84
+ # )
85
+
86
+ # # The masks are returned as a list of numpy arrays
87
+ # print("Predicted Mask:", masks)
88
+
89
+ # # Assuming there's only one mask returned, you can adjust if there are multiple
90
+ # predicted_mask = masks[0]
91
+
92
+ # print(len(image))
93
+
94
+ # print(len(predicted_mask))
95
+
96
+ # # Create annotations for AnnotatedImage
97
+ # annotations = [(predicted_mask, "Predicted Mask")]
98
+
99
+ # return image, annotations
100
+
101
+
102
+ # # Define the Gradio interface
103
+ # demo = gr.Interface(
104
+ # fn=prompter, # Use the custom prompter function
105
+ # inputs=ImagePrompter(
106
+ # show_label=False
107
+ # ), # ImagePrompter for image input and point selection
108
+ # outputs=gr.AnnotatedImage(), # Display the image with the predicted mask
109
+ # title="Image Point Collector with Mask Overlay",
110
+ # description="Upload an image, click on it, and get the predicted mask overlayed on the image.",
111
+ # )
112
+
113
+ # # Launch the Gradio app
114
+ # demo.launch()
115
+
116
+
117
  import gradio as gr
118
  from gradio_image_prompter import ImagePrompter
119
+ import torch
120
+ import numpy as np
121
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
122
+ from PIL import Image
123
+
124
+
125
+ def prompter(prompts):
126
+ image = np.array(prompts["image"]) # Convert the image to a numpy array
127
+ points = prompts["points"] # Get the points from prompts
128
+
129
+ # Print the collected inputs for debugging or logging
130
+ print("Image received:", image)
131
+ print("Points received:", points)
132
+
133
+ device = torch.device("cpu")
134
+
135
+ # Load the SAM2ImagePredictor model
136
+ predictor = SAM2ImagePredictor.from_pretrained(
137
+ "facebook/sam2-hiera-base-plus", device=device
138
+ )
139
+
140
+ # Perform inference with multimask_output=True
141
+ with torch.inference_mode():
142
+ predictor.set_image(image)
143
+ input_point = [[point[0], point[1]] for point in points]
144
+ input_label = [1] * len(points) # Assuming all points are foreground
145
+ masks, _, _ = predictor.predict(
146
+ point_coords=input_point, point_labels=input_label, multimask_output=True
147
+ )
148
+
149
+ # Prepare individual images with separate overlays
150
+ overlay_images = []
151
+ for i, mask in enumerate(masks):
152
+ print(f"Predicted Mask {i+1}:", mask)
153
+ red_mask = np.zeros_like(image)
154
+ red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel
155
+ red_mask = Image.fromarray(red_mask)
156
+
157
+ # Convert the original image to a PIL image
158
+ original_image = Image.fromarray(image)
159
+
160
+ # Blend the original image with the red mask
161
+ blended_image = Image.blend(original_image, red_mask, alpha=0.5)
162
+
163
+ # Add the blended image to the list
164
+ overlay_images.append(blended_image)
165
+
166
+ return overlay_images
167
+
168
 
169
  # Define the Gradio interface
170
  demo = gr.Interface(
171
+ fn=prompter, # Use the custom prompter function
 
 
 
172
  inputs=ImagePrompter(
173
  show_label=False
174
  ), # ImagePrompter for image input and point selection
175
  outputs=[
176
+ gr.Image(show_label=False) for _ in range(3)
177
+ ], # Display up to 3 overlay images
178
+ title="Image Point Collector with Multiple Separate Mask Overlays",
179
+ description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.",
 
180
  )
181
 
182
  # Launch the Gradio app
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  gradio
2
  gradio-image-prompter
3
- Pillow
 
 
 
1
  gradio
2
  gradio-image-prompter
3
+ Pillow
4
+ opencv-python
5
+ git+https://github.com/facebookresearch/segment-anything-2.git