Spaces:
Runtime error
Runtime error
File size: 4,507 Bytes
621050a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from .networks import U2NET
import torchvision.transforms as transforms
import torch.nn.functional as F
import os
from PIL import Image
from collections import OrderedDict
import torch
device = 'cuda' if torch.cuda.is_available() else "cpu"
if device == 'cuda':
torch.cuda.empty_cache()
# for hugging face
BASE_DIR = "/home/path/app"
# BASE_DIR = os.getcwd()
image_dir = 'cloth'
result_dir = 'cloth_mask'
checkpoint_path = 'cloth_segmentation/checkpoints/cloth_segm_u2net_latest.pth'
def load_checkpoint_mgpu(model, checkpoint_path):
if not os.path.exists(checkpoint_path):
print("----No checkpoints at given path----")
return
model_state_dict = torch.load(
checkpoint_path, map_location=torch.device("cpu"))
new_state_dict = OrderedDict()
for k, v in model_state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print("----checkpoints loaded from path: {}----".format(checkpoint_path))
return model
class Normalize_image(object):
"""Normalize given tensor into given mean and standard dev
Args:
mean (float): Desired mean to substract from tensors
std (float): Desired std to divide from tensors
"""
def __init__(self, mean, std):
assert isinstance(mean, (float))
if isinstance(mean, float):
self.mean = mean
if isinstance(std, float):
self.std = std
self.normalize_1 = transforms.Normalize(self.mean, self.std)
self.normalize_3 = transforms.Normalize(
[self.mean] * 3, [self.std] * 3)
self.normalize_18 = transforms.Normalize(
[self.mean] * 18, [self.std] * 18)
def __call__(self, image_tensor):
if image_tensor.shape[0] == 1:
return self.normalize_1(image_tensor)
elif image_tensor.shape[0] == 3:
return self.normalize_3(image_tensor)
elif image_tensor.shape[0] == 18:
return self.normalize_18(image_tensor)
else:
assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] = 255
palette[j * 3 + 1] = 255
palette[j * 3 + 2] = 255
# palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
# palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
# palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
def generate_cloth_mask():
transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)
net = U2NET(in_ch=3, out_ch=4)
with torch.no_grad():
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()
palette = get_palette(4)
images_list = sorted(os.listdir(image_dir))
for image_name in images_list:
img = Image.open(os.path.join(
image_dir, image_name)).convert('RGB')
img_size = img.size
img = img.resize((768, 768), Image.Resampling.BICUBIC)
image_tensor = transform_rgb(img)
image_tensor = torch.unsqueeze(image_tensor, 0)
output_tensor = net(image_tensor.to(device))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
output_img = output_img.resize(img_size, Image.Resampling.BICUBIC)
output_img.putpalette(palette)
output_img = output_img.convert('L')
output_img.save(os.path.join(result_dir, image_name[:-4]+'.jpg'))
if __name__ == '__main__':
generate_cloth_mask()
|