import os import copy import json import operator import numpy as np from PIL import Image from os.path import join from itertools import chain from scipy.io import loadmat from collections import defaultdict import torch import torch.utils.data as data from torchvision import transforms DATA_ROOTS = 'data/VGGFlower' # wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz # tar -xvzf 102flowers.tgz # rename file to VGGFlower # cd VGGFlower # wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat class VGGFlower(data.Dataset): def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): super().__init__() self.root = root self.train = train self.image_transforms = image_transforms paths, labels = self.load_images() self.paths, self.labels = paths, labels def load_images(self): rs = np.random.RandomState(42) imagelabels_path = os.path.join(self.root, 'imagelabels.mat') with open(imagelabels_path, 'rb') as f: labels = loadmat(f)['labels'][0] all_filepaths = defaultdict(list) for i, label in enumerate(labels): # all_filepaths[label].append(os.path.join(self.root, 'jpg', 'image_{:05d}.jpg'.format(i+1))) all_filepaths[label].append(os.path.join(self.root, 'image_{:05d}.jpg'.format(i+1))) # train test split split_filepaths, split_labels = [], [] for label, paths in all_filepaths.items(): num = len(paths) paths = np.array(paths) indexer = np.arange(num) rs.shuffle(indexer) paths = paths[indexer].tolist() if self.train: paths = paths[:int(0.8 * num)] else: paths = paths[int(0.8 * num):] labels = [label] * len(paths) split_filepaths.extend(paths) split_labels.extend(labels) return split_filepaths, split_labels def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] label = int(self.labels[index]) - 1 image = Image.open(path).convert(mode='RGB') if self.image_transforms: image = self.image_transforms(image) return image, label