import os import numpy as np from PIL import Image from os.path import join from collections import defaultdict import torch.utils.data as data DATA_ROOTS = 'data/Aircraft' # url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' # wget http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz # python # from torchvision.datasets.utils import extract_archive # extract_archive("fgvc-aircraft-2013b.tar.gz") # Download and preprocess: https://github.com/lvyilin/pytorch-fgvc-dataset/blob/master/aircraft.py # class_types = ('variant', 'family', 'manufacturer') # splits = ('train', 'val', 'trainval', 'test') # img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images') class Aircraft(data.Dataset): def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None): super().__init__() self.root = root self.train = train self.image_transforms = image_transforms paths, bboxes, labels = self.load_images() self.paths = paths self.bboxes = bboxes self.labels = labels def load_images(self): split = 'trainval' if self.train else 'test' variant_path = os.path.join(self.root, 'data', 'images_variant_%s.txt'%split) with open(variant_path, 'r') as f: names_to_variants = [line.split('\n')[0].split(' ', 1) for line in f.readlines()] names_to_variants = dict(names_to_variants) variants_to_names = defaultdict(list) for name, variant in names_to_variants.items(): variants_to_names[variant].append(name) variants = sorted(list(set(variants_to_names.keys()))) names_to_bboxes = self.get_bounding_boxes() split_files, split_labels, split_bboxes = [], [], [] for variant_id, variant in enumerate(variants): class_files = [join(self.root, 'data', 'images', '%s.jpg'%filename) for filename in sorted(variants_to_names[variant])] bboxes = [names_to_bboxes[name] for name in sorted(variants_to_names[variant])] labels = list([variant_id] * len(class_files)) split_files += class_files split_labels += labels split_bboxes += bboxes return split_files, split_bboxes, split_labels def get_bounding_boxes(self): bboxes_path = os.path.join(self.root, 'data', 'images_box.txt') with open(bboxes_path, 'r') as f: names_to_bboxes = [line.split('\n')[0].split(' ') for line in f.readlines()] names_to_bboxes = dict((name, list(map(int, (xmin, ymin, xmax, ymax)))) for name, xmin, ymin, xmax, ymax in names_to_bboxes) return names_to_bboxes def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] bbox = tuple(self.bboxes[index]) label = self.labels[index] image = Image.open(path).convert(mode='RGB') image = image.crop(bbox) if self.image_transforms: image = self.image_transforms(image) return image, label