depi_final / app.py
HebaAllah's picture
Update app.py
b5a24ea verified
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
class CustomDataset(Dataset):
def __init__(self, data_dir=str, transforms=None):
self.data_dir = data_dir
self.transforms = transforms
self.data = []
self.class_map = {}
self.extensions = ("jpeg", "jpg", "png")
file_list = sorted(glob.glob(self.data_dir + "/*"))
for class_path in file_list:
class_name = class_path.split("/")[-1]
for img_path in glob.glob(class_path + "/*"):
ext = img_path.split("/")[-1].split(".")[-1]
if ext in self.extensions:
self.data.append([img_path, class_name])
for idx, class_path in enumerate(file_list):
class_name = class_path.split("/")[-1]
self.class_map[class_name] = idx
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path, class_name = self.data[idx]
img = cv2.imread(img_path)
# Applying transforms on image
if self.transforms:
img = self.transforms(image=img)["image"]
label = self.class_map[class_name]
return img, label
def get_transform():
resize = A.Resize(224, 224)
normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
to_tensor = ToTensorV2()
return A.Compose([resize, normalize, to_tensor])
if __name__ == "__main__":
data_dir = os.path.join(os.path.abspath(__file__ + "/../../"), "data/train/")
print(data_dir)
dataset = CustomDataset(data_dir=data_dir, transforms=get_transform())
# print(len(dataset))
# print(dataset[0][0].shape)
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
total_imgs = 0
for imgs, labels in data_loader:
total_imgs += int(imgs.shape[0])
print(imgs.shape)
break
print(total_imgs)