Spaces:
No application file
No application file
import os | |
import numpy as np | |
import config | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image | |
class MyImageFolder(Dataset): | |
def __init__(self, root_dir): | |
super(MyImageFolder, self).__init__() | |
self.data = [] | |
self.root_dir = root_dir | |
self.class_names = os.listdir(root_dir) | |
for index, name in enumerate(self.class_names): | |
files = os.listdir(os.path.join(root_dir, name)) | |
self.data += list(zip(files, [index] * len(files))) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
img_file, label = self.data[index] | |
root_and_dir = os.path.join(self.root_dir, self.class_names[label]) | |
image = np.array(Image.open(os.path.join(root_and_dir, img_file))) | |
image = config.both_transforms(image=image)["image"] | |
high_res = config.highres_transform(image=image)["image"] | |
low_res = config.lowres_transform(image=image)["image"] | |
return low_res, high_res | |
def test(): | |
dataset = MyImageFolder(root_dir="new_data/") | |
loader = DataLoader(dataset, batch_size=1, num_workers=8) | |
for low_res, high_res in loader: | |
print(low_res.shape) | |
print(high_res.shape) | |
if __name__ == "__main__": | |
test() | |