mix-bt / transfer_datasets /aircraft.py
wgcban's picture
Upload 98 files
803ef9e
raw
history blame
3.12 kB
import os
import numpy as np
from PIL import Image
from os.path import join
from collections import defaultdict
import torch.utils.data as data
DATA_ROOTS = 'data/Aircraft'
# url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
# wget http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz
# python
# from torchvision.datasets.utils import extract_archive
# extract_archive("fgvc-aircraft-2013b.tar.gz")
# Download and preprocess: https://github.com/lvyilin/pytorch-fgvc-dataset/blob/master/aircraft.py
# class_types = ('variant', 'family', 'manufacturer')
# splits = ('train', 'val', 'trainval', 'test')
# img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images')
class Aircraft(data.Dataset):
def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
super().__init__()
self.root = root
self.train = train
self.image_transforms = image_transforms
paths, bboxes, labels = self.load_images()
self.paths = paths
self.bboxes = bboxes
self.labels = labels
def load_images(self):
split = 'trainval' if self.train else 'test'
variant_path = os.path.join(self.root, 'data', 'images_variant_%s.txt'%split)
with open(variant_path, 'r') as f:
names_to_variants = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
names_to_variants = dict(names_to_variants)
variants_to_names = defaultdict(list)
for name, variant in names_to_variants.items():
variants_to_names[variant].append(name)
variants = sorted(list(set(variants_to_names.keys())))
names_to_bboxes = self.get_bounding_boxes()
split_files, split_labels, split_bboxes = [], [], []
for variant_id, variant in enumerate(variants):
class_files = [join(self.root, 'data', 'images', '%s.jpg'%filename) for filename in sorted(variants_to_names[variant])]
bboxes = [names_to_bboxes[name] for name in sorted(variants_to_names[variant])]
labels = list([variant_id] * len(class_files))
split_files += class_files
split_labels += labels
split_bboxes += bboxes
return split_files, split_bboxes, split_labels
def get_bounding_boxes(self):
bboxes_path = os.path.join(self.root, 'data', 'images_box.txt')
with open(bboxes_path, 'r') as f:
names_to_bboxes = [line.split('\n')[0].split(' ') for line in f.readlines()]
names_to_bboxes = dict((name, list(map(int, (xmin, ymin, xmax, ymax)))) for name, xmin, ymin, xmax, ymax in names_to_bboxes)
return names_to_bboxes
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
bbox = tuple(self.bboxes[index])
label = self.labels[index]
image = Image.open(path).convert(mode='RGB')
image = image.crop(bbox)
if self.image_transforms:
image = self.image_transforms(image)
return image, label