Spaces:
Sleeping
Sleeping
from albumentations.pytorch import ToTensorV2 | |
from torch.utils.data import DataLoader, Dataset | |
class CustomDataset(Dataset): | |
def __init__(self, data_dir=str, transforms=None): | |
self.data_dir = data_dir | |
self.transforms = transforms | |
self.data = [] | |
self.class_map = {} | |
self.extensions = ("jpeg", "jpg", "png") | |
file_list = sorted(glob.glob(self.data_dir + "/*")) | |
for class_path in file_list: | |
class_name = class_path.split("/")[-1] | |
for img_path in glob.glob(class_path + "/*"): | |
ext = img_path.split("/")[-1].split(".")[-1] | |
if ext in self.extensions: | |
self.data.append([img_path, class_name]) | |
for idx, class_path in enumerate(file_list): | |
class_name = class_path.split("/")[-1] | |
self.class_map[class_name] = idx | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
img_path, class_name = self.data[idx] | |
img = cv2.imread(img_path) | |
# Applying transforms on image | |
if self.transforms: | |
img = self.transforms(image=img)["image"] | |
label = self.class_map[class_name] | |
return img, label | |
def get_transform(): | |
resize = A.Resize(224, 224) | |
normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
to_tensor = ToTensorV2() | |
return A.Compose([resize, normalize, to_tensor]) | |
if __name__ == "__main__": | |
data_dir = os.path.join(os.path.abspath(__file__ + "/../../"), "data/train/") | |
print(data_dir) | |
dataset = CustomDataset(data_dir=data_dir, transforms=get_transform()) | |
# print(len(dataset)) | |
# print(dataset[0][0].shape) | |
data_loader = DataLoader(dataset, batch_size=16, shuffle=True) | |
total_imgs = 0 | |
for imgs, labels in data_loader: | |
total_imgs += int(imgs.shape[0]) | |
print(imgs.shape) | |
break | |
print(total_imgs) |