mix-bt / transfer_datasets /cu_birds.py
wgcban's picture
Upload 98 files
803ef9e
raw
history blame
2.27 kB
import os
import numpy as np
from PIL import Image
import torch.utils.data as data
DATA_ROOTS = 'data/CUBirds'
# wget https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz
# tar -xvzf CUB_200_2011.tgz
class CUBirds(data.Dataset):
def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
super().__init__()
self.root = root
self.train = train
self.image_transforms = image_transforms
paths, labels = self.load_images()
self.paths, self.labels = paths, labels
def load_images(self):
image_info_path = os.path.join(self.root, 'images.txt')
with open(image_info_path, 'r') as f:
image_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
image_info = dict(image_info)
# load image to label information
label_info_path = os.path.join(self.root, 'image_class_labels.txt')
with open(label_info_path, 'r') as f:
label_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
label_info = dict(label_info)
# load train test split
train_test_info_path = os.path.join(self.root, 'train_test_split.txt')
with open(train_test_info_path, 'r') as f:
train_test_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
train_test_info = dict(train_test_info)
all_paths, all_labels = [], []
for index, image_path in image_info.items():
label = label_info[index]
split = int(train_test_info[index])
if self.train:
if split == 1:
all_paths.append(image_path)
all_labels.append(label)
else:
if split == 0:
all_paths.append(image_path)
all_labels.append(label)
return all_paths, all_labels
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = os.path.join(self.root, 'images', self.paths[index])
label = int(self.labels[index]) - 1
image = Image.open(path).convert(mode='RGB')
if self.image_transforms:
image = self.image_transforms(image)
return image, label