HebaAllah commited on
Commit
f061edc
·
verified ·
1 Parent(s): d879f4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import albumentations as A
5
+ import cv2
6
+ import torchvision.transforms as T
7
+ from albumentations.pytorch import ToTensorV2
8
+ from torch.utils.data import DataLoader, Dataset
9
+
10
+
11
+ class CustomDataset(Dataset):
12
+ def __init__(self, data_dir=str, transforms=None):
13
+ self.data_dir = data_dir
14
+ self.transforms = transforms
15
+ self.data = []
16
+ self.class_map = {}
17
+ self.extensions = ("jpeg", "jpg", "png")
18
+
19
+ file_list = sorted(glob.glob(self.data_dir + "/*"))
20
+
21
+ for class_path in file_list:
22
+ class_name = class_path.split("/")[-1]
23
+ for img_path in glob.glob(class_path + "/*"):
24
+ ext = img_path.split("/")[-1].split(".")[-1]
25
+ if ext in self.extensions:
26
+ self.data.append([img_path, class_name])
27
+
28
+ for idx, class_path in enumerate(file_list):
29
+ class_name = class_path.split("/")[-1]
30
+ self.class_map[class_name] = idx
31
+
32
+ def __len__(self):
33
+ return len(self.data)
34
+
35
+ def __getitem__(self, idx):
36
+ img_path, class_name = self.data[idx]
37
+ img = cv2.imread(img_path)
38
+ # Applying transforms on image
39
+ if self.transforms:
40
+ img = self.transforms(image=img)["image"]
41
+
42
+ label = self.class_map[class_name]
43
+
44
+ return img, label
45
+
46
+
47
+ def get_transform():
48
+ resize = A.Resize(224, 224)
49
+ normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
50
+ to_tensor = ToTensorV2()
51
+ return A.Compose([resize, normalize, to_tensor])
52
+
53
+
54
+ if __name__ == "__main__":
55
+ data_dir = os.path.join(os.path.abspath(__file__ + "/../../"), "data/train/")
56
+ print(data_dir)
57
+ dataset = CustomDataset(data_dir=data_dir, transforms=get_transform())
58
+ # print(len(dataset))
59
+ # print(dataset[0][0].shape)
60
+
61
+ data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
62
+ total_imgs = 0
63
+ for imgs, labels in data_loader:
64
+ total_imgs += int(imgs.shape[0])
65
+ print(imgs.shape)
66
+ break
67
+ print(total_imgs)