sidharthism commited on
Commit
621050a
·
1 Parent(s): d3d01c9

Create new file

Browse files
cloth_segmentation DELETED
@@ -1 +0,0 @@
1
- Subproject commit 28392f0da3aa5eb9ae64db73d04b31be10ce6350
 
 
cloth_segmentation/generate_cloth_mask.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .networks import U2NET
2
+ import torchvision.transforms as transforms
3
+ import torch.nn.functional as F
4
+ import os
5
+ from PIL import Image
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+
10
+ device = 'cuda' if torch.cuda.is_available() else "cpu"
11
+
12
+ if device == 'cuda':
13
+ torch.cuda.empty_cache()
14
+
15
+ # for hugging face
16
+ BASE_DIR = "/home/path/app"
17
+
18
+ # BASE_DIR = os.getcwd()
19
+
20
+ image_dir = 'cloth'
21
+ result_dir = 'cloth_mask'
22
+ checkpoint_path = 'cloth_segmentation/checkpoints/cloth_segm_u2net_latest.pth'
23
+
24
+
25
+ def load_checkpoint_mgpu(model, checkpoint_path):
26
+ if not os.path.exists(checkpoint_path):
27
+ print("----No checkpoints at given path----")
28
+ return
29
+ model_state_dict = torch.load(
30
+ checkpoint_path, map_location=torch.device("cpu"))
31
+ new_state_dict = OrderedDict()
32
+ for k, v in model_state_dict.items():
33
+ name = k[7:] # remove `module.`
34
+ new_state_dict[name] = v
35
+
36
+ model.load_state_dict(new_state_dict)
37
+ print("----checkpoints loaded from path: {}----".format(checkpoint_path))
38
+ return model
39
+
40
+
41
+ class Normalize_image(object):
42
+ """Normalize given tensor into given mean and standard dev
43
+ Args:
44
+ mean (float): Desired mean to substract from tensors
45
+ std (float): Desired std to divide from tensors
46
+ """
47
+
48
+ def __init__(self, mean, std):
49
+ assert isinstance(mean, (float))
50
+ if isinstance(mean, float):
51
+ self.mean = mean
52
+
53
+ if isinstance(std, float):
54
+ self.std = std
55
+
56
+ self.normalize_1 = transforms.Normalize(self.mean, self.std)
57
+ self.normalize_3 = transforms.Normalize(
58
+ [self.mean] * 3, [self.std] * 3)
59
+ self.normalize_18 = transforms.Normalize(
60
+ [self.mean] * 18, [self.std] * 18)
61
+
62
+ def __call__(self, image_tensor):
63
+ if image_tensor.shape[0] == 1:
64
+ return self.normalize_1(image_tensor)
65
+
66
+ elif image_tensor.shape[0] == 3:
67
+ return self.normalize_3(image_tensor)
68
+
69
+ elif image_tensor.shape[0] == 18:
70
+ return self.normalize_18(image_tensor)
71
+
72
+ else:
73
+ assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"
74
+
75
+
76
+ def get_palette(num_cls):
77
+ """ Returns the color map for visualizing the segmentation mask.
78
+ Args:
79
+ num_cls: Number of classes
80
+ Returns:
81
+ The color map
82
+ """
83
+ n = num_cls
84
+ palette = [0] * (n * 3)
85
+ for j in range(0, n):
86
+ lab = j
87
+ palette[j * 3 + 0] = 0
88
+ palette[j * 3 + 1] = 0
89
+ palette[j * 3 + 2] = 0
90
+ i = 0
91
+ while lab:
92
+ palette[j * 3 + 0] = 255
93
+ palette[j * 3 + 1] = 255
94
+ palette[j * 3 + 2] = 255
95
+ # palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
96
+ # palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
97
+ # palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
98
+ i += 1
99
+ lab >>= 3
100
+ return palette
101
+
102
+
103
+ def generate_cloth_mask():
104
+ transforms_list = []
105
+ transforms_list += [transforms.ToTensor()]
106
+ transforms_list += [Normalize_image(0.5, 0.5)]
107
+ transform_rgb = transforms.Compose(transforms_list)
108
+
109
+ net = U2NET(in_ch=3, out_ch=4)
110
+ with torch.no_grad():
111
+ net = load_checkpoint_mgpu(net, checkpoint_path)
112
+ net = net.to(device)
113
+ net = net.eval()
114
+
115
+ palette = get_palette(4)
116
+
117
+ images_list = sorted(os.listdir(image_dir))
118
+ for image_name in images_list:
119
+ img = Image.open(os.path.join(
120
+ image_dir, image_name)).convert('RGB')
121
+ img_size = img.size
122
+ img = img.resize((768, 768), Image.Resampling.BICUBIC)
123
+ image_tensor = transform_rgb(img)
124
+ image_tensor = torch.unsqueeze(image_tensor, 0)
125
+
126
+ output_tensor = net(image_tensor.to(device))
127
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
128
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
129
+ output_tensor = torch.squeeze(output_tensor, dim=0)
130
+ output_tensor = torch.squeeze(output_tensor, dim=0)
131
+ output_arr = output_tensor.cpu().numpy()
132
+
133
+ output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
134
+ output_img = output_img.resize(img_size, Image.Resampling.BICUBIC)
135
+
136
+ output_img.putpalette(palette)
137
+ output_img = output_img.convert('L')
138
+ output_img.save(os.path.join(result_dir, image_name[:-4]+'.jpg'))
139
+
140
+
141
+ if __name__ == '__main__':
142
+ generate_cloth_mask()