Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
import albumentations as A
|
5 |
+
import cv2
|
6 |
+
import torchvision.transforms as T
|
7 |
+
from albumentations.pytorch import ToTensorV2
|
8 |
+
from torch.utils.data import DataLoader, Dataset
|
9 |
+
|
10 |
+
|
11 |
+
class CustomDataset(Dataset):
|
12 |
+
def __init__(self, data_dir=str, transforms=None):
|
13 |
+
self.data_dir = data_dir
|
14 |
+
self.transforms = transforms
|
15 |
+
self.data = []
|
16 |
+
self.class_map = {}
|
17 |
+
self.extensions = ("jpeg", "jpg", "png")
|
18 |
+
|
19 |
+
file_list = sorted(glob.glob(self.data_dir + "/*"))
|
20 |
+
|
21 |
+
for class_path in file_list:
|
22 |
+
class_name = class_path.split("/")[-1]
|
23 |
+
for img_path in glob.glob(class_path + "/*"):
|
24 |
+
ext = img_path.split("/")[-1].split(".")[-1]
|
25 |
+
if ext in self.extensions:
|
26 |
+
self.data.append([img_path, class_name])
|
27 |
+
|
28 |
+
for idx, class_path in enumerate(file_list):
|
29 |
+
class_name = class_path.split("/")[-1]
|
30 |
+
self.class_map[class_name] = idx
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.data)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
img_path, class_name = self.data[idx]
|
37 |
+
img = cv2.imread(img_path)
|
38 |
+
# Applying transforms on image
|
39 |
+
if self.transforms:
|
40 |
+
img = self.transforms(image=img)["image"]
|
41 |
+
|
42 |
+
label = self.class_map[class_name]
|
43 |
+
|
44 |
+
return img, label
|
45 |
+
|
46 |
+
|
47 |
+
def get_transform():
|
48 |
+
resize = A.Resize(224, 224)
|
49 |
+
normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
50 |
+
to_tensor = ToTensorV2()
|
51 |
+
return A.Compose([resize, normalize, to_tensor])
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
data_dir = os.path.join(os.path.abspath(__file__ + "/../../"), "data/train/")
|
56 |
+
print(data_dir)
|
57 |
+
dataset = CustomDataset(data_dir=data_dir, transforms=get_transform())
|
58 |
+
# print(len(dataset))
|
59 |
+
# print(dataset[0][0].shape)
|
60 |
+
|
61 |
+
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
|
62 |
+
total_imgs = 0
|
63 |
+
for imgs, labels in data_loader:
|
64 |
+
total_imgs += int(imgs.shape[0])
|
65 |
+
print(imgs.shape)
|
66 |
+
break
|
67 |
+
print(total_imgs)
|