File size: 4,507 Bytes
621050a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from .networks import U2NET
import torchvision.transforms as transforms
import torch.nn.functional as F
import os
from PIL import Image
from collections import OrderedDict

import torch

device = 'cuda' if torch.cuda.is_available() else "cpu"

if device == 'cuda':
    torch.cuda.empty_cache()

# for hugging face
BASE_DIR = "/home/path/app"

# BASE_DIR = os.getcwd()

image_dir = 'cloth'
result_dir = 'cloth_mask'
checkpoint_path = 'cloth_segmentation/checkpoints/cloth_segm_u2net_latest.pth'


def load_checkpoint_mgpu(model, checkpoint_path):
    if not os.path.exists(checkpoint_path):
        print("----No checkpoints at given path----")
        return
    model_state_dict = torch.load(
        checkpoint_path, map_location=torch.device("cpu"))
    new_state_dict = OrderedDict()
    for k, v in model_state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)
    print("----checkpoints loaded from path: {}----".format(checkpoint_path))
    return model


class Normalize_image(object):
    """Normalize given tensor into given mean and standard dev
    Args:
        mean (float): Desired mean to substract from tensors
        std (float): Desired std to divide from tensors
    """

    def __init__(self, mean, std):
        assert isinstance(mean, (float))
        if isinstance(mean, float):
            self.mean = mean

        if isinstance(std, float):
            self.std = std

        self.normalize_1 = transforms.Normalize(self.mean, self.std)
        self.normalize_3 = transforms.Normalize(
            [self.mean] * 3, [self.std] * 3)
        self.normalize_18 = transforms.Normalize(
            [self.mean] * 18, [self.std] * 18)

    def __call__(self, image_tensor):
        if image_tensor.shape[0] == 1:
            return self.normalize_1(image_tensor)

        elif image_tensor.shape[0] == 3:
            return self.normalize_3(image_tensor)

        elif image_tensor.shape[0] == 18:
            return self.normalize_18(image_tensor)

        else:
            assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"


def get_palette(num_cls):
    """ Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] = 255
            palette[j * 3 + 1] = 255
            palette[j * 3 + 2] = 255
            # palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            # palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            # palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3
    return palette


def generate_cloth_mask():
    transforms_list = []
    transforms_list += [transforms.ToTensor()]
    transforms_list += [Normalize_image(0.5, 0.5)]
    transform_rgb = transforms.Compose(transforms_list)

    net = U2NET(in_ch=3, out_ch=4)
    with torch.no_grad():
        net = load_checkpoint_mgpu(net, checkpoint_path)
        net = net.to(device)
        net = net.eval()

        palette = get_palette(4)

        images_list = sorted(os.listdir(image_dir))
        for image_name in images_list:
            img = Image.open(os.path.join(
                image_dir, image_name)).convert('RGB')
            img_size = img.size
            img = img.resize((768, 768), Image.Resampling.BICUBIC)
            image_tensor = transform_rgb(img)
            image_tensor = torch.unsqueeze(image_tensor, 0)

            output_tensor = net(image_tensor.to(device))
            output_tensor = F.log_softmax(output_tensor[0], dim=1)
            output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
            output_tensor = torch.squeeze(output_tensor, dim=0)
            output_tensor = torch.squeeze(output_tensor, dim=0)
            output_arr = output_tensor.cpu().numpy()

            output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
            output_img = output_img.resize(img_size, Image.Resampling.BICUBIC)

            output_img.putpalette(palette)
            output_img = output_img.convert('L')
            output_img.save(os.path.join(result_dir, image_name[:-4]+'.jpg'))


if __name__ == '__main__':
    generate_cloth_mask()