Spaces:
Runtime error
Runtime error
David Piscasio
commited on
Commit
·
7369193
1
Parent(s):
3ce13dc
Added data folder
Browse files- data/__init__.py +93 -0
- data/__pycache__/__init__.cpython-38.pyc +0 -0
- data/__pycache__/base_dataset.cpython-38.pyc +0 -0
- data/__pycache__/image_folder.cpython-38.pyc +0 -0
- data/__pycache__/single_dataset.cpython-38.pyc +0 -0
- data/aligned_dataset.py +60 -0
- data/base_dataset.py +157 -0
- data/colorization_dataset.py +68 -0
- data/image_folder.py +65 -0
- data/single_dataset.py +40 -0
- data/template_dataset.py +75 -0
- data/unaligned_dataset.py +71 -0
data/__init__.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import importlib
|
14 |
+
import torch.utils.data
|
15 |
+
from data.base_dataset import BaseDataset
|
16 |
+
|
17 |
+
|
18 |
+
def find_dataset_using_name(dataset_name):
|
19 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
20 |
+
|
21 |
+
In the file, the class called DatasetNameDataset() will
|
22 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
23 |
+
and it is case-insensitive.
|
24 |
+
"""
|
25 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
26 |
+
datasetlib = importlib.import_module(dataset_filename)
|
27 |
+
|
28 |
+
dataset = None
|
29 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
30 |
+
for name, cls in datasetlib.__dict__.items():
|
31 |
+
if name.lower() == target_dataset_name.lower() \
|
32 |
+
and issubclass(cls, BaseDataset):
|
33 |
+
dataset = cls
|
34 |
+
|
35 |
+
if dataset is None:
|
36 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
37 |
+
|
38 |
+
return dataset
|
39 |
+
|
40 |
+
|
41 |
+
def get_option_setter(dataset_name):
|
42 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
43 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
44 |
+
return dataset_class.modify_commandline_options
|
45 |
+
|
46 |
+
|
47 |
+
def create_dataset(opt):
|
48 |
+
"""Create a dataset given the option.
|
49 |
+
|
50 |
+
This function wraps the class CustomDatasetDataLoader.
|
51 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
52 |
+
|
53 |
+
Example:
|
54 |
+
>>> from data import create_dataset
|
55 |
+
>>> dataset = create_dataset(opt)
|
56 |
+
"""
|
57 |
+
data_loader = CustomDatasetDataLoader(opt)
|
58 |
+
dataset = data_loader.load_data()
|
59 |
+
return dataset
|
60 |
+
|
61 |
+
|
62 |
+
class CustomDatasetDataLoader():
|
63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
64 |
+
|
65 |
+
def __init__(self, opt):
|
66 |
+
"""Initialize this class
|
67 |
+
|
68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
69 |
+
Step 2: create a multi-threaded data loader.
|
70 |
+
"""
|
71 |
+
self.opt = opt
|
72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
73 |
+
self.dataset = dataset_class(opt)
|
74 |
+
print("dataset [%s] was created" % type(self.dataset).__name__)
|
75 |
+
self.dataloader = torch.utils.data.DataLoader(
|
76 |
+
self.dataset,
|
77 |
+
batch_size=opt.batch_size,
|
78 |
+
shuffle=not opt.serial_batches,
|
79 |
+
num_workers=int(opt.num_threads))
|
80 |
+
|
81 |
+
def load_data(self):
|
82 |
+
return self
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
"""Return the number of data in the dataset"""
|
86 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
87 |
+
|
88 |
+
def __iter__(self):
|
89 |
+
"""Return a batch of data"""
|
90 |
+
for i, data in enumerate(self.dataloader):
|
91 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
92 |
+
break
|
93 |
+
yield data
|
data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (4.03 kB). View file
|
|
data/__pycache__/base_dataset.cpython-38.pyc
ADDED
Binary file (5.9 kB). View file
|
|
data/__pycache__/image_folder.cpython-38.pyc
ADDED
Binary file (2.53 kB). View file
|
|
data/__pycache__/single_dataset.cpython-38.pyc
ADDED
Binary file (2.01 kB). View file
|
|
data/aligned_dataset.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from data.base_dataset import BaseDataset, get_params, get_transform
|
3 |
+
from data.image_folder import make_dataset
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
class AlignedDataset(BaseDataset):
|
8 |
+
"""A dataset class for paired image dataset.
|
9 |
+
|
10 |
+
It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
|
11 |
+
During test time, you need to prepare a directory '/path/to/data/test'.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, opt):
|
15 |
+
"""Initialize this dataset class.
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
19 |
+
"""
|
20 |
+
BaseDataset.__init__(self, opt)
|
21 |
+
self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
|
22 |
+
self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
|
23 |
+
assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image
|
24 |
+
self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
|
25 |
+
self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
|
26 |
+
|
27 |
+
def __getitem__(self, index):
|
28 |
+
"""Return a data point and its metadata information.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
index - - a random integer for data indexing
|
32 |
+
|
33 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
34 |
+
A (tensor) - - an image in the input domain
|
35 |
+
B (tensor) - - its corresponding image in the target domain
|
36 |
+
A_paths (str) - - image paths
|
37 |
+
B_paths (str) - - image paths (same as A_paths)
|
38 |
+
"""
|
39 |
+
# read a image given a random integer index
|
40 |
+
AB_path = self.AB_paths[index]
|
41 |
+
AB = Image.open(AB_path).convert('RGB')
|
42 |
+
# split AB image into A and B
|
43 |
+
w, h = AB.size
|
44 |
+
w2 = int(w / 2)
|
45 |
+
A = AB.crop((0, 0, w2, h))
|
46 |
+
B = AB.crop((w2, 0, w, h))
|
47 |
+
|
48 |
+
# apply the same transform to both A and B
|
49 |
+
transform_params = get_params(self.opt, A.size)
|
50 |
+
A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
|
51 |
+
B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))
|
52 |
+
|
53 |
+
A = A_transform(A)
|
54 |
+
B = B_transform(B)
|
55 |
+
|
56 |
+
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
"""Return the total number of images in the dataset."""
|
60 |
+
return len(self.AB_paths)
|
data/base_dataset.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(data.Dataset, ABC):
|
14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
15 |
+
|
16 |
+
To create a subclass, you need to implement the following four functions:
|
17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
18 |
+
-- <__len__>: return the size of dataset.
|
19 |
+
-- <__getitem__>: get a data point.
|
20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
"""Initialize the class; save the options in the class
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
28 |
+
"""
|
29 |
+
self.opt = opt
|
30 |
+
self.root = opt.dataroot
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def modify_commandline_options(parser, is_train):
|
34 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
parser -- original option parser
|
38 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
the modified parser.
|
42 |
+
"""
|
43 |
+
return parser
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def __len__(self):
|
47 |
+
"""Return the total number of images in the dataset."""
|
48 |
+
return 0
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def __getitem__(self, index):
|
52 |
+
"""Return a data point and its metadata information.
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
index - - a random integer for data indexing
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
59 |
+
"""
|
60 |
+
pass
|
61 |
+
|
62 |
+
|
63 |
+
def get_params(opt, size):
|
64 |
+
w, h = size
|
65 |
+
new_h = h
|
66 |
+
new_w = w
|
67 |
+
if opt.preprocess == 'resize_and_crop':
|
68 |
+
new_h = new_w = opt.load_size
|
69 |
+
elif opt.preprocess == 'scale_width_and_crop':
|
70 |
+
new_w = opt.load_size
|
71 |
+
new_h = opt.load_size * h // w
|
72 |
+
|
73 |
+
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
74 |
+
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
75 |
+
|
76 |
+
flip = random.random() > 0.5
|
77 |
+
|
78 |
+
return {'crop_pos': (x, y), 'flip': flip}
|
79 |
+
|
80 |
+
|
81 |
+
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
|
82 |
+
transform_list = []
|
83 |
+
if grayscale:
|
84 |
+
transform_list.append(transforms.Grayscale(1))
|
85 |
+
if 'resize' in opt.preprocess:
|
86 |
+
osize = [opt.load_size, opt.load_size]
|
87 |
+
transform_list.append(transforms.Resize(osize, method))
|
88 |
+
elif 'scale_width' in opt.preprocess:
|
89 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
|
90 |
+
|
91 |
+
if 'crop' in opt.preprocess:
|
92 |
+
if params is None:
|
93 |
+
transform_list.append(transforms.RandomCrop(opt.crop_size))
|
94 |
+
else:
|
95 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
96 |
+
|
97 |
+
if opt.preprocess == 'none':
|
98 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
|
99 |
+
|
100 |
+
if not opt.no_flip:
|
101 |
+
if params is None:
|
102 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
103 |
+
elif params['flip']:
|
104 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
105 |
+
|
106 |
+
if convert:
|
107 |
+
transform_list += [transforms.ToTensor()]
|
108 |
+
if grayscale:
|
109 |
+
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
110 |
+
else:
|
111 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
112 |
+
return transforms.Compose(transform_list)
|
113 |
+
|
114 |
+
|
115 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
116 |
+
ow, oh = img.size
|
117 |
+
h = int(round(oh / base) * base)
|
118 |
+
w = int(round(ow / base) * base)
|
119 |
+
if h == oh and w == ow:
|
120 |
+
return img
|
121 |
+
|
122 |
+
__print_size_warning(ow, oh, w, h)
|
123 |
+
return img.resize((w, h), method)
|
124 |
+
|
125 |
+
|
126 |
+
def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
|
127 |
+
ow, oh = img.size
|
128 |
+
if ow == target_size and oh >= crop_size:
|
129 |
+
return img
|
130 |
+
w = target_size
|
131 |
+
h = int(max(target_size * oh / ow, crop_size))
|
132 |
+
return img.resize((w, h), method)
|
133 |
+
|
134 |
+
|
135 |
+
def __crop(img, pos, size):
|
136 |
+
ow, oh = img.size
|
137 |
+
x1, y1 = pos
|
138 |
+
tw = th = size
|
139 |
+
if (ow > tw or oh > th):
|
140 |
+
return img.crop((x1, y1, x1 + tw, y1 + th))
|
141 |
+
return img
|
142 |
+
|
143 |
+
|
144 |
+
def __flip(img, flip):
|
145 |
+
if flip:
|
146 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
147 |
+
return img
|
148 |
+
|
149 |
+
|
150 |
+
def __print_size_warning(ow, oh, w, h):
|
151 |
+
"""Print warning information about image size(only print once)"""
|
152 |
+
if not hasattr(__print_size_warning, 'has_printed'):
|
153 |
+
print("The image size needs to be a multiple of 4. "
|
154 |
+
"The loaded image size was (%d, %d), so it was adjusted to "
|
155 |
+
"(%d, %d). This adjustment will be done to all images "
|
156 |
+
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
157 |
+
__print_size_warning.has_printed = True
|
data/colorization_dataset.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from data.base_dataset import BaseDataset, get_transform
|
3 |
+
from data.image_folder import make_dataset
|
4 |
+
from skimage import color # require skimage
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
|
9 |
+
|
10 |
+
class ColorizationDataset(BaseDataset):
|
11 |
+
"""This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space.
|
12 |
+
|
13 |
+
This dataset is required by pix2pix-based colorization model ('--model colorization')
|
14 |
+
"""
|
15 |
+
@staticmethod
|
16 |
+
def modify_commandline_options(parser, is_train):
|
17 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
18 |
+
|
19 |
+
Parameters:
|
20 |
+
parser -- original option parser
|
21 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
the modified parser.
|
25 |
+
|
26 |
+
By default, the number of channels for input image is 1 (L) and
|
27 |
+
the number of channels for output image is 2 (ab). The direction is from A to B
|
28 |
+
"""
|
29 |
+
parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB')
|
30 |
+
return parser
|
31 |
+
|
32 |
+
def __init__(self, opt):
|
33 |
+
"""Initialize this dataset class.
|
34 |
+
|
35 |
+
Parameters:
|
36 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
37 |
+
"""
|
38 |
+
BaseDataset.__init__(self, opt)
|
39 |
+
self.dir = os.path.join(opt.dataroot, opt.phase)
|
40 |
+
self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))
|
41 |
+
assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')
|
42 |
+
self.transform = get_transform(self.opt, convert=False)
|
43 |
+
|
44 |
+
def __getitem__(self, index):
|
45 |
+
"""Return a data point and its metadata information.
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
index - - a random integer for data indexing
|
49 |
+
|
50 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
51 |
+
A (tensor) - - the L channel of an image
|
52 |
+
B (tensor) - - the ab channels of the same image
|
53 |
+
A_paths (str) - - image paths
|
54 |
+
B_paths (str) - - image paths (same as A_paths)
|
55 |
+
"""
|
56 |
+
path = self.AB_paths[index]
|
57 |
+
im = Image.open(path).convert('RGB')
|
58 |
+
im = self.transform(im)
|
59 |
+
im = np.array(im)
|
60 |
+
lab = color.rgb2lab(im).astype(np.float32)
|
61 |
+
lab_t = transforms.ToTensor()(lab)
|
62 |
+
A = lab_t[[0], ...] / 50.0 - 1.0
|
63 |
+
B = lab_t[[1, 2], ...] / 110.0
|
64 |
+
return {'A': A, 'B': B, 'A_paths': path, 'B_paths': path}
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
"""Return the total number of images in the dataset."""
|
68 |
+
return len(self.AB_paths)
|
data/image_folder.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A modified image folder class
|
2 |
+
|
3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
4 |
+
so that this class can load images from both current directory and its subdirectories.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
|
12 |
+
IMG_EXTENSIONS = [
|
13 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
14 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
15 |
+
'.tif', '.TIF', '.tiff', '.TIFF',
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def is_image_file(filename):
|
20 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
21 |
+
|
22 |
+
|
23 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
24 |
+
images = []
|
25 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
26 |
+
|
27 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
28 |
+
for fname in fnames:
|
29 |
+
if is_image_file(fname):
|
30 |
+
path = os.path.join(root, fname)
|
31 |
+
images.append(path)
|
32 |
+
return images[:min(max_dataset_size, len(images))]
|
33 |
+
|
34 |
+
|
35 |
+
def default_loader(path):
|
36 |
+
return Image.open(path).convert('RGB')
|
37 |
+
|
38 |
+
|
39 |
+
class ImageFolder(data.Dataset):
|
40 |
+
|
41 |
+
def __init__(self, root, transform=None, return_paths=False,
|
42 |
+
loader=default_loader):
|
43 |
+
imgs = make_dataset(root)
|
44 |
+
if len(imgs) == 0:
|
45 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
46 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
47 |
+
|
48 |
+
self.root = root
|
49 |
+
self.imgs = imgs
|
50 |
+
self.transform = transform
|
51 |
+
self.return_paths = return_paths
|
52 |
+
self.loader = loader
|
53 |
+
|
54 |
+
def __getitem__(self, index):
|
55 |
+
path = self.imgs[index]
|
56 |
+
img = self.loader(path)
|
57 |
+
if self.transform is not None:
|
58 |
+
img = self.transform(img)
|
59 |
+
if self.return_paths:
|
60 |
+
return img, path
|
61 |
+
else:
|
62 |
+
return img
|
63 |
+
|
64 |
+
def __len__(self):
|
65 |
+
return len(self.imgs)
|
data/single_dataset.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data.base_dataset import BaseDataset, get_transform
|
2 |
+
from data.image_folder import make_dataset
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class SingleDataset(BaseDataset):
|
7 |
+
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
|
8 |
+
|
9 |
+
It can be used for generating CycleGAN results only for one side with the model option '-model test'.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, opt):
|
13 |
+
"""Initialize this dataset class.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
17 |
+
"""
|
18 |
+
BaseDataset.__init__(self, opt)
|
19 |
+
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
|
20 |
+
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
|
21 |
+
self.transform = get_transform(opt, grayscale=(input_nc == 1))
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
"""Return a data point and its metadata information.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
index - - a random integer for data indexing
|
28 |
+
|
29 |
+
Returns a dictionary that contains A and A_paths
|
30 |
+
A(tensor) - - an image in one domain
|
31 |
+
A_paths(str) - - the path of the image
|
32 |
+
"""
|
33 |
+
A_path = self.A_paths[index]
|
34 |
+
A_img = Image.open(A_path).convert('RGB')
|
35 |
+
A = self.transform(A_img)
|
36 |
+
return {'A': A, 'A_paths': A_path}
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
"""Return the total number of images in the dataset."""
|
40 |
+
return len(self.A_paths)
|
data/template_dataset.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Dataset class template
|
2 |
+
|
3 |
+
This module provides a template for users to implement custom datasets.
|
4 |
+
You can specify '--dataset_mode template' to use this dataset.
|
5 |
+
The class name should be consistent with both the filename and its dataset_mode option.
|
6 |
+
The filename should be <dataset_mode>_dataset.py
|
7 |
+
The class name should be <Dataset_mode>Dataset.py
|
8 |
+
You need to implement the following functions:
|
9 |
+
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
|
10 |
+
-- <__init__>: Initialize this dataset class.
|
11 |
+
-- <__getitem__>: Return a data point and its metadata information.
|
12 |
+
-- <__len__>: Return the number of images.
|
13 |
+
"""
|
14 |
+
from data.base_dataset import BaseDataset, get_transform
|
15 |
+
# from data.image_folder import make_dataset
|
16 |
+
# from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
class TemplateDataset(BaseDataset):
|
20 |
+
"""A template dataset class for you to implement custom datasets."""
|
21 |
+
@staticmethod
|
22 |
+
def modify_commandline_options(parser, is_train):
|
23 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
parser -- original option parser
|
27 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
the modified parser.
|
31 |
+
"""
|
32 |
+
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
|
33 |
+
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
|
34 |
+
return parser
|
35 |
+
|
36 |
+
def __init__(self, opt):
|
37 |
+
"""Initialize this dataset class.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
41 |
+
|
42 |
+
A few things can be done here.
|
43 |
+
- save the options (have been done in BaseDataset)
|
44 |
+
- get image paths and meta information of the dataset.
|
45 |
+
- define the image transformation.
|
46 |
+
"""
|
47 |
+
# save the option and dataset root
|
48 |
+
BaseDataset.__init__(self, opt)
|
49 |
+
# get the image paths of your dataset;
|
50 |
+
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
|
51 |
+
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
|
52 |
+
self.transform = get_transform(opt)
|
53 |
+
|
54 |
+
def __getitem__(self, index):
|
55 |
+
"""Return a data point and its metadata information.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
index -- a random integer for data indexing
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
a dictionary of data with their names. It usually contains the data itself and its metadata information.
|
62 |
+
|
63 |
+
Step 1: get a random image path: e.g., path = self.image_paths[index]
|
64 |
+
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
|
65 |
+
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
|
66 |
+
Step 4: return a data point as a dictionary.
|
67 |
+
"""
|
68 |
+
path = 'temp' # needs to be a string
|
69 |
+
data_A = None # needs to be a tensor
|
70 |
+
data_B = None # needs to be a tensor
|
71 |
+
return {'data_A': data_A, 'data_B': data_B, 'path': path}
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
"""Return the total number of images."""
|
75 |
+
return len(self.image_paths)
|
data/unaligned_dataset.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from data.base_dataset import BaseDataset, get_transform
|
3 |
+
from data.image_folder import make_dataset
|
4 |
+
from PIL import Image
|
5 |
+
import random
|
6 |
+
|
7 |
+
|
8 |
+
class UnalignedDataset(BaseDataset):
|
9 |
+
"""
|
10 |
+
This dataset class can load unaligned/unpaired datasets.
|
11 |
+
|
12 |
+
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
13 |
+
and from domain B '/path/to/data/trainB' respectively.
|
14 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
15 |
+
Similarly, you need to prepare two directories:
|
16 |
+
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, opt):
|
20 |
+
"""Initialize this dataset class.
|
21 |
+
|
22 |
+
Parameters:
|
23 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
24 |
+
"""
|
25 |
+
BaseDataset.__init__(self, opt)
|
26 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
|
27 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
|
28 |
+
|
29 |
+
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
30 |
+
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
|
31 |
+
self.A_size = len(self.A_paths) # get the size of dataset A
|
32 |
+
self.B_size = len(self.B_paths) # get the size of dataset B
|
33 |
+
btoA = self.opt.direction == 'BtoA'
|
34 |
+
input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
|
35 |
+
output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
|
36 |
+
self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
|
37 |
+
self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
"""Return a data point and its metadata information.
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
index (int) -- a random integer for data indexing
|
44 |
+
|
45 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
46 |
+
A (tensor) -- an image in the input domain
|
47 |
+
B (tensor) -- its corresponding image in the target domain
|
48 |
+
A_paths (str) -- image paths
|
49 |
+
B_paths (str) -- image paths
|
50 |
+
"""
|
51 |
+
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
|
52 |
+
if self.opt.serial_batches: # make sure index is within then range
|
53 |
+
index_B = index % self.B_size
|
54 |
+
else: # randomize the index for domain B to avoid fixed pairs.
|
55 |
+
index_B = random.randint(0, self.B_size - 1)
|
56 |
+
B_path = self.B_paths[index_B]
|
57 |
+
A_img = Image.open(A_path).convert('RGB')
|
58 |
+
B_img = Image.open(B_path).convert('RGB')
|
59 |
+
# apply image transformation
|
60 |
+
A = self.transform_A(A_img)
|
61 |
+
B = self.transform_B(B_img)
|
62 |
+
|
63 |
+
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
"""Return the total number of images in the dataset.
|
67 |
+
|
68 |
+
As we have two datasets with potentially different number of images,
|
69 |
+
we take a maximum of
|
70 |
+
"""
|
71 |
+
return max(self.A_size, self.B_size)
|