Spaces:
Runtime error
Runtime error
remove subprocess, U2Net for bg removal
Browse files- .gitignore +1 -0
- PIFu/apps/eval.py +57 -27
- PIFu/inputs/.gitignore +0 -2
- PIFu/lib/options.py +5 -1
- PIFu/results/spaces_demo/.gitignore +0 -2
- app.py → PIFu/spaces.py +79 -33
- README.md +1 -1
- remove_bg.py +0 -58
- requirements.txt +3 -1
.gitignore
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
# Python build
|
2 |
.eggs/
|
3 |
gradio.egg-info/*
|
|
|
1 |
+
results/
|
2 |
# Python build
|
3 |
.eggs/
|
4 |
gradio.egg-info/*
|
PIFu/apps/eval.py
CHANGED
@@ -1,28 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
|
4 |
-
sys.path.insert(0, os.path.abspath(
|
|
|
5 |
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
6 |
|
7 |
-
import time
|
8 |
-
import json
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
from torch.utils.data import DataLoader
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
from lib.sample_util import *
|
16 |
-
from lib.train_util import *
|
17 |
-
from lib.model import *
|
18 |
-
|
19 |
-
from PIL import Image
|
20 |
-
import torchvision.transforms as transforms
|
21 |
-
import glob
|
22 |
-
import tqdm
|
23 |
-
|
24 |
-
# get options
|
25 |
-
opt = BaseOptions().parse()
|
26 |
|
27 |
class Evaluator:
|
28 |
def __init__(self, opt, projection_mode='orthogonal'):
|
@@ -34,19 +33,22 @@ class Evaluator:
|
|
34 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
35 |
])
|
36 |
# set cuda
|
37 |
-
cuda = torch.device(
|
|
|
38 |
|
39 |
# create net
|
40 |
netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
|
41 |
print('Using Network: ', netG.name)
|
42 |
|
43 |
if opt.load_netG_checkpoint_path:
|
44 |
-
netG.load_state_dict(torch.load(
|
|
|
45 |
|
46 |
if opt.load_netC_checkpoint_path is not None:
|
47 |
print('loading for net C ...', opt.load_netC_checkpoint_path)
|
48 |
netC = ResBlkPIFuNet(opt).to(device=cuda)
|
49 |
-
netC.load_state_dict(torch.load(
|
|
|
50 |
else:
|
51 |
netC = None
|
52 |
|
@@ -87,6 +89,30 @@ class Evaluator:
|
|
87 |
'b_max': B_MAX,
|
88 |
}
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def eval(self, data, use_octree=False):
|
91 |
'''
|
92 |
Evaluate a data point
|
@@ -98,18 +124,22 @@ class Evaluator:
|
|
98 |
self.netG.eval()
|
99 |
if self.netC:
|
100 |
self.netC.eval()
|
101 |
-
save_path = '%s/%s/result_%s.obj' % (
|
|
|
102 |
if self.netC:
|
103 |
-
gen_mesh_color(opt, self.netG, self.netC, self.cuda,
|
|
|
104 |
else:
|
105 |
-
gen_mesh(opt, self.netG, self.cuda, data,
|
|
|
106 |
|
107 |
|
108 |
if __name__ == '__main__':
|
109 |
evaluator = Evaluator(opt)
|
110 |
|
111 |
test_images = glob.glob(os.path.join(opt.test_folder_path, '*'))
|
112 |
-
test_images = [f for f in test_images if (
|
|
|
113 |
test_masks = [f[:-4]+'_mask.png' for f in test_images]
|
114 |
|
115 |
print("num; ", len(test_masks))
|
@@ -120,4 +150,4 @@ if __name__ == '__main__':
|
|
120 |
data = evaluator.load_image(image_path, mask_path)
|
121 |
evaluator.eval(data, True)
|
122 |
except Exception as e:
|
123 |
-
|
|
|
1 |
+
import tqdm
|
2 |
+
import glob
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from PIL import Image
|
5 |
+
from lib.model import *
|
6 |
+
from lib.train_util import *
|
7 |
+
from lib.sample_util import *
|
8 |
+
from lib.mesh_util import *
|
9 |
+
# from lib.options import BaseOptions
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
import json
|
14 |
+
import time
|
15 |
import sys
|
16 |
import os
|
17 |
|
18 |
+
sys.path.insert(0, os.path.abspath(
|
19 |
+
os.path.join(os.path.dirname(__file__), '..')))
|
20 |
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# # get options
|
24 |
+
# opt = BaseOptions().parse()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
class Evaluator:
|
27 |
def __init__(self, opt, projection_mode='orthogonal'):
|
|
|
33 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
34 |
])
|
35 |
# set cuda
|
36 |
+
cuda = torch.device(
|
37 |
+
'cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu')
|
38 |
|
39 |
# create net
|
40 |
netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
|
41 |
print('Using Network: ', netG.name)
|
42 |
|
43 |
if opt.load_netG_checkpoint_path:
|
44 |
+
netG.load_state_dict(torch.load(
|
45 |
+
opt.load_netG_checkpoint_path, map_location=cuda))
|
46 |
|
47 |
if opt.load_netC_checkpoint_path is not None:
|
48 |
print('loading for net C ...', opt.load_netC_checkpoint_path)
|
49 |
netC = ResBlkPIFuNet(opt).to(device=cuda)
|
50 |
+
netC.load_state_dict(torch.load(
|
51 |
+
opt.load_netC_checkpoint_path, map_location=cuda))
|
52 |
else:
|
53 |
netC = None
|
54 |
|
|
|
89 |
'b_max': B_MAX,
|
90 |
}
|
91 |
|
92 |
+
def load_image_from_memory(self, image_path, mask_path, img_name):
|
93 |
+
# Calib
|
94 |
+
B_MIN = np.array([-1, -1, -1])
|
95 |
+
B_MAX = np.array([1, 1, 1])
|
96 |
+
projection_matrix = np.identity(4)
|
97 |
+
projection_matrix[1, 1] = -1
|
98 |
+
calib = torch.Tensor(projection_matrix).float()
|
99 |
+
# Mask
|
100 |
+
mask = Image.fromarray(mask_path).convert('L')
|
101 |
+
mask = transforms.Resize(self.load_size)(mask)
|
102 |
+
mask = transforms.ToTensor()(mask).float()
|
103 |
+
# image
|
104 |
+
image = Image.fromarray(image_path).convert('RGB')
|
105 |
+
image = self.to_tensor(image)
|
106 |
+
image = mask.expand_as(image) * image
|
107 |
+
return {
|
108 |
+
'name': img_name,
|
109 |
+
'img': image.unsqueeze(0),
|
110 |
+
'calib': calib.unsqueeze(0),
|
111 |
+
'mask': mask.unsqueeze(0),
|
112 |
+
'b_min': B_MIN,
|
113 |
+
'b_max': B_MAX,
|
114 |
+
}
|
115 |
+
|
116 |
def eval(self, data, use_octree=False):
|
117 |
'''
|
118 |
Evaluate a data point
|
|
|
124 |
self.netG.eval()
|
125 |
if self.netC:
|
126 |
self.netC.eval()
|
127 |
+
save_path = '%s/%s/result_%s.obj' % (
|
128 |
+
opt.results_path, opt.name, data['name'])
|
129 |
if self.netC:
|
130 |
+
gen_mesh_color(opt, self.netG, self.netC, self.cuda,
|
131 |
+
data, save_path, use_octree=use_octree)
|
132 |
else:
|
133 |
+
gen_mesh(opt, self.netG, self.cuda, data,
|
134 |
+
save_path, use_octree=use_octree)
|
135 |
|
136 |
|
137 |
if __name__ == '__main__':
|
138 |
evaluator = Evaluator(opt)
|
139 |
|
140 |
test_images = glob.glob(os.path.join(opt.test_folder_path, '*'))
|
141 |
+
test_images = [f for f in test_images if (
|
142 |
+
'png' in f or 'jpg' in f) and (not 'mask' in f)]
|
143 |
test_masks = [f[:-4]+'_mask.png' for f in test_images]
|
144 |
|
145 |
print("num; ", len(test_masks))
|
|
|
150 |
data = evaluator.load_image(image_path, mask_path)
|
151 |
evaluator.eval(data, True)
|
152 |
except Exception as e:
|
153 |
+
print("error:", e.args)
|
PIFu/inputs/.gitignore
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
*
|
2 |
-
!.gitignore
|
|
|
|
|
|
PIFu/lib/options.py
CHANGED
@@ -5,7 +5,7 @@ import os
|
|
5 |
class BaseOptions():
|
6 |
def __init__(self):
|
7 |
self.initialized = False
|
8 |
-
|
9 |
def initialize(self, parser):
|
10 |
# Datasets related
|
11 |
g_data = parser.add_argument_group('Data')
|
@@ -155,3 +155,7 @@ class BaseOptions():
|
|
155 |
def parse(self):
|
156 |
opt = self.gather_options()
|
157 |
return opt
|
|
|
|
|
|
|
|
|
|
5 |
class BaseOptions():
|
6 |
def __init__(self):
|
7 |
self.initialized = False
|
8 |
+
argparse
|
9 |
def initialize(self, parser):
|
10 |
# Datasets related
|
11 |
g_data = parser.add_argument_group('Data')
|
|
|
155 |
def parse(self):
|
156 |
opt = self.gather_options()
|
157 |
return opt
|
158 |
+
|
159 |
+
def parse_to_dict(self):
|
160 |
+
opt = self.gather_options()
|
161 |
+
return opt.__dict__
|
PIFu/results/spaces_demo/.gitignore
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
*
|
2 |
-
!.gitignore
|
|
|
|
|
|
app.py → PIFu/spaces.py
RENAMED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
try:
|
3 |
os.system("pip install --upgrade torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html")
|
4 |
except Exception as e:
|
@@ -7,67 +8,110 @@ except Exception as e:
|
|
7 |
from pydoc import describe
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
import gradio as gr
|
10 |
-
import subprocess
|
11 |
import os
|
12 |
-
import datetime
|
13 |
from PIL import Image
|
14 |
-
|
15 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
print(
|
18 |
"torch: ", torch.__version__,
|
19 |
-
"\ntorchvision: ",torchvision.__version__,
|
20 |
"\nskimage:", skimage.__version__
|
21 |
)
|
22 |
|
23 |
net_C = hf_hub_download("radames/PIFu-upright-standing", filename="net_C")
|
24 |
net_G = hf_hub_download("radames/PIFu-upright-standing", filename="net_G")
|
25 |
-
torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True)
|
26 |
|
27 |
-
remove_bg = RemoveBackground()
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
def process(img_path):
|
38 |
base = os.path.basename(img_path)
|
39 |
img_name = os.path.splitext(base)[0]
|
|
|
40 |
print("image name", img_name)
|
41 |
-
img_raw = Image.open(img_path)
|
|
|
42 |
img = img_raw.resize(
|
43 |
(800, int(800 * img_raw.size[1] / img_raw.size[0])),
|
44 |
Image.Resampling.LANCZOS)
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
except Exception as e:
|
52 |
print(e)
|
|
|
53 |
print("Aliging mask with input training image")
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
**env,
|
60 |
-
"INPUT_IMAGE_PATH": f'./inputs/{img_name}.png',
|
61 |
-
"VOL_RES": "128"},
|
62 |
-
cwd="PIFu").communicate()
|
63 |
-
print("DONE 3D model")
|
64 |
-
return f'./PIFu/results/spaces_demo/result_{img_name}.glb'
|
65 |
|
66 |
|
67 |
examples = [["./examples/" + img] for img in sorted(os.listdir("./examples/"))]
|
68 |
description = '''
|
69 |
# PIFu Clothed Human Digitization
|
70 |
-
|
71 |
<base target="_blank">
|
72 |
|
73 |
This is a demo for <a href="https://github.com/shunsukesaito/PIFu" target="_blank"> PIFu model </a>.
|
@@ -76,17 +120,17 @@ The pre-trained model has the following warning:
|
|
76 |
|
77 |
**The inference takes about 180seconds for a new image.**
|
78 |
|
79 |
-
<details>
|
80 |
<summary>More</summary>
|
81 |
|
82 |
-
|
83 |
|
84 |
* Julien and Clem
|
85 |
* [StyleGAN Humans](https://huggingface.co/spaces/hysts/StyleGAN-Human)
|
86 |
* [Renderpeople: Dennis](https://renderpeople.com)
|
87 |
|
88 |
|
89 |
-
|
90 |
* https://phorhum.github.io/
|
91 |
* https://github.com/yuliangxiu/icon
|
92 |
* https://shunsukesaito.github.io/PIFuHD/
|
@@ -102,6 +146,8 @@ iface = gr.Interface(
|
|
102 |
examples=examples,
|
103 |
allow_flagging="never",
|
104 |
cache_examples=True
|
|
|
|
|
105 |
)
|
106 |
|
107 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
+
from xml.etree.ElementPath import ops
|
3 |
try:
|
4 |
os.system("pip install --upgrade torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html")
|
5 |
except Exception as e:
|
|
|
8 |
from pydoc import describe
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
import gradio as gr
|
|
|
11 |
import os
|
12 |
+
from datetime import datetime
|
13 |
from PIL import Image
|
14 |
+
import torch
|
15 |
+
import torchvision
|
16 |
+
import skimage
|
17 |
+
import paddlehub
|
18 |
+
import numpy as np
|
19 |
+
from lib.options import BaseOptions
|
20 |
+
from apps.crop_img import process_img
|
21 |
+
from apps.eval import Evaluator
|
22 |
+
from types import SimpleNamespace
|
23 |
+
import trimesh
|
24 |
|
25 |
print(
|
26 |
"torch: ", torch.__version__,
|
27 |
+
"\ntorchvision: ", torchvision.__version__,
|
28 |
"\nskimage:", skimage.__version__
|
29 |
)
|
30 |
|
31 |
net_C = hf_hub_download("radames/PIFu-upright-standing", filename="net_C")
|
32 |
net_G = hf_hub_download("radames/PIFu-upright-standing", filename="net_G")
|
|
|
33 |
|
|
|
34 |
|
35 |
+
opt = BaseOptions()
|
36 |
+
opts = opt.parse_to_dict()
|
37 |
+
opts['batch_size'] = 1
|
38 |
+
opts['mlp_dim'] = [257, 1024, 512, 256, 128, 1]
|
39 |
+
opts['mlp_dim_color'] = [513, 1024, 512, 256, 128, 3]
|
40 |
+
opts['num_stack'] = 4
|
41 |
+
opts['num_hourglass'] = 2
|
42 |
+
opts['resolution'] = 128
|
43 |
+
opts['hg_down'] = 'ave_pool'
|
44 |
+
opts['norm'] = 'group'
|
45 |
+
opts['norm_color'] = 'group'
|
46 |
+
opts['load_netG_checkpoint_path'] = net_G
|
47 |
+
opts['load_netC_checkpoint_path'] = net_C
|
48 |
+
opts['results_path'] = "./results"
|
49 |
+
opts['name'] = "spaces_demo"
|
50 |
+
opts = SimpleNamespace(**opts)
|
51 |
+
evaluator = Evaluator(opts)
|
52 |
+
bg_remover_model = paddlehub.Module(name="U2Net")
|
53 |
|
54 |
|
55 |
def process(img_path):
|
56 |
base = os.path.basename(img_path)
|
57 |
img_name = os.path.splitext(base)[0]
|
58 |
+
print("\n\n\nStarting Process", datetime.now())
|
59 |
print("image name", img_name)
|
60 |
+
img_raw = Image.open(img_path).convert('RGB')
|
61 |
+
|
62 |
img = img_raw.resize(
|
63 |
(800, int(800 * img_raw.size[1] / img_raw.size[0])),
|
64 |
Image.Resampling.LANCZOS)
|
65 |
|
66 |
+
try:
|
67 |
+
# remove background
|
68 |
+
print("Removing Background")
|
69 |
+
masks = bg_remover_model.Segmentation(
|
70 |
+
images=[np.array(img)],
|
71 |
+
paths=None,
|
72 |
+
batch_size=1,
|
73 |
+
input_size=320,
|
74 |
+
output_dir='./PIFu/inputs',
|
75 |
+
visualization=False)
|
76 |
+
mask = masks[0]["mask"]
|
77 |
+
front = masks[0]["front"]
|
78 |
except Exception as e:
|
79 |
print(e)
|
80 |
+
|
81 |
print("Aliging mask with input training image")
|
82 |
+
print("Not aligned", front.shape, mask.shape)
|
83 |
+
img_new, msk_new = process_img(front, mask)
|
84 |
+
print("Aligned", img_new.shape, msk_new.shape)
|
85 |
+
|
86 |
+
try:
|
87 |
+
time = datetime.now()
|
88 |
+
data = evaluator.load_image_from_memory(img_new, msk_new, img_name)
|
89 |
+
print("Evaluating via PIFu", time)
|
90 |
+
evaluator.eval(data, True)
|
91 |
+
print("Success Evaluating via PIFu", datetime.now() - time)
|
92 |
+
result_path = f'{opts.results_path}/{opts.name}/result_{img_name}'
|
93 |
+
except Exception as e:
|
94 |
+
print("Error evaluating via PIFu", e)
|
95 |
+
|
96 |
+
try:
|
97 |
+
mesh = trimesh.load(result_path + '.obj')\
|
98 |
+
# flip mesh
|
99 |
+
mesh.apply_transform([[1, 0, 0, 0],
|
100 |
+
[0, 1, 0, 0],
|
101 |
+
[0, 0, -1, 0],
|
102 |
+
[0, 0, 0, 1]])
|
103 |
+
mesh.export(file_obj=result_path + '.glb')
|
104 |
+
result_gltf = result_path + '.glb'
|
105 |
+
return result_gltf
|
106 |
|
107 |
+
except Exception as e:
|
108 |
+
print("error generating MESH", e)
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
|
111 |
examples = [["./examples/" + img] for img in sorted(os.listdir("./examples/"))]
|
112 |
description = '''
|
113 |
# PIFu Clothed Human Digitization
|
114 |
+
# PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization
|
115 |
<base target="_blank">
|
116 |
|
117 |
This is a demo for <a href="https://github.com/shunsukesaito/PIFu" target="_blank"> PIFu model </a>.
|
|
|
120 |
|
121 |
**The inference takes about 180seconds for a new image.**
|
122 |
|
123 |
+
<details>
|
124 |
<summary>More</summary>
|
125 |
|
126 |
+
# Image Credits
|
127 |
|
128 |
* Julien and Clem
|
129 |
* [StyleGAN Humans](https://huggingface.co/spaces/hysts/StyleGAN-Human)
|
130 |
* [Renderpeople: Dennis](https://renderpeople.com)
|
131 |
|
132 |
|
133 |
+
# More
|
134 |
* https://phorhum.github.io/
|
135 |
* https://github.com/yuliangxiu/icon
|
136 |
* https://shunsukesaito.github.io/PIFuHD/
|
|
|
146 |
examples=examples,
|
147 |
allow_flagging="never",
|
148 |
cache_examples=True
|
149 |
+
|
150 |
+
|
151 |
)
|
152 |
|
153 |
if __name__ == "__main__":
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
|
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 2.9.0b8
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
python_version: 3.7.13
|
11 |
---
|
|
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 2.9.0b8
|
8 |
+
app_file: ./PIFu/spaces.py
|
9 |
pinned: false
|
10 |
python_version: 3.7.13
|
11 |
---
|
remove_bg.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
# from https://huggingface.co/spaces/eugenesiow/remove-bg/blob/main/app.py
|
2 |
-
import cv2
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
from torchvision import transforms
|
6 |
-
|
7 |
-
class RemoveBackground(object):
|
8 |
-
def __init__(self):
|
9 |
-
self.model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True)
|
10 |
-
self.model.eval()
|
11 |
-
|
12 |
-
def make_transparent_foreground(self, pic, mask):
|
13 |
-
# split the image into channels
|
14 |
-
b, g, r = cv2.split(np.array(pic).astype('uint8'))
|
15 |
-
# add an alpha channel with and fill all with transparent pixels (max 255)
|
16 |
-
a = np.ones(mask.shape, dtype='uint8') * 255
|
17 |
-
# merge the alpha channel back
|
18 |
-
alpha_im = cv2.merge([b, g, r, a], 4)
|
19 |
-
# create a transparent background
|
20 |
-
bg = np.zeros(alpha_im.shape)
|
21 |
-
# setup the new mask
|
22 |
-
new_mask = np.stack([mask, mask, mask, mask], axis=2)
|
23 |
-
# copy only the foreground color pixels from the original image where mask is set
|
24 |
-
foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
|
25 |
-
|
26 |
-
return foreground
|
27 |
-
|
28 |
-
|
29 |
-
def remove_background(self, input_image):
|
30 |
-
preprocess = transforms.Compose([
|
31 |
-
transforms.ToTensor(),
|
32 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
33 |
-
])
|
34 |
-
|
35 |
-
input_tensor = preprocess(input_image)
|
36 |
-
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
37 |
-
|
38 |
-
# move the input and model to GPU for speed if available
|
39 |
-
if torch.cuda.is_available():
|
40 |
-
input_batch = input_batch.to('cuda')
|
41 |
-
self.model.to('cuda')
|
42 |
-
|
43 |
-
with torch.no_grad():
|
44 |
-
output = self.model(input_batch)['out'][0]
|
45 |
-
output_predictions = output.argmax(0)
|
46 |
-
|
47 |
-
# create a binary (black and white) mask of the profile foreground
|
48 |
-
mask = output_predictions.byte().cpu().numpy()
|
49 |
-
background = np.zeros(mask.shape)
|
50 |
-
bin_mask = np.where(mask, 255, background).astype(np.uint8)
|
51 |
-
|
52 |
-
foreground = self.make_transparent_foreground(input_image, bin_mask)
|
53 |
-
|
54 |
-
return foreground, bin_mask
|
55 |
-
|
56 |
-
def inference(self, img):
|
57 |
-
foreground, _ = self.remove_background(img)
|
58 |
-
return foreground
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -19,4 +19,6 @@ six==1.14.0
|
|
19 |
torch==1.4.0
|
20 |
torchvision==0.5.0
|
21 |
trimesh==3.5.23
|
22 |
-
tqdm==4.64.0
|
|
|
|
|
|
19 |
torch==1.4.0
|
20 |
torchvision==0.5.0
|
21 |
trimesh==3.5.23
|
22 |
+
tqdm==4.64.0
|
23 |
+
paddlehub
|
24 |
+
paddlepaddle
|