Max Meyer commited on
Commit
a9521f8
·
verified ·
1 Parent(s): 29c887d

Fix example, load weights safely and remove extra whitespace (#2)

Browse files

- Fix example, load weights safely, remove whitespace (e840827775a894eaf8241a70535c2921b36bd430)

Files changed (3) hide show
  1. README.md +11 -11
  2. image2.png +0 -0
  3. model.py +27 -27
README.md CHANGED
@@ -13,34 +13,34 @@ tags:
13
 
14
  # BEN - Background Erase Network (Beta Base Model)
15
 
16
- BEN is a deep learning model designed to automatically remove backgrounds from images, producing both a mask and a foreground image.
17
 
18
  - MADE IN AMERICA
19
 
20
  ## Quick Start Code
 
21
  ```python
22
- from BEN import model
23
  from PIL import Image
24
  import torch
25
 
26
 
27
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
 
29
- file = "./image2.jpg" # input image
30
 
31
- model = model.BEN_Base().to(device).eval() #init pipeline
32
 
33
- model.loadcheckpoints("./BEN/BEN_Base.pth")
34
  image = Image.open(file)
35
- mask, foreground = model.inference(image)
 
36
 
37
  mask.save("./mask.png")
38
  foreground.save("./foreground.png")
39
-
40
-
41
-
42
  ```
43
- # BEN SOA Benchmarks on Disk 5k Eval
 
44
 
45
  ![Demo Results](demo.jpg)
46
 
@@ -84,4 +84,4 @@ foreground.save("./foreground.png")
84
 
85
  ## Installation
86
  1. Clone Repo
87
- 2. Install requirements.txt
 
13
 
14
  # BEN - Background Erase Network (Beta Base Model)
15
 
16
+ BEN is a deep learning model designed to automatically remove backgrounds from images, producing both a mask and a foreground image.
17
 
18
  - MADE IN AMERICA
19
 
20
  ## Quick Start Code
21
+
22
  ```python
23
+ import model
24
  from PIL import Image
25
  import torch
26
 
27
 
28
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
 
30
+ file = "./image2.png" # input image
31
 
32
+ model = model.BEN_Base().to(device).eval() #init pipeline
33
 
34
+ model.loadcheckpoints("./BEN_Base.pth")
35
  image = Image.open(file)
36
+ with torch.no_grad():
37
+ mask, foreground = model.inference(image)
38
 
39
  mask.save("./mask.png")
40
  foreground.save("./foreground.png")
 
 
 
41
  ```
42
+
43
+ # BEN SOA Benchmarks on Disk 5k Eval
44
 
45
  ![Demo Results](demo.jpg)
46
 
 
84
 
85
  ## Installation
86
  1. Clone Repo
87
+ 2. Install requirements.txt
image2.png ADDED
model.py CHANGED
@@ -560,7 +560,7 @@ class SwinTransformer(nn.Module):
560
  # interpolate the position embedding to the corresponding size
561
  absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
562
  x = (x + absolute_pos_embed) # B Wh*Ww C
563
-
564
  outs = [x.contiguous()]
565
  x = x.flatten(2).transpose(1, 2)
566
  x = self.pos_drop(x)
@@ -634,7 +634,7 @@ class PositionEmbeddingSine:
634
  scale = 2 * math.pi
635
  self.scale = scale
636
  self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
637
-
638
  def __call__(self, b, h, w):
639
  device = self.dim_t.device
640
  mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
@@ -646,18 +646,18 @@ class PositionEmbeddingSine:
646
  eps = 1e-6
647
  y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
648
  x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
649
-
650
  dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
651
  pos_x = x_embed[:, :, :, None] / dim_t
652
  pos_y = y_embed[:, :, :, None] / dim_t
653
-
654
  pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
655
  pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
656
-
657
- return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
658
 
 
659
 
660
- class MCLM(nn.Module):
 
661
  def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
662
  super(MCLM, self).__init__()
663
  self.attention = nn.ModuleList([
@@ -688,10 +688,10 @@ class MCLM(nn.Module):
688
  l: 4,c,h,w
689
  g: 1,c,h,w
690
  """
691
- b, c, h, w = l.size()
692
  # 4,c,h,w -> 1,c,2h,2w
693
  concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
694
-
695
  pools = []
696
  for pool_ratio in self.pool_ratios:
697
  # b,c,h,w
@@ -734,7 +734,7 @@ class MCLM(nn.Module):
734
  l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
735
  l_hw_b_c = self.norm1(l_hw_b_c)
736
  l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
737
- l_hw_b_c = self.norm2(l_hw_b_c)
738
 
739
  l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
740
  return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
@@ -770,42 +770,42 @@ class MCRM(nn.Module):
770
 
771
  def forward(self, x):
772
  device = x.device
773
- b, c, h, w = x.size()
774
  loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
775
-
776
  patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
777
-
778
  token_attention_map = self.sigmoid(self.sal_conv(glb))
779
  token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
780
  loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
781
-
782
  pools = []
783
  for pool_ratio in self.pool_ratios:
784
  tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
785
  pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
786
  pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
787
-
788
  pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
789
  loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
790
-
791
  outputs = []
792
  for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
793
  v = pools[i]
794
  k = v
795
  outputs.append(self.attention[i](q, k, v)[0])
796
-
797
- outputs = torch.cat(outputs, 1)
798
  src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
799
  src = self.norm1(src)
800
  src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
801
  src = self.norm2(src)
802
  src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
803
  glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
804
-
805
  return torch.cat((src, glb), 0), token_attention_map
806
 
807
 
808
- class BEN_Base(nn.Module):
809
  def __init__(self):
810
  super().__init__()
811
 
@@ -868,7 +868,7 @@ class BEN_Base(nn.Module):
868
  e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
869
 
870
  e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
871
- e4 = self.conv4(e4)
872
  e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
873
  e3 = self.conv3(e3)
874
  e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
@@ -909,11 +909,11 @@ class BEN_Base(nn.Module):
909
  return blurred_mask, foreground
910
 
911
  def loadcheckpoints(self,model_path):
912
- model_dict = torch.load(model_path,map_location="cpu")
913
  self.load_state_dict(model_dict['model_state_dict'], strict=True)
914
  del model_path
915
 
916
-
917
 
918
 
919
  def rgb_loader_refiner( original_image):
@@ -923,16 +923,16 @@ def rgb_loader_refiner( original_image):
923
  # Convert to RGB if necessary
924
  if image.mode != 'RGB':
925
  image = image.convert('RGB')
926
-
927
  # Resize the image
928
  image = image.resize((1024, 1024), resample=Image.LANCZOS)
929
 
930
- return image.convert('RGB'), h, w,original_image
931
-
932
  # Define the image transformation
933
  img_transform = transforms.Compose([
934
  transforms.ToTensor(),
935
- transforms.ConvertImageDtype(torch.float32),
936
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
937
  ])
938
 
 
560
  # interpolate the position embedding to the corresponding size
561
  absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
562
  x = (x + absolute_pos_embed) # B Wh*Ww C
563
+
564
  outs = [x.contiguous()]
565
  x = x.flatten(2).transpose(1, 2)
566
  x = self.pos_drop(x)
 
634
  scale = 2 * math.pi
635
  self.scale = scale
636
  self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
637
+
638
  def __call__(self, b, h, w):
639
  device = self.dim_t.device
640
  mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
 
646
  eps = 1e-6
647
  y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
648
  x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
649
+
650
  dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
651
  pos_x = x_embed[:, :, :, None] / dim_t
652
  pos_y = y_embed[:, :, :, None] / dim_t
653
+
654
  pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
655
  pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
 
 
656
 
657
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
658
 
659
+
660
+ class MCLM(nn.Module):
661
  def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
662
  super(MCLM, self).__init__()
663
  self.attention = nn.ModuleList([
 
688
  l: 4,c,h,w
689
  g: 1,c,h,w
690
  """
691
+ b, c, h, w = l.size()
692
  # 4,c,h,w -> 1,c,2h,2w
693
  concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
694
+
695
  pools = []
696
  for pool_ratio in self.pool_ratios:
697
  # b,c,h,w
 
734
  l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
735
  l_hw_b_c = self.norm1(l_hw_b_c)
736
  l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
737
+ l_hw_b_c = self.norm2(l_hw_b_c)
738
 
739
  l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
740
  return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
 
770
 
771
  def forward(self, x):
772
  device = x.device
773
+ b, c, h, w = x.size()
774
  loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
775
+
776
  patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
777
+
778
  token_attention_map = self.sigmoid(self.sal_conv(glb))
779
  token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
780
  loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
781
+
782
  pools = []
783
  for pool_ratio in self.pool_ratios:
784
  tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
785
  pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
786
  pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
787
+
788
  pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
789
  loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
790
+
791
  outputs = []
792
  for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
793
  v = pools[i]
794
  k = v
795
  outputs.append(self.attention[i](q, k, v)[0])
796
+
797
+ outputs = torch.cat(outputs, 1)
798
  src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
799
  src = self.norm1(src)
800
  src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
801
  src = self.norm2(src)
802
  src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
803
  glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
804
+
805
  return torch.cat((src, glb), 0), token_attention_map
806
 
807
 
808
+ class BEN_Base(nn.Module):
809
  def __init__(self):
810
  super().__init__()
811
 
 
868
  e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
869
 
870
  e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
871
+ e4 = self.conv4(e4)
872
  e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
873
  e3 = self.conv3(e3)
874
  e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
 
909
  return blurred_mask, foreground
910
 
911
  def loadcheckpoints(self,model_path):
912
+ model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
913
  self.load_state_dict(model_dict['model_state_dict'], strict=True)
914
  del model_path
915
 
916
+
917
 
918
 
919
  def rgb_loader_refiner( original_image):
 
923
  # Convert to RGB if necessary
924
  if image.mode != 'RGB':
925
  image = image.convert('RGB')
926
+
927
  # Resize the image
928
  image = image.resize((1024, 1024), resample=Image.LANCZOS)
929
 
930
+ return image.convert('RGB'), h, w,original_image
931
+
932
  # Define the image transformation
933
  img_transform = transforms.Compose([
934
  transforms.ToTensor(),
935
+ transforms.ConvertImageDtype(torch.float32),
936
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
937
  ])
938