MaxwellMeyer commited on
Commit
2ca249a
1 Parent(s): bbb0aa9

Created inference file, moved torch no grad to model.py, removed timm user warning and used a different photo from the demo for default inference image.

Browse files
Files changed (3) hide show
  1. __pycache__/model.cpython-39.pyc +0 -0
  2. inference.py +17 -0
  3. model.py +2 -1
__pycache__/model.cpython-39.pyc ADDED
Binary file (29.3 kB). View file
 
inference.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model
2
+ from PIL import Image
3
+ import torch
4
+
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+ file = "./image.png" # input image
9
+
10
+ model = model.BEN_Base().to(device).eval() #init pipeline
11
+
12
+ model.loadcheckpoints("./BEN_Base.pth")
13
+ image = Image.open(file)
14
+ mask, foreground = model.inference(image)
15
+
16
+ mask.save("./mask.png")
17
+ foreground.save("./foreground.png")
model.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
  import torch.utils.checkpoint as checkpoint
7
  from einops import rearrange
8
  from PIL import Image, ImageFilter, ImageOps
9
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
10
  from torchvision import transforms
11
 
12
  class Mlp(nn.Module):
@@ -887,6 +887,7 @@ class BEN_Base(nn.Module):
887
 
888
  return final_output.sigmoid()
889
 
 
890
  def inference(self,image):
891
  image, h, w,original_image = rgb_loader_refiner(image)
892
 
 
6
  import torch.utils.checkpoint as checkpoint
7
  from einops import rearrange
8
  from PIL import Image, ImageFilter, ImageOps
9
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
10
  from torchvision import transforms
11
 
12
  class Mlp(nn.Module):
 
887
 
888
  return final_output.sigmoid()
889
 
890
+ @torch.no_grad()
891
  def inference(self,image):
892
  image, h, w,original_image = rgb_loader_refiner(image)
893