dillonlaird commited on
Commit
cdab98d
·
1 Parent(s): 8040aeb

added sam_hq

Browse files
app/sam_hq DELETED
@@ -1 +0,0 @@
1
- Subproject commit 759948f24f5524e5946bc274b4086ff1cc13b676
 
 
app/sam_hq/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .build_sam import build_sam
2
+ from .predictor import SamPredictor
app/sam_hq/build_sam.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .image_encoder import ImageEncoderViT
12
+ from .mask_decoder_hq import MaskDecoderHQ
13
+ from .prompt_encoder import PromptEncoder
14
+ from .transformer import TwoWayTransformer
15
+ from .sam import Sam
16
+
17
+
18
+ def build_sam_vit_h(checkpoint=None):
19
+ return _build_sam(
20
+ encoder_embed_dim=1280,
21
+ encoder_depth=32,
22
+ encoder_num_heads=16,
23
+ encoder_global_attn_indexes=[7, 15, 23, 31],
24
+ checkpoint=checkpoint,
25
+ )
26
+
27
+
28
+ build_sam = build_sam_vit_h
29
+
30
+
31
+ def build_sam_vit_l(checkpoint=None):
32
+ return _build_sam(
33
+ encoder_embed_dim=1024,
34
+ encoder_depth=24,
35
+ encoder_num_heads=16,
36
+ encoder_global_attn_indexes=[5, 11, 17, 23],
37
+ checkpoint=checkpoint,
38
+ )
39
+
40
+
41
+ def build_sam_vit_b(checkpoint=None):
42
+ return _build_sam(
43
+ encoder_embed_dim=768,
44
+ encoder_depth=12,
45
+ encoder_num_heads=12,
46
+ encoder_global_attn_indexes=[2, 5, 8, 11],
47
+ checkpoint=checkpoint,
48
+ )
49
+
50
+
51
+ sam_model_registry = {
52
+ "default": build_sam_vit_h,
53
+ "vit_h": build_sam_vit_h,
54
+ "vit_l": build_sam_vit_l,
55
+ "vit_b": build_sam_vit_b,
56
+ }
57
+
58
+
59
+ def _build_sam(
60
+ encoder_embed_dim,
61
+ encoder_depth,
62
+ encoder_num_heads,
63
+ encoder_global_attn_indexes,
64
+ checkpoint=None,
65
+ ):
66
+ prompt_embed_dim = 256
67
+ image_size = 1024
68
+ vit_patch_size = 16
69
+ image_embedding_size = image_size // vit_patch_size
70
+ sam = Sam(
71
+ image_encoder=ImageEncoderViT(
72
+ depth=encoder_depth,
73
+ embed_dim=encoder_embed_dim,
74
+ img_size=image_size,
75
+ mlp_ratio=4,
76
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
77
+ num_heads=encoder_num_heads,
78
+ patch_size=vit_patch_size,
79
+ qkv_bias=True,
80
+ use_rel_pos=True,
81
+ global_attn_indexes=encoder_global_attn_indexes,
82
+ window_size=14,
83
+ out_chans=prompt_embed_dim,
84
+ ),
85
+ prompt_encoder=PromptEncoder(
86
+ embed_dim=prompt_embed_dim,
87
+ image_embedding_size=(image_embedding_size, image_embedding_size),
88
+ input_image_size=(image_size, image_size),
89
+ mask_in_chans=16,
90
+ ),
91
+ mask_decoder=MaskDecoderHQ(
92
+ num_multimask_outputs=3,
93
+ transformer=TwoWayTransformer(
94
+ depth=2,
95
+ embedding_dim=prompt_embed_dim,
96
+ mlp_dim=2048,
97
+ num_heads=8,
98
+ ),
99
+ transformer_dim=prompt_embed_dim,
100
+ iou_head_depth=3,
101
+ iou_head_hidden_dim=256,
102
+ vit_dim=encoder_embed_dim,
103
+ ),
104
+ pixel_mean=[123.675, 116.28, 103.53],
105
+ pixel_std=[58.395, 57.12, 57.375],
106
+ )
107
+ # sam.eval()
108
+ if checkpoint is not None:
109
+ with open(checkpoint, "rb") as f:
110
+ state_dict = torch.load(f, map_location=torch.device("cpu"))
111
+ info = sam.load_state_dict(state_dict, strict=False)
112
+ print(info)
113
+ for n, p in sam.named_parameters():
114
+ if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
115
+ p.requires_grad = False
116
+
117
+ return sam
app/sam_hq/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from typing import Type
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
app/sam_hq/image_encoder.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d, MLPBlock
14
+
15
+
16
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17
+ class ImageEncoderViT(nn.Module):
18
+ def __init__(
19
+ self,
20
+ img_size: int = 1024,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ depth: int = 12,
25
+ num_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ out_chans: int = 256,
28
+ qkv_bias: bool = True,
29
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
30
+ act_layer: Type[nn.Module] = nn.GELU,
31
+ use_abs_pos: bool = True,
32
+ use_rel_pos: bool = False,
33
+ rel_pos_zero_init: bool = True,
34
+ window_size: int = 0,
35
+ global_attn_indexes: Tuple[int, ...] = (),
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ img_size (int): Input image size.
40
+ patch_size (int): Patch size.
41
+ in_chans (int): Number of input image channels.
42
+ embed_dim (int): Patch embedding dimension.
43
+ depth (int): Depth of ViT.
44
+ num_heads (int): Number of attention heads in each ViT block.
45
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
47
+ norm_layer (nn.Module): Normalization layer.
48
+ act_layer (nn.Module): Activation layer.
49
+ use_abs_pos (bool): If True, use absolute positional embeddings.
50
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52
+ window_size (int): Window size for window attention blocks.
53
+ global_attn_indexes (list): Indexes for blocks using global attention.
54
+ """
55
+ super().__init__()
56
+ self.img_size = img_size
57
+
58
+ self.patch_embed = PatchEmbed(
59
+ kernel_size=(patch_size, patch_size),
60
+ stride=(patch_size, patch_size),
61
+ in_chans=in_chans,
62
+ embed_dim=embed_dim,
63
+ )
64
+
65
+ self.pos_embed: Optional[nn.Parameter] = None
66
+ if use_abs_pos:
67
+ # Initialize absolute positional embedding with pretrain image size.
68
+ self.pos_embed = nn.Parameter(
69
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70
+ )
71
+
72
+ self.blocks = nn.ModuleList()
73
+ for i in range(depth):
74
+ block = Block(
75
+ dim=embed_dim,
76
+ num_heads=num_heads,
77
+ mlp_ratio=mlp_ratio,
78
+ qkv_bias=qkv_bias,
79
+ norm_layer=norm_layer,
80
+ act_layer=act_layer,
81
+ use_rel_pos=use_rel_pos,
82
+ rel_pos_zero_init=rel_pos_zero_init,
83
+ window_size=window_size if i not in global_attn_indexes else 0,
84
+ input_size=(img_size // patch_size, img_size // patch_size),
85
+ )
86
+ self.blocks.append(block)
87
+
88
+ self.neck = nn.Sequential(
89
+ nn.Conv2d(
90
+ embed_dim,
91
+ out_chans,
92
+ kernel_size=1,
93
+ bias=False,
94
+ ),
95
+ LayerNorm2d(out_chans),
96
+ nn.Conv2d(
97
+ out_chans,
98
+ out_chans,
99
+ kernel_size=3,
100
+ padding=1,
101
+ bias=False,
102
+ ),
103
+ LayerNorm2d(out_chans),
104
+ )
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ x = self.patch_embed(x)
108
+ if self.pos_embed is not None:
109
+ x = x + self.pos_embed
110
+
111
+ interm_embeddings=[]
112
+ for blk in self.blocks:
113
+ x = blk(x)
114
+ if blk.window_size == 0:
115
+ interm_embeddings.append(x)
116
+
117
+ x = self.neck(x.permute(0, 3, 1, 2))
118
+
119
+ return x, interm_embeddings
120
+
121
+
122
+ class Block(nn.Module):
123
+ """Transformer blocks with support of window attention and residual propagation blocks"""
124
+
125
+ def __init__(
126
+ self,
127
+ dim: int,
128
+ num_heads: int,
129
+ mlp_ratio: float = 4.0,
130
+ qkv_bias: bool = True,
131
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
132
+ act_layer: Type[nn.Module] = nn.GELU,
133
+ use_rel_pos: bool = False,
134
+ rel_pos_zero_init: bool = True,
135
+ window_size: int = 0,
136
+ input_size: Optional[Tuple[int, int]] = None,
137
+ ) -> None:
138
+ """
139
+ Args:
140
+ dim (int): Number of input channels.
141
+ num_heads (int): Number of attention heads in each ViT block.
142
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
143
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
144
+ norm_layer (nn.Module): Normalization layer.
145
+ act_layer (nn.Module): Activation layer.
146
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
147
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
148
+ window_size (int): Window size for window attention blocks. If it equals 0, then
149
+ use global attention.
150
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
151
+ positional parameter size.
152
+ """
153
+ super().__init__()
154
+ self.norm1 = norm_layer(dim)
155
+ self.attn = Attention(
156
+ dim,
157
+ num_heads=num_heads,
158
+ qkv_bias=qkv_bias,
159
+ use_rel_pos=use_rel_pos,
160
+ rel_pos_zero_init=rel_pos_zero_init,
161
+ input_size=input_size if window_size == 0 else (window_size, window_size),
162
+ )
163
+
164
+ self.norm2 = norm_layer(dim)
165
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
166
+
167
+ self.window_size = window_size
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ shortcut = x
171
+ x = self.norm1(x)
172
+ # Window partition
173
+ if self.window_size > 0:
174
+ H, W = x.shape[1], x.shape[2]
175
+ x, pad_hw = window_partition(x, self.window_size)
176
+
177
+ x = self.attn(x)
178
+ # Reverse window partition
179
+ if self.window_size > 0:
180
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
181
+
182
+ x = shortcut + x
183
+ x = x + self.mlp(self.norm2(x))
184
+
185
+ return x
186
+
187
+
188
+ class Attention(nn.Module):
189
+ """Multi-head Attention block with relative position embeddings."""
190
+
191
+ def __init__(
192
+ self,
193
+ dim: int,
194
+ num_heads: int = 8,
195
+ qkv_bias: bool = True,
196
+ use_rel_pos: bool = False,
197
+ rel_pos_zero_init: bool = True,
198
+ input_size: Optional[Tuple[int, int]] = None,
199
+ ) -> None:
200
+ """
201
+ Args:
202
+ dim (int): Number of input channels.
203
+ num_heads (int): Number of attention heads.
204
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
205
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
206
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
207
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
208
+ positional parameter size.
209
+ """
210
+ super().__init__()
211
+ self.num_heads = num_heads
212
+ head_dim = dim // num_heads
213
+ self.scale = head_dim**-0.5
214
+
215
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
216
+ self.proj = nn.Linear(dim, dim)
217
+
218
+ self.use_rel_pos = use_rel_pos
219
+ if self.use_rel_pos:
220
+ assert (
221
+ input_size is not None
222
+ ), "Input size must be provided if using relative positional encoding."
223
+ # initialize relative positional embeddings
224
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
225
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
226
+
227
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
228
+ B, H, W, _ = x.shape
229
+ # qkv with shape (3, B, nHead, H * W, C)
230
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
231
+ # q, k, v with shape (B * nHead, H * W, C)
232
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
233
+
234
+ attn = (q * self.scale) @ k.transpose(-2, -1)
235
+
236
+ if self.use_rel_pos:
237
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
238
+
239
+ attn = attn.softmax(dim=-1)
240
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
241
+ x = self.proj(x)
242
+
243
+ return x
244
+
245
+
246
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
247
+ """
248
+ Partition into non-overlapping windows with padding if needed.
249
+ Args:
250
+ x (tensor): input tokens with [B, H, W, C].
251
+ window_size (int): window size.
252
+
253
+ Returns:
254
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
255
+ (Hp, Wp): padded height and width before partition
256
+ """
257
+ B, H, W, C = x.shape
258
+
259
+ pad_h = (window_size - H % window_size) % window_size
260
+ pad_w = (window_size - W % window_size) % window_size
261
+ if pad_h > 0 or pad_w > 0:
262
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
263
+ Hp, Wp = H + pad_h, W + pad_w
264
+
265
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
266
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
267
+ return windows, (Hp, Wp)
268
+
269
+
270
+ def window_unpartition(
271
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
272
+ ) -> torch.Tensor:
273
+ """
274
+ Window unpartition into original sequences and removing padding.
275
+ Args:
276
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
277
+ window_size (int): window size.
278
+ pad_hw (Tuple): padded height and width (Hp, Wp).
279
+ hw (Tuple): original height and width (H, W) before padding.
280
+
281
+ Returns:
282
+ x: unpartitioned sequences with [B, H, W, C].
283
+ """
284
+ Hp, Wp = pad_hw
285
+ H, W = hw
286
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
287
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
288
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
289
+
290
+ if Hp > H or Wp > W:
291
+ x = x[:, :H, :W, :].contiguous()
292
+ return x
293
+
294
+
295
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
296
+ """
297
+ Get relative positional embeddings according to the relative positions of
298
+ query and key sizes.
299
+ Args:
300
+ q_size (int): size of query q.
301
+ k_size (int): size of key k.
302
+ rel_pos (Tensor): relative position embeddings (L, C).
303
+
304
+ Returns:
305
+ Extracted positional embeddings according to relative positions.
306
+ """
307
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
308
+ # Interpolate rel pos if needed.
309
+ if rel_pos.shape[0] != max_rel_dist:
310
+ # Interpolate rel pos.
311
+ rel_pos_resized = F.interpolate(
312
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
313
+ size=max_rel_dist,
314
+ mode="linear",
315
+ )
316
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
317
+ else:
318
+ rel_pos_resized = rel_pos
319
+
320
+ # Scale the coords with short length if shapes for q and k are different.
321
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
322
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
323
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
324
+
325
+ return rel_pos_resized[relative_coords.long()]
326
+
327
+
328
+ def add_decomposed_rel_pos(
329
+ attn: torch.Tensor,
330
+ q: torch.Tensor,
331
+ rel_pos_h: torch.Tensor,
332
+ rel_pos_w: torch.Tensor,
333
+ q_size: Tuple[int, int],
334
+ k_size: Tuple[int, int],
335
+ ) -> torch.Tensor:
336
+ """
337
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
338
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
339
+ Args:
340
+ attn (Tensor): attention map.
341
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
342
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
343
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
344
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
345
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
346
+
347
+ Returns:
348
+ attn (Tensor): attention map with added relative positional embeddings.
349
+ """
350
+ q_h, q_w = q_size
351
+ k_h, k_w = k_size
352
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
353
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
354
+
355
+ B, _, dim = q.shape
356
+ r_q = q.reshape(B, q_h, q_w, dim)
357
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
358
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
359
+
360
+ attn = (
361
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
362
+ ).view(B, q_h * q_w, k_h * k_w)
363
+
364
+ return attn
365
+
366
+
367
+ class PatchEmbed(nn.Module):
368
+ """
369
+ Image to Patch Embedding.
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ kernel_size: Tuple[int, int] = (16, 16),
375
+ stride: Tuple[int, int] = (16, 16),
376
+ padding: Tuple[int, int] = (0, 0),
377
+ in_chans: int = 3,
378
+ embed_dim: int = 768,
379
+ ) -> None:
380
+ """
381
+ Args:
382
+ kernel_size (Tuple): kernel size of the projection layer.
383
+ stride (Tuple): stride of the projection layer.
384
+ padding (Tuple): padding size of the projection layer.
385
+ in_chans (int): Number of input image channels.
386
+ embed_dim (int): Patch embedding dimension.
387
+ """
388
+ super().__init__()
389
+
390
+ self.proj = nn.Conv2d(
391
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
392
+ )
393
+
394
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
395
+ x = self.proj(x)
396
+ # B C H W -> B H W C
397
+ x = x.permute(0, 2, 3, 1)
398
+ return x
app/sam_hq/mask_decoder_hq.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by HQ-SAM team
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from typing import List, Tuple, Type
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class MaskDecoderHQ(nn.Module):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ transformer_dim: int,
21
+ transformer: nn.Module,
22
+ num_multimask_outputs: int = 3,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ iou_head_depth: int = 3,
25
+ iou_head_hidden_dim: int = 256,
26
+ vit_dim: int = 1024,
27
+ ) -> None:
28
+ """
29
+ Predicts masks given an image and prompt embeddings, using a
30
+ transformer architecture.
31
+
32
+ Arguments:
33
+ transformer_dim (int): the channel dimension of the transformer
34
+ transformer (nn.Module): the transformer used to predict masks
35
+ num_multimask_outputs (int): the number of masks to predict
36
+ when disambiguating masks
37
+ activation (nn.Module): the type of activation to use when
38
+ upscaling masks
39
+ iou_head_depth (int): the depth of the MLP used to predict
40
+ mask quality
41
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
42
+ used to predict mask quality
43
+ """
44
+ super().__init__()
45
+ self.transformer_dim = transformer_dim
46
+ self.transformer = transformer
47
+
48
+ self.num_multimask_outputs = num_multimask_outputs
49
+
50
+ self.iou_token = nn.Embedding(1, transformer_dim)
51
+ self.num_mask_tokens = num_multimask_outputs + 1
52
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
53
+
54
+ self.output_upscaling = nn.Sequential(
55
+ nn.ConvTranspose2d(
56
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
57
+ ),
58
+ LayerNorm2d(transformer_dim // 4),
59
+ activation(),
60
+ nn.ConvTranspose2d(
61
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
62
+ ),
63
+ activation(),
64
+ )
65
+ self.output_hypernetworks_mlps = nn.ModuleList(
66
+ [
67
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
68
+ for i in range(self.num_mask_tokens)
69
+ ]
70
+ )
71
+
72
+ self.iou_prediction_head = MLP(
73
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
74
+ )
75
+
76
+ # HQ-SAM parameters
77
+ self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token
78
+ self.hf_mlp = MLP(
79
+ transformer_dim, transformer_dim, transformer_dim // 8, 3
80
+ ) # corresponding new MLP layer for HQ-Ouptput-Token
81
+ self.num_mask_tokens = self.num_mask_tokens + 1
82
+
83
+ # three conv fusion layers for obtaining HQ-Feature
84
+ self.compress_vit_feat = nn.Sequential(
85
+ nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
86
+ LayerNorm2d(transformer_dim),
87
+ nn.GELU(),
88
+ nn.ConvTranspose2d(
89
+ transformer_dim, transformer_dim // 8, kernel_size=2, stride=2
90
+ ),
91
+ )
92
+
93
+ self.embedding_encoder = nn.Sequential(
94
+ nn.ConvTranspose2d(
95
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
96
+ ),
97
+ LayerNorm2d(transformer_dim // 4),
98
+ nn.GELU(),
99
+ nn.ConvTranspose2d(
100
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
101
+ ),
102
+ )
103
+ self.embedding_maskfeature = nn.Sequential(
104
+ nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
105
+ LayerNorm2d(transformer_dim // 4),
106
+ nn.GELU(),
107
+ nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1),
108
+ )
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ sparse_prompt_embeddings: torch.Tensor,
115
+ dense_prompt_embeddings: torch.Tensor,
116
+ multimask_output: bool,
117
+ hq_token_only: bool,
118
+ interm_embeddings: torch.Tensor,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """
121
+ Predict masks given image and prompt embeddings.
122
+
123
+ Arguments:
124
+ image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
125
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128
+ multimask_output (bool): Whether to return multiple masks or a single
129
+ mask.
130
+
131
+ Returns:
132
+ torch.Tensor: batched predicted masks
133
+ torch.Tensor: batched predictions of mask quality
134
+ """
135
+ vit_features = interm_embeddings[0].permute(
136
+ 0, 3, 1, 2
137
+ ) # early-layer ViT feature, after 1st global attention block in ViT
138
+ hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(
139
+ vit_features
140
+ )
141
+
142
+ masks, iou_pred = self.predict_masks(
143
+ image_embeddings=image_embeddings,
144
+ image_pe=image_pe,
145
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
146
+ dense_prompt_embeddings=dense_prompt_embeddings,
147
+ hq_features=hq_features,
148
+ )
149
+
150
+ # Select the correct mask or masks for output
151
+ if multimask_output:
152
+ # mask with highest score
153
+ mask_slice = slice(1, self.num_mask_tokens - 1)
154
+ iou_pred = iou_pred[:, mask_slice]
155
+ iou_pred, max_iou_idx = torch.max(iou_pred, dim=1)
156
+ iou_pred = iou_pred.unsqueeze(1)
157
+ masks_multi = masks[:, mask_slice, :, :]
158
+ masks_sam = masks_multi[
159
+ torch.arange(masks_multi.size(0)), max_iou_idx
160
+ ].unsqueeze(1)
161
+ else:
162
+ # singale mask output, default
163
+ mask_slice = slice(0, 1)
164
+ iou_pred = iou_pred[:, mask_slice]
165
+ masks_sam = masks[:, mask_slice]
166
+
167
+ masks_hq = masks[:, slice(self.num_mask_tokens - 1, self.num_mask_tokens)]
168
+ if hq_token_only:
169
+ masks = masks_hq
170
+ else:
171
+ masks = masks_sam + masks_hq
172
+ # Prepare output
173
+ return masks, iou_pred
174
+
175
+ def predict_masks(
176
+ self,
177
+ image_embeddings: torch.Tensor,
178
+ image_pe: torch.Tensor,
179
+ sparse_prompt_embeddings: torch.Tensor,
180
+ dense_prompt_embeddings: torch.Tensor,
181
+ hq_features: torch.Tensor,
182
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
183
+ """Predicts masks. See 'forward' for more details."""
184
+ # Concatenate output tokens
185
+ output_tokens = torch.cat(
186
+ [self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight],
187
+ dim=0,
188
+ )
189
+ output_tokens = output_tokens.unsqueeze(0).expand(
190
+ sparse_prompt_embeddings.size(0), -1, -1
191
+ )
192
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
193
+
194
+ # Expand per-image data in batch direction to be per-mask
195
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
196
+ src = src + dense_prompt_embeddings
197
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
198
+ b, c, h, w = src.shape
199
+
200
+ # Run the transformer
201
+ hs, src = self.transformer(src, pos_src, tokens)
202
+ iou_token_out = hs[:, 0, :]
203
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
204
+
205
+ # Upscale mask embeddings and predict masks using the mask tokens
206
+ src = src.transpose(1, 2).view(b, c, h, w)
207
+
208
+ upscaled_embedding_sam = self.output_upscaling(src)
209
+ upscaled_embedding_hq = self.embedding_maskfeature(
210
+ upscaled_embedding_sam
211
+ ) + hq_features.repeat(b, 1, 1, 1)
212
+
213
+ hyper_in_list: List[torch.Tensor] = []
214
+ for i in range(self.num_mask_tokens):
215
+ if i < self.num_mask_tokens - 1:
216
+ hyper_in_list.append(
217
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
218
+ )
219
+ else:
220
+ hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
221
+
222
+ hyper_in = torch.stack(hyper_in_list, dim=1)
223
+ b, c, h, w = upscaled_embedding_sam.shape
224
+
225
+ masks_sam = (
226
+ hyper_in[:, : self.num_mask_tokens - 1]
227
+ @ upscaled_embedding_sam.view(b, c, h * w)
228
+ ).view(b, -1, h, w)
229
+ masks_sam_hq = (
230
+ hyper_in[:, self.num_mask_tokens - 1 :]
231
+ @ upscaled_embedding_hq.view(b, c, h * w)
232
+ ).view(b, -1, h, w)
233
+ masks = torch.cat([masks_sam, masks_sam_hq], dim=1)
234
+ # Generate mask quality predictions
235
+ iou_pred = self.iou_prediction_head(iou_token_out)
236
+
237
+ return masks, iou_pred
238
+
239
+
240
+ # Lightly adapted from
241
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
242
+ class MLP(nn.Module):
243
+ def __init__(
244
+ self,
245
+ input_dim: int,
246
+ hidden_dim: int,
247
+ output_dim: int,
248
+ num_layers: int,
249
+ sigmoid_output: bool = False,
250
+ ) -> None:
251
+ super().__init__()
252
+ self.num_layers = num_layers
253
+ h = [hidden_dim] * (num_layers - 1)
254
+ self.layers = nn.ModuleList(
255
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
256
+ )
257
+ self.sigmoid_output = sigmoid_output
258
+
259
+ def forward(self, x):
260
+ for i, layer in enumerate(self.layers):
261
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
262
+ if self.sigmoid_output:
263
+ x = F.sigmoid(x)
264
+ return x
app/sam_hq/predictor.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ from typing import Optional, Tuple
12
+ from .sam import Sam
13
+ from .transforms import ResizeLongestSide
14
+
15
+
16
+ class SamPredictor:
17
+ def __init__(
18
+ self,
19
+ sam_model: Sam,
20
+ ) -> None:
21
+ """
22
+ Uses SAM to calculate the image embedding for an image, and then
23
+ allow repeated, efficient mask prediction given prompts.
24
+
25
+ Arguments:
26
+ sam_model (Sam): The model to use for mask prediction.
27
+ """
28
+ super().__init__()
29
+ self.model = sam_model
30
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
31
+ self.reset_image()
32
+
33
+ def set_image(
34
+ self,
35
+ image: np.ndarray,
36
+ image_format: str = "RGB",
37
+ ) -> None:
38
+ """
39
+ Calculates the image embeddings for the provided image, allowing
40
+ masks to be predicted with the 'predict' method.
41
+
42
+ Arguments:
43
+ image (np.ndarray): The image for calculating masks. Expects an
44
+ image in HWC uint8 format, with pixel values in [0, 255].
45
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
46
+ """
47
+ assert image_format in [
48
+ "RGB",
49
+ "BGR",
50
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
51
+ # import pdb;pdb.set_trace()
52
+ if image_format != self.model.image_format:
53
+ image = image[..., ::-1]
54
+
55
+ # Transform the image to the form expected by the model
56
+ # import pdb;pdb.set_trace()
57
+ input_image = self.transform.apply_image(image)
58
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
59
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
60
+ None, :, :, :
61
+ ]
62
+
63
+ self.set_torch_image(input_image_torch, image.shape[:2])
64
+
65
+ @torch.no_grad()
66
+ def set_torch_image(
67
+ self,
68
+ transformed_image: torch.Tensor,
69
+ original_image_size: Tuple[int, ...],
70
+ ) -> None:
71
+ """
72
+ Calculates the image embeddings for the provided image, allowing
73
+ masks to be predicted with the 'predict' method. Expects the input
74
+ image to be already transformed to the format expected by the model.
75
+
76
+ Arguments:
77
+ transformed_image (torch.Tensor): The input image, with shape
78
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
79
+ original_image_size (tuple(int, int)): The size of the image
80
+ before transformation, in (H, W) format.
81
+ """
82
+ assert (
83
+ len(transformed_image.shape) == 4
84
+ and transformed_image.shape[1] == 3
85
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
86
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
87
+ self.reset_image()
88
+
89
+ self.original_size = original_image_size
90
+ self.input_size = tuple(transformed_image.shape[-2:])
91
+ input_image = self.model.preprocess(transformed_image)
92
+ self.features, self.interm_features = self.model.image_encoder(input_image)
93
+ self.is_image_set = True
94
+
95
+ def predict(
96
+ self,
97
+ point_coords: Optional[np.ndarray] = None,
98
+ point_labels: Optional[np.ndarray] = None,
99
+ box: Optional[np.ndarray] = None,
100
+ mask_input: Optional[np.ndarray] = None,
101
+ multimask_output: bool = True,
102
+ return_logits: bool = False,
103
+ hq_token_only: bool = False,
104
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
105
+ """
106
+ Predict masks for the given input prompts, using the currently set image.
107
+
108
+ Arguments:
109
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
110
+ model. Each point is in (X,Y) in pixels.
111
+ point_labels (np.ndarray or None): A length N array of labels for the
112
+ point prompts. 1 indicates a foreground point and 0 indicates a
113
+ background point.
114
+ box (np.ndarray or None): A length 4 array given a box prompt to the
115
+ model, in XYXY format.
116
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
117
+ coming from a previous prediction iteration. Has form 1xHxW, where
118
+ for SAM, H=W=256.
119
+ multimask_output (bool): If true, the model will return three masks.
120
+ For ambiguous input prompts (such as a single click), this will often
121
+ produce better masks than a single prediction. If only a single
122
+ mask is needed, the model's predicted quality score can be used
123
+ to select the best mask. For non-ambiguous prompts, such as multiple
124
+ input prompts, multimask_output=False can give better results.
125
+ return_logits (bool): If true, returns un-thresholded masks logits
126
+ instead of a binary mask.
127
+
128
+ Returns:
129
+ (np.ndarray): The output masks in CxHxW format, where C is the
130
+ number of masks, and (H, W) is the original image size.
131
+ (np.ndarray): An array of length C containing the model's
132
+ predictions for the quality of each mask.
133
+ (np.ndarray): An array of shape CxHxW, where C is the number
134
+ of masks and H=W=256. These low resolution logits can be passed to
135
+ a subsequent iteration as mask input.
136
+ """
137
+ if not self.is_image_set:
138
+ raise RuntimeError(
139
+ "An image must be set with .set_image(...) before mask prediction."
140
+ )
141
+
142
+ # Transform input prompts
143
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
144
+ if point_coords is not None:
145
+ assert (
146
+ point_labels is not None
147
+ ), "point_labels must be supplied if point_coords is supplied."
148
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
149
+ coords_torch = torch.as_tensor(
150
+ point_coords, dtype=torch.float, device=self.device
151
+ )
152
+ labels_torch = torch.as_tensor(
153
+ point_labels, dtype=torch.int, device=self.device
154
+ )
155
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
156
+ if box is not None:
157
+ box = self.transform.apply_boxes(box, self.original_size)
158
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
159
+ box_torch = box_torch[None, :]
160
+ if mask_input is not None:
161
+ mask_input_torch = torch.as_tensor(
162
+ mask_input, dtype=torch.float, device=self.device
163
+ )
164
+ mask_input_torch = mask_input_torch[None, :, :, :]
165
+
166
+ masks, iou_predictions, low_res_masks = self.predict_torch(
167
+ coords_torch,
168
+ labels_torch,
169
+ box_torch,
170
+ mask_input_torch,
171
+ multimask_output,
172
+ return_logits=return_logits,
173
+ hq_token_only=hq_token_only,
174
+ )
175
+
176
+ masks_np = masks[0].detach().cpu().numpy()
177
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
178
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
179
+ return masks_np, iou_predictions_np, low_res_masks_np
180
+
181
+ @torch.no_grad()
182
+ def predict_torch(
183
+ self,
184
+ point_coords: Optional[torch.Tensor],
185
+ point_labels: Optional[torch.Tensor],
186
+ boxes: Optional[torch.Tensor] = None,
187
+ mask_input: Optional[torch.Tensor] = None,
188
+ multimask_output: bool = True,
189
+ return_logits: bool = False,
190
+ hq_token_only: bool = False,
191
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
192
+ """
193
+ Predict masks for the given input prompts, using the currently set image.
194
+ Input prompts are batched torch tensors and are expected to already be
195
+ transformed to the input frame using ResizeLongestSide.
196
+
197
+ Arguments:
198
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
199
+ model. Each point is in (X,Y) in pixels.
200
+ point_labels (torch.Tensor or None): A BxN array of labels for the
201
+ point prompts. 1 indicates a foreground point and 0 indicates a
202
+ background point.
203
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
204
+ model, in XYXY format.
205
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
206
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
207
+ for SAM, H=W=256. Masks returned by a previous iteration of the
208
+ predict method do not need further transformation.
209
+ multimask_output (bool): If true, the model will return three masks.
210
+ For ambiguous input prompts (such as a single click), this will often
211
+ produce better masks than a single prediction. If only a single
212
+ mask is needed, the model's predicted quality score can be used
213
+ to select the best mask. For non-ambiguous prompts, such as multiple
214
+ input prompts, multimask_output=False can give better results.
215
+ return_logits (bool): If true, returns un-thresholded masks logits
216
+ instead of a binary mask.
217
+
218
+ Returns:
219
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
220
+ number of masks, and (H, W) is the original image size.
221
+ (torch.Tensor): An array of shape BxC containing the model's
222
+ predictions for the quality of each mask.
223
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
224
+ of masks and H=W=256. These low res logits can be passed to
225
+ a subsequent iteration as mask input.
226
+ """
227
+ if not self.is_image_set:
228
+ raise RuntimeError(
229
+ "An image must be set with .set_image(...) before mask prediction."
230
+ )
231
+
232
+ if point_coords is not None:
233
+ points = (point_coords, point_labels)
234
+ else:
235
+ points = None
236
+
237
+ # Embed prompts
238
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
239
+ points=points,
240
+ boxes=boxes,
241
+ masks=mask_input,
242
+ )
243
+
244
+ # Predict masks
245
+ low_res_masks, iou_predictions = self.model.mask_decoder(
246
+ image_embeddings=self.features,
247
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
248
+ sparse_prompt_embeddings=sparse_embeddings,
249
+ dense_prompt_embeddings=dense_embeddings,
250
+ multimask_output=multimask_output,
251
+ hq_token_only=hq_token_only,
252
+ interm_embeddings=self.interm_features,
253
+ )
254
+
255
+ # Upscale the masks to the original image resolution
256
+ masks = self.model.postprocess_masks(
257
+ low_res_masks, self.input_size, self.original_size
258
+ )
259
+
260
+ if not return_logits:
261
+ masks = masks > self.model.mask_threshold
262
+
263
+ return masks, iou_predictions, low_res_masks
264
+
265
+ def get_image_embedding(self) -> torch.Tensor:
266
+ """
267
+ Returns the image embeddings for the currently set image, with
268
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
269
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
270
+ """
271
+ if not self.is_image_set:
272
+ raise RuntimeError(
273
+ "An image must be set with .set_image(...) to generate an embedding."
274
+ )
275
+ assert (
276
+ self.features is not None
277
+ ), "Features must exist if an image has been set."
278
+ return self.features
279
+
280
+ @property
281
+ def device(self) -> torch.device:
282
+ return self.model.device
283
+
284
+ def reset_image(self) -> None:
285
+ """Resets the currently set image."""
286
+ self.is_image_set = False
287
+ self.features = None
288
+ self.orig_h = None
289
+ self.orig_w = None
290
+ self.input_h = None
291
+ self.input_w = None
app/sam_hq/prompt_encoder.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from typing import Any, Optional, Tuple, Type
12
+ from .common import LayerNorm2d
13
+
14
+
15
+ class PromptEncoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ embed_dim: int,
19
+ image_embedding_size: Tuple[int, int],
20
+ input_image_size: Tuple[int, int],
21
+ mask_in_chans: int,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ ) -> None:
24
+ """
25
+ Encodes prompts for input to SAM's mask decoder.
26
+
27
+ Arguments:
28
+ embed_dim (int): The prompts' embedding dimension
29
+ image_embedding_size (tuple(int, int)): The spatial size of the
30
+ image embedding, as (H, W).
31
+ input_image_size (int): The padded size of the image as input
32
+ to the image encoder, as (H, W).
33
+ mask_in_chans (int): The number of hidden channels used for
34
+ encoding input masks.
35
+ activation (nn.Module): The activation to use when encoding
36
+ input masks.
37
+ """
38
+ super().__init__()
39
+ self.embed_dim = embed_dim
40
+ self.input_image_size = input_image_size
41
+ self.image_embedding_size = image_embedding_size
42
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
43
+
44
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
45
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
46
+ self.point_embeddings = nn.ModuleList(point_embeddings)
47
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
48
+
49
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
50
+ self.mask_downscaling = nn.Sequential(
51
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
52
+ LayerNorm2d(mask_in_chans // 4),
53
+ activation(),
54
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
55
+ LayerNorm2d(mask_in_chans),
56
+ activation(),
57
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
58
+ )
59
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
60
+
61
+ def get_dense_pe(self) -> torch.Tensor:
62
+ """
63
+ Returns the positional encoding used to encode point prompts,
64
+ applied to a dense set of points the shape of the image encoding.
65
+
66
+ Returns:
67
+ torch.Tensor: Positional encoding with shape
68
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
69
+ """
70
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
71
+
72
+ def _embed_points(
73
+ self,
74
+ points: torch.Tensor,
75
+ labels: torch.Tensor,
76
+ pad: bool,
77
+ ) -> torch.Tensor:
78
+ """Embeds point prompts."""
79
+ points = points + 0.5 # Shift to center of pixel
80
+ if pad:
81
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
82
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
83
+ points = torch.cat([points, padding_point], dim=1)
84
+ labels = torch.cat([labels, padding_label], dim=1)
85
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
86
+ point_embedding[labels == -1] = 0.0
87
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
88
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
89
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
90
+ return point_embedding
91
+
92
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
93
+ """Embeds box prompts."""
94
+ boxes = boxes + 0.5 # Shift to center of pixel
95
+ coords = boxes.reshape(-1, 2, 2)
96
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
97
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
98
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
99
+ return corner_embedding
100
+
101
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
102
+ """Embeds mask inputs."""
103
+ mask_embedding = self.mask_downscaling(masks)
104
+ return mask_embedding
105
+
106
+ def _get_batch_size(
107
+ self,
108
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
109
+ boxes: Optional[torch.Tensor],
110
+ masks: Optional[torch.Tensor],
111
+ ) -> int:
112
+ """
113
+ Gets the batch size of the output given the batch size of the input prompts.
114
+ """
115
+ if points is not None:
116
+ return points[0].shape[0]
117
+ elif boxes is not None:
118
+ return boxes.shape[0]
119
+ elif masks is not None:
120
+ return masks.shape[0]
121
+ else:
122
+ return 1
123
+
124
+ def _get_device(self) -> torch.device:
125
+ return self.point_embeddings[0].weight.device
126
+
127
+ def forward(
128
+ self,
129
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
130
+ boxes: Optional[torch.Tensor],
131
+ masks: Optional[torch.Tensor],
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ """
134
+ Embeds different types of prompts, returning both sparse and dense
135
+ embeddings.
136
+
137
+ Arguments:
138
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
139
+ and labels to embed.
140
+ boxes (torch.Tensor or none): boxes to embed
141
+ masks (torch.Tensor or none): masks to embed
142
+
143
+ Returns:
144
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
145
+ BxNx(embed_dim), where N is determined by the number of input points
146
+ and boxes.
147
+ torch.Tensor: dense embeddings for the masks, in the shape
148
+ Bx(embed_dim)x(embed_H)x(embed_W)
149
+ """
150
+ bs = self._get_batch_size(points, boxes, masks)
151
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
152
+ if points is not None:
153
+ coords, labels = points
154
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
155
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
156
+ if boxes is not None:
157
+ box_embeddings = self._embed_boxes(boxes)
158
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
159
+
160
+ if masks is not None:
161
+ dense_embeddings = self._embed_masks(masks)
162
+ else:
163
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
164
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
165
+ )
166
+
167
+ return sparse_embeddings, dense_embeddings
168
+
169
+
170
+ class PositionEmbeddingRandom(nn.Module):
171
+ """
172
+ Positional encoding using random spatial frequencies.
173
+ """
174
+
175
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
176
+ super().__init__()
177
+ if scale is None or scale <= 0.0:
178
+ scale = 1.0
179
+ self.register_buffer(
180
+ "positional_encoding_gaussian_matrix",
181
+ scale * torch.randn((2, num_pos_feats)),
182
+ )
183
+
184
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
185
+ """Positionally encode points that are normalized to [0,1]."""
186
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
187
+ coords = 2 * coords - 1
188
+ coords = coords @ self.positional_encoding_gaussian_matrix
189
+ coords = 2 * np.pi * coords
190
+ # outputs d_1 x ... x d_n x C shape
191
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
192
+
193
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
194
+ """Generate positional encoding for a grid of the specified size."""
195
+ h, w = size
196
+ device: Any = self.positional_encoding_gaussian_matrix.device
197
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
198
+ y_embed = grid.cumsum(dim=0) - 0.5
199
+ x_embed = grid.cumsum(dim=1) - 0.5
200
+ y_embed = y_embed / h
201
+ x_embed = x_embed / w
202
+
203
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
204
+ return pe.permute(2, 0, 1) # C x H x W
205
+
206
+ def forward_with_coords(
207
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
208
+ ) -> torch.Tensor:
209
+ """Positionally encode points that are not normalized to [0,1]."""
210
+ coords = coords_input.clone()
211
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
212
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
213
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
app/sam_hq/sam.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ from .image_encoder import ImageEncoderViT
14
+ from .mask_decoder_hq import MaskDecoderHQ
15
+ from .prompt_encoder import PromptEncoder
16
+
17
+
18
+ class Sam(nn.Module):
19
+ mask_threshold: float = 0.0
20
+ image_format: str = "RGB"
21
+
22
+ def __init__(
23
+ self,
24
+ image_encoder: ImageEncoderViT,
25
+ prompt_encoder: PromptEncoder,
26
+ mask_decoder: MaskDecoderHQ,
27
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
28
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
29
+ ) -> None:
30
+ """
31
+ SAM predicts object masks from an image and input prompts.
32
+
33
+ Arguments:
34
+ image_encoder (ImageEncoderViT): The backbone used to encode the
35
+ image into image embeddings that allow for efficient mask prediction.
36
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38
+ and encoded prompts.
39
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
41
+ """
42
+ super().__init__()
43
+ self.image_encoder = image_encoder
44
+ self.prompt_encoder = prompt_encoder
45
+ self.mask_decoder = mask_decoder
46
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48
+
49
+ @property
50
+ def device(self) -> Any:
51
+ return self.pixel_mean.device
52
+
53
+ def forward(
54
+ self,
55
+ batched_input: List[Dict[str, Any]],
56
+ multimask_output: bool,
57
+ hq_token_only: bool =False,
58
+ ) -> List[Dict[str, torch.Tensor]]:
59
+ """
60
+ Predicts masks end-to-end from provided images and prompts.
61
+ If prompts are not known in advance, using SamPredictor is
62
+ recommended over calling the model directly.
63
+
64
+ Arguments:
65
+ batched_input (list(dict)): A list over input images, each a
66
+ dictionary with the following keys. A prompt key can be
67
+ excluded if it is not present.
68
+ 'image': The image as a torch tensor in 3xHxW format,
69
+ already transformed for input to the model.
70
+ 'original_size': (tuple(int, int)) The original size of
71
+ the image before transformation, as (H, W).
72
+ 'point_coords': (torch.Tensor) Batched point prompts for
73
+ this image, with shape BxNx2. Already transformed to the
74
+ input frame of the model.
75
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
76
+ with shape BxN.
77
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78
+ Already transformed to the input frame of the model.
79
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80
+ in the form Bx1xHxW.
81
+ multimask_output (bool): Whether the model should predict multiple
82
+ disambiguating masks, or return a single mask.
83
+
84
+ Returns:
85
+ (list(dict)): A list over input images, where each element is
86
+ as dictionary with the following keys.
87
+ 'masks': (torch.Tensor) Batched binary mask predictions,
88
+ with shape BxCxHxW, where B is the number of input prompts,
89
+ C is determined by multimask_output, and (H, W) is the
90
+ original size of the image.
91
+ 'iou_predictions': (torch.Tensor) The model's predictions
92
+ of mask quality, in shape BxC.
93
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
94
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
95
+ to subsequent iterations of prediction.
96
+ """
97
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98
+ image_embeddings, interm_embeddings = self.image_encoder(input_images)
99
+ interm_embeddings = interm_embeddings[0] # early layer
100
+
101
+ outputs = []
102
+ for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings):
103
+ if "point_coords" in image_record:
104
+ points = (image_record["point_coords"], image_record["point_labels"])
105
+ else:
106
+ points = None
107
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
108
+ points=points,
109
+ boxes=image_record.get("boxes", None),
110
+ masks=image_record.get("mask_inputs", None),
111
+ )
112
+ low_res_masks, iou_predictions = self.mask_decoder(
113
+ image_embeddings=curr_embedding.unsqueeze(0),
114
+ image_pe=self.prompt_encoder.get_dense_pe(),
115
+ sparse_prompt_embeddings=sparse_embeddings,
116
+ dense_prompt_embeddings=dense_embeddings,
117
+ multimask_output=multimask_output,
118
+ hq_token_only=hq_token_only,
119
+ interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
120
+ )
121
+ masks = self.postprocess_masks(
122
+ low_res_masks,
123
+ input_size=image_record["image"].shape[-2:],
124
+ original_size=image_record["original_size"],
125
+ )
126
+ masks = masks > self.mask_threshold
127
+ outputs.append(
128
+ {
129
+ "masks": masks,
130
+ "iou_predictions": iou_predictions,
131
+ "low_res_logits": low_res_masks,
132
+ }
133
+ )
134
+ return outputs
135
+
136
+ def postprocess_masks(
137
+ self,
138
+ masks: torch.Tensor,
139
+ input_size: Tuple[int, ...],
140
+ original_size: Tuple[int, ...],
141
+ ) -> torch.Tensor:
142
+ """
143
+ Remove padding and upscale masks to the original image size.
144
+
145
+ Arguments:
146
+ masks (torch.Tensor): Batched masks from the mask_decoder,
147
+ in BxCxHxW format.
148
+ input_size (tuple(int, int)): The size of the image input to the
149
+ model, in (H, W) format. Used to remove padding.
150
+ original_size (tuple(int, int)): The original size of the image
151
+ before resizing for input to the model, in (H, W) format.
152
+
153
+ Returns:
154
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
155
+ is given by original_size.
156
+ """
157
+ masks = F.interpolate(
158
+ masks,
159
+ (self.image_encoder.img_size, self.image_encoder.img_size),
160
+ mode="bilinear",
161
+ align_corners=False,
162
+ )
163
+ masks = masks[..., : input_size[0], : input_size[1]]
164
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
165
+ return masks
166
+
167
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
168
+ """Normalize pixel values and pad to a square input."""
169
+ # Normalize colors
170
+ x = (x - self.pixel_mean) / self.pixel_std
171
+
172
+ # Pad
173
+ h, w = x.shape[-2:]
174
+ padh = self.image_encoder.img_size - h
175
+ padw = self.image_encoder.img_size - w
176
+ x = F.pad(x, (0, padw, 0, padh))
177
+ return x
app/sam_hq/transformer.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import math
9
+
10
+ from torch import Tensor, nn
11
+ from typing import Tuple, Type
12
+ from .common import MLPBlock
13
+
14
+
15
+ class TwoWayTransformer(nn.Module):
16
+ def __init__(
17
+ self,
18
+ depth: int,
19
+ embedding_dim: int,
20
+ num_heads: int,
21
+ mlp_dim: int,
22
+ activation: Type[nn.Module] = nn.ReLU,
23
+ attention_downsample_rate: int = 2,
24
+ ) -> None:
25
+ """
26
+ A transformer decoder that attends to an input image using
27
+ queries whose positional embedding is supplied.
28
+
29
+ Args:
30
+ depth (int): number of layers in the transformer
31
+ embedding_dim (int): the channel dimension for the input embeddings
32
+ num_heads (int): the number of heads for multihead attention. Must
33
+ divide embedding_dim
34
+ mlp_dim (int): the channel dimension internal to the MLP block
35
+ activation (nn.Module): the activation to use in the MLP block
36
+ """
37
+ super().__init__()
38
+ self.depth = depth
39
+ self.embedding_dim = embedding_dim
40
+ self.num_heads = num_heads
41
+ self.mlp_dim = mlp_dim
42
+ self.layers = nn.ModuleList()
43
+
44
+ for i in range(depth):
45
+ self.layers.append(
46
+ TwoWayAttentionBlock(
47
+ embedding_dim=embedding_dim,
48
+ num_heads=num_heads,
49
+ mlp_dim=mlp_dim,
50
+ activation=activation,
51
+ attention_downsample_rate=attention_downsample_rate,
52
+ skip_first_layer_pe=(i == 0),
53
+ )
54
+ )
55
+
56
+ self.final_attn_token_to_image = Attention(
57
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
58
+ )
59
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
60
+
61
+ def forward(
62
+ self,
63
+ image_embedding: Tensor,
64
+ image_pe: Tensor,
65
+ point_embedding: Tensor,
66
+ ) -> Tuple[Tensor, Tensor]:
67
+ """
68
+ Args:
69
+ image_embedding (torch.Tensor): image to attend to. Should be shape
70
+ B x embedding_dim x h x w for any h and w.
71
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
72
+ have the same shape as image_embedding.
73
+ point_embedding (torch.Tensor): the embedding to add to the query points.
74
+ Must have shape B x N_points x embedding_dim for any N_points.
75
+
76
+ Returns:
77
+ torch.Tensor: the processed point_embedding
78
+ torch.Tensor: the processed image_embedding
79
+ """
80
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
81
+ bs, c, h, w = image_embedding.shape
82
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
83
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
84
+
85
+ # Prepare queries
86
+ queries = point_embedding
87
+ keys = image_embedding
88
+
89
+ # Apply transformer blocks and final layernorm
90
+ for layer in self.layers:
91
+ queries, keys = layer(
92
+ queries=queries,
93
+ keys=keys,
94
+ query_pe=point_embedding,
95
+ key_pe=image_pe,
96
+ )
97
+
98
+ # Apply the final attention layer from the points to the image
99
+ q = queries + point_embedding
100
+ k = keys + image_pe
101
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
102
+ queries = queries + attn_out
103
+ queries = self.norm_final_attn(queries)
104
+
105
+ return queries, keys
106
+
107
+
108
+ class TwoWayAttentionBlock(nn.Module):
109
+ def __init__(
110
+ self,
111
+ embedding_dim: int,
112
+ num_heads: int,
113
+ mlp_dim: int = 2048,
114
+ activation: Type[nn.Module] = nn.ReLU,
115
+ attention_downsample_rate: int = 2,
116
+ skip_first_layer_pe: bool = False,
117
+ ) -> None:
118
+ """
119
+ A transformer block with four layers: (1) self-attention of sparse
120
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
121
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
122
+ inputs.
123
+
124
+ Arguments:
125
+ embedding_dim (int): the channel dimension of the embeddings
126
+ num_heads (int): the number of heads in the attention layers
127
+ mlp_dim (int): the hidden dimension of the mlp block
128
+ activation (nn.Module): the activation of the mlp block
129
+ skip_first_layer_pe (bool): skip the PE on the first layer
130
+ """
131
+ super().__init__()
132
+ self.self_attn = Attention(embedding_dim, num_heads)
133
+ self.norm1 = nn.LayerNorm(embedding_dim)
134
+
135
+ self.cross_attn_token_to_image = Attention(
136
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
137
+ )
138
+ self.norm2 = nn.LayerNorm(embedding_dim)
139
+
140
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
141
+ self.norm3 = nn.LayerNorm(embedding_dim)
142
+
143
+ self.norm4 = nn.LayerNorm(embedding_dim)
144
+ self.cross_attn_image_to_token = Attention(
145
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
146
+ )
147
+
148
+ self.skip_first_layer_pe = skip_first_layer_pe
149
+
150
+ def forward(
151
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
152
+ ) -> Tuple[Tensor, Tensor]:
153
+ # Self attention block
154
+ if self.skip_first_layer_pe:
155
+ queries = self.self_attn(q=queries, k=queries, v=queries)
156
+ else:
157
+ q = queries + query_pe
158
+ attn_out = self.self_attn(q=q, k=q, v=queries)
159
+ queries = queries + attn_out
160
+ queries = self.norm1(queries)
161
+
162
+ # Cross attention block, tokens attending to image embedding
163
+ q = queries + query_pe
164
+ k = keys + key_pe
165
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
166
+ queries = queries + attn_out
167
+ queries = self.norm2(queries)
168
+
169
+ # MLP block
170
+ mlp_out = self.mlp(queries)
171
+ queries = queries + mlp_out
172
+ queries = self.norm3(queries)
173
+
174
+ # Cross attention block, image embedding attending to tokens
175
+ q = queries + query_pe
176
+ k = keys + key_pe
177
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
178
+ keys = keys + attn_out
179
+ keys = self.norm4(keys)
180
+
181
+ return queries, keys
182
+
183
+
184
+ class Attention(nn.Module):
185
+ """
186
+ An attention layer that allows for downscaling the size of the embedding
187
+ after projection to queries, keys, and values.
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ embedding_dim: int,
193
+ num_heads: int,
194
+ downsample_rate: int = 1,
195
+ ) -> None:
196
+ super().__init__()
197
+ self.embedding_dim = embedding_dim
198
+ self.internal_dim = embedding_dim // downsample_rate
199
+ self.num_heads = num_heads
200
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
201
+
202
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
203
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
204
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
205
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
206
+
207
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
208
+ b, n, c = x.shape
209
+ x = x.reshape(b, n, num_heads, c // num_heads)
210
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
211
+
212
+ def _recombine_heads(self, x: Tensor) -> Tensor:
213
+ b, n_heads, n_tokens, c_per_head = x.shape
214
+ x = x.transpose(1, 2)
215
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
216
+
217
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
218
+ # Input projections
219
+ q = self.q_proj(q)
220
+ k = self.k_proj(k)
221
+ v = self.v_proj(v)
222
+
223
+ # Separate into heads
224
+ q = self._separate_heads(q, self.num_heads)
225
+ k = self._separate_heads(k, self.num_heads)
226
+ v = self._separate_heads(v, self.num_heads)
227
+
228
+ # Attention
229
+ _, _, _, c_per_head = q.shape
230
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
231
+ attn = attn / math.sqrt(c_per_head)
232
+ attn = torch.softmax(attn, dim=-1)
233
+
234
+ # Get output
235
+ out = attn @ v
236
+ out = self._recombine_heads(out)
237
+ out = self.out_proj(out)
238
+
239
+ return out
app/sam_hq/transforms.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11
+
12
+ from copy import deepcopy
13
+ from typing import Tuple
14
+
15
+
16
+ class ResizeLongestSide:
17
+ """
18
+ Resizes images to the longest side 'target_length', as well as provides
19
+ methods for resizing coordinates and boxes. Provides methods for
20
+ transforming both numpy array and batched torch tensors.
21
+ """
22
+
23
+ def __init__(self, target_length: int) -> None:
24
+ self.target_length = target_length
25
+
26
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
27
+ """
28
+ Expects a numpy array with shape HxWxC in uint8 format.
29
+ """
30
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31
+ return np.array(resize(to_pil_image(image), target_size))
32
+
33
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34
+ """
35
+ Expects a numpy array of length 2 in the final dimension. Requires the
36
+ original image size in (H, W) format.
37
+ """
38
+ old_h, old_w = original_size
39
+ new_h, new_w = self.get_preprocess_shape(
40
+ original_size[0], original_size[1], self.target_length
41
+ )
42
+ coords = deepcopy(coords).astype(float)
43
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
44
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
45
+ return coords
46
+
47
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48
+ """
49
+ Expects a numpy array shape Bx4. Requires the original image size
50
+ in (H, W) format.
51
+ """
52
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53
+ return boxes.reshape(-1, 4)
54
+
55
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Expects batched images with shape BxCxHxW and float format. This
58
+ transformation may not exactly match apply_image. apply_image is
59
+ the transformation expected by the model.
60
+ """
61
+ # Expects an image in BCHW format. May not exactly match apply_image.
62
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63
+ return F.interpolate(
64
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
65
+ )
66
+
67
+ def apply_coords_torch(
68
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
69
+ ) -> torch.Tensor:
70
+ """
71
+ Expects a torch tensor with length 2 in the last dimension. Requires the
72
+ original image size in (H, W) format.
73
+ """
74
+ old_h, old_w = original_size
75
+ new_h, new_w = self.get_preprocess_shape(
76
+ original_size[0], original_size[1], self.target_length
77
+ )
78
+ coords = deepcopy(coords).to(torch.float)
79
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
80
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
81
+ return coords
82
+
83
+ def apply_boxes_torch(
84
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85
+ ) -> torch.Tensor:
86
+ """
87
+ Expects a torch tensor with shape Bx4. Requires the original image
88
+ size in (H, W) format.
89
+ """
90
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91
+ return boxes.reshape(-1, 4)
92
+
93
+ @staticmethod
94
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95
+ """
96
+ Compute the output size given input size and target long side length.
97
+ """
98
+ scale = long_side_length * 1.0 / max(oldh, oldw)
99
+ newh, neww = oldh * scale, oldw * scale
100
+ neww = int(neww + 0.5)
101
+ newh = int(newh + 0.5)
102
+ return (newh, neww)