integrate with transformers
Browse filesthis pr will fix the integration with the `transformers` library
following https://github.com/huggingface/transformers/pull/29262 there is no need to further use the method i used in [pr-9](https://huggingface.co/briaai/RMBG-1.4/discussions/9)
I have fixed the `requirements.txt` and the `README.md` files for future users beforehand so no need to change those.
# this is some exrtra bit of code to test the pr before mergin
```
wget https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt && pip install -qr requirements.txt
pip install -q git+https://github.com/huggingface/transformers.git
```
how to use before
```python
from transformers import pipeline
pipe = pipeline("image-segmentation",
model="briaai/RMBG-1.4",
revision ="refs/pr/21", # only when using the pr
trust_remote_code=True)
pipe("image_path.webp",out_name="myout.png") # applies mask and saves the extracted image as `myout.png`
```
also friendly tag to
@OriLib
Sincerely,
Hafedh Hichri
- MyConfig.py +14 -0
- MyPipe.py +76 -0
- README.md +22 -32
- briarmbg.py +9 -7
- config.json +24 -3
- requirements.txt +2 -1
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import PretrainedConfig
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
class RMBGConfig(PretrainedConfig):
|
6 |
+
model_type = "SegformerForSemanticSegmentation"
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_ch=3,
|
10 |
+
out_ch=1,
|
11 |
+
**kwargs):
|
12 |
+
self.in_ch = in_ch
|
13 |
+
self.out_ch = out_ch
|
14 |
+
super().__init__(**kwargs)
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch, os
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.transforms.functional import normalize
|
5 |
+
import numpy as np
|
6 |
+
from transformers import Pipeline
|
7 |
+
from skimage import io
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
class RMBGPipe(Pipeline):
|
11 |
+
def __init__(self,**kwargs):
|
12 |
+
Pipeline.__init__(self,**kwargs)
|
13 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
self.model.to(self.device)
|
15 |
+
self.model.eval()
|
16 |
+
|
17 |
+
def _sanitize_parameters(self, **kwargs):
|
18 |
+
# parse parameters
|
19 |
+
preprocess_kwargs = {}
|
20 |
+
postprocess_kwargs = {}
|
21 |
+
if "model_input_size" in kwargs :
|
22 |
+
preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
|
23 |
+
if "out_name" in kwargs:
|
24 |
+
postprocess_kwargs["out_name"] = kwargs["out_name"]
|
25 |
+
return preprocess_kwargs, {}, postprocess_kwargs
|
26 |
+
|
27 |
+
def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
|
28 |
+
# preprocess the input
|
29 |
+
orig_im = io.imread(im_path)
|
30 |
+
orig_im_size = orig_im.shape[0:2]
|
31 |
+
image = self.preprocess_image(orig_im, model_input_size).to(self.device)
|
32 |
+
inputs = {
|
33 |
+
"image":image,
|
34 |
+
"orig_im_size":orig_im_size,
|
35 |
+
"im_path" : im_path
|
36 |
+
}
|
37 |
+
return inputs
|
38 |
+
|
39 |
+
def _forward(self,inputs):
|
40 |
+
result = self.model(inputs.pop("image"))
|
41 |
+
inputs["result"] = result
|
42 |
+
return inputs
|
43 |
+
def postprocess(self,inputs,out_name = ""):
|
44 |
+
result = inputs.pop("result")
|
45 |
+
orig_im_size = inputs.pop("orig_im_size")
|
46 |
+
im_path = inputs.pop("im_path")
|
47 |
+
result_image = self.postprocess_image(result[0][0], orig_im_size)
|
48 |
+
if out_name != "" :
|
49 |
+
# if out_name is specified we save the image using that name
|
50 |
+
pil_im = Image.fromarray(result_image)
|
51 |
+
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
52 |
+
orig_image = Image.open(im_path)
|
53 |
+
no_bg_image.paste(orig_image, mask=pil_im)
|
54 |
+
no_bg_image.save(out_name)
|
55 |
+
else :
|
56 |
+
return result_image
|
57 |
+
|
58 |
+
# utilities functions
|
59 |
+
def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
|
60 |
+
# same as utilities.py with minor modification
|
61 |
+
if len(im.shape) < 3:
|
62 |
+
im = im[:, :, np.newaxis]
|
63 |
+
# orig_im_size=im.shape[0:2]
|
64 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
65 |
+
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
|
66 |
+
image = torch.divide(im_tensor,255.0)
|
67 |
+
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
68 |
+
return image
|
69 |
+
def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
|
70 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
|
71 |
+
ma = torch.max(result)
|
72 |
+
mi = torch.min(result)
|
73 |
+
result = (result-mi)/(ma-mi)
|
74 |
+
im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
|
75 |
+
im_array = np.squeeze(im_array)
|
76 |
+
return im_array
|
@@ -2,7 +2,7 @@
|
|
2 |
license: other
|
3 |
license_name: bria-rmbg-1.4
|
4 |
license_link: https://bria.ai/bria-huggingface-model-license-agreement/
|
5 |
-
pipeline_tag: image-
|
6 |
tags:
|
7 |
- remove background
|
8 |
- background
|
@@ -10,6 +10,7 @@ tags:
|
|
10 |
- Pytorch
|
11 |
- vision
|
12 |
- legal liability
|
|
|
13 |
|
14 |
extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
|
15 |
extra_gated_fields:
|
@@ -94,43 +95,32 @@ These modifications significantly improve the model’s accuracy and effectivene
|
|
94 |
|
95 |
## Installation
|
96 |
```bash
|
97 |
-
|
98 |
-
cd RMBG-1.4/
|
99 |
-
pip install -r requirements.txt
|
100 |
```
|
101 |
|
102 |
## Usage
|
103 |
|
104 |
```python
|
105 |
-
|
106 |
-
import torch, os
|
107 |
-
from PIL import Image
|
108 |
-
from briarmbg import BriaRMBG
|
109 |
-
from utilities import preprocess_image, postprocess_image
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
net.to(device)
|
117 |
-
|
118 |
-
# prepare input
|
119 |
-
model_input_size = [1024,1024]
|
120 |
-
orig_im = io.imread(im_path)
|
121 |
-
orig_im_size = orig_im.shape[0:2]
|
122 |
-
image = preprocess_image(orig_im, model_input_size).to(device)
|
123 |
-
|
124 |
-
# inference
|
125 |
-
result=net(image)
|
126 |
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
#
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
```
|
|
|
2 |
license: other
|
3 |
license_name: bria-rmbg-1.4
|
4 |
license_link: https://bria.ai/bria-huggingface-model-license-agreement/
|
5 |
+
pipeline_tag: image-segmentation
|
6 |
tags:
|
7 |
- remove background
|
8 |
- background
|
|
|
10 |
- Pytorch
|
11 |
- vision
|
12 |
- legal liability
|
13 |
+
- transformers
|
14 |
|
15 |
extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
|
16 |
extra_gated_fields:
|
|
|
95 |
|
96 |
## Installation
|
97 |
```bash
|
98 |
+
wget https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt && pip install -qr requirements.txt
|
|
|
|
|
99 |
```
|
100 |
|
101 |
## Usage
|
102 |
|
103 |
```python
|
104 |
+
# How to use
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
either load the model
|
107 |
+
```python
|
108 |
+
from transformers import AutoModelForImageSegmentation
|
109 |
+
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4",trust_remote_code=True)
|
110 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
or load the pipeline
|
113 |
+
```python
|
114 |
+
from transformers import pipeline
|
115 |
+
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
|
116 |
+
numpy_mask = pipe("img_path") # outputs numpy mask
|
117 |
+
pipe("image_path",out_name="myout.png") # applies mask and saves the extracted image as `myout.png`
|
118 |
+
```
|
119 |
|
120 |
+
# parameters :
|
121 |
+
for the pipeline you can use the following parameters :
|
122 |
+
* `model_input_size` : default to [1024,1024]
|
123 |
+
* `out_name` : if specified it will use the numpy mask to extract the image and save it using the `out_name`
|
124 |
+
* `preprocess_image` : original method created by briaai
|
125 |
+
* `postprocess_image` : original method created by briaai
|
126 |
```
|
@@ -1,7 +1,9 @@
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
-
from
|
|
|
5 |
|
6 |
class REBNCONV(nn.Module):
|
7 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
@@ -345,12 +347,12 @@ class myrebnconv(nn.Module):
|
|
345 |
return self.rl(self.bn(self.conv(x)))
|
346 |
|
347 |
|
348 |
-
class BriaRMBG(
|
349 |
-
|
350 |
-
def __init__(self,config
|
351 |
-
super(
|
352 |
-
in_ch=config
|
353 |
-
out_ch=config
|
354 |
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
355 |
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
356 |
|
|
|
1 |
+
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
from .MyConfig import RMBGConfig
|
7 |
|
8 |
class REBNCONV(nn.Module):
|
9 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
|
|
347 |
return self.rl(self.bn(self.conv(x)))
|
348 |
|
349 |
|
350 |
+
class BriaRMBG(PreTrainedModel):
|
351 |
+
config_class = RMBGConfig
|
352 |
+
def __init__(self,config):
|
353 |
+
super().__init__(config)
|
354 |
+
in_ch = config.in_ch # 3
|
355 |
+
out_ch = config.out_ch # 1
|
356 |
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
357 |
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
358 |
|
@@ -1,4 +1,25 @@
|
|
1 |
{
|
2 |
-
"
|
3 |
-
"
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "briaai/RMBG-1.4",
|
3 |
+
"architectures": [
|
4 |
+
"BriaRMBG"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "MyConfig.RMBGConfig",
|
8 |
+
"AutoModelForImageSegmentation": "briarmbg.BriaRMBG"
|
9 |
+
},
|
10 |
+
"custom_pipelines": {
|
11 |
+
"image-segmentation": {
|
12 |
+
"impl": "MyPipe.RMBGPipe",
|
13 |
+
"pt": [
|
14 |
+
"AutoModelForImageSegmentation"
|
15 |
+
],
|
16 |
+
"tf": [],
|
17 |
+
"type": "image"
|
18 |
+
}
|
19 |
+
},
|
20 |
+
"in_ch": 3,
|
21 |
+
"model_type": "SegformerForSemanticSegmentation",
|
22 |
+
"out_ch": 1,
|
23 |
+
"torch_dtype": "float32",
|
24 |
+
"transformers_version": "4.38.0.dev0"
|
25 |
+
}
|
@@ -4,4 +4,5 @@ pillow
|
|
4 |
numpy
|
5 |
typing
|
6 |
scikit-image
|
7 |
-
huggingface_hub
|
|
|
|
4 |
numpy
|
5 |
typing
|
6 |
scikit-image
|
7 |
+
huggingface_hub
|
8 |
+
git+https://github.com/huggingface/transformers.git
|