rnwang's picture
add license
09e9630
raw
history blame
7.74 kB
# This file is copied from https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/data.py
# MIT License
# Copyright (c) 2022 Lorenzo Breschi
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import Callable, Dict
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as F
from torchvision import transforms
import pytorch_lightning as pl
from collections.abc import Iterable
# image reader writer
from pathlib import Path
from PIL import Image
from typing import Tuple
def read_image(filepath: Path, mode: str = None) -> Image:
with open(filepath, 'rb') as file:
image = Image.open(file)
return image.convert(mode)
image2tensor = transforms.ToTensor()
tensor2image = transforms.ToPILImage()
def write_image(image: Image, filepath: Path):
filepath.parent.mkdir(parents=True, exist_ok=True)
image.save(str(filepath))
def read_image_tensor(filepath: Path, mode: str = 'RGB') -> torch.Tensor:
return image2tensor(read_image(filepath, mode))
def write_image_tensor(input: torch.Tensor, filepath: Path):
write_image(tensor2image(input), filepath)
def get_valid_indices(H: int, W: int, patch_size: int, random_overlap: int = 0):
vih = torch.arange(random_overlap, H-patch_size -
random_overlap+1, patch_size)
viw = torch.arange(random_overlap, W-patch_size -
random_overlap+1, patch_size)
if random_overlap > 0:
rih = torch.randint_like(vih, -random_overlap, random_overlap)
riw = torch.randint_like(viw, -random_overlap, random_overlap)
vih += rih
viw += riw
vi = torch.stack(torch.meshgrid(vih, viw)).view(2, -1).t()
return vi
def cut_patches(input: torch.Tensor, indices: Tuple[Tuple[int, int]], patch_size: int, padding: int = 0):
# TODO use slices to get all patches at the same time ?
patches_l = []
for n in range(len(indices)):
patch = F.crop(input, *(indices[n]-padding),
*(patch_size+padding*2,)*2)
patches_l.append(patch)
patches = torch.cat(patches_l, dim=0)
return patches
def prepare_data(data_path: Path, read_func: Callable = read_image_tensor) -> Dict:
"""
Takes a data_path of a folder which contains subfolders with input, target, etc.
lablelled by the same names.
:param data_path: Path of the folder containing data
:param read_func: function that reads data and returns a tensor
"""
data_dict = {}
subdir_names = ["target", "input", "mask"] # ,"helper"
# checks only files for which there is an target
# TODO check for images
name_ls = [file.name for file in (
data_path / "target").iterdir() if file.is_file()]
subdirs = [data_path / sdn for sdn in subdir_names]
for sd in subdirs:
if sd.is_dir():
data_ls = []
files = [sd / name for name in name_ls]
for file in files:
tensor = read_func(file)
H, W = tensor.shape[-2:]
data_ls.append(tensor)
# TODO check that all sizes match
data_dict[sd.name] = torch.stack(data_ls, dim=0)
data_dict['name'] = name_ls
data_dict['len'] = len(data_dict['name'])
data_dict['H'] = H
data_dict['W'] = W
return data_dict
# TODO an image is loaded whenever a patch is needed, this may be a bottleneck
class DataDictLoader():
def __init__(self, data_dict: Dict,
batch_size: int = 16,
max_length: int = 128,
shuffle: bool = False):
"""
"""
self.batch_size = batch_size
self.shuffle = shuffle
self.batch_size = batch_size
self.data_dict = data_dict
self.dataset_len = data_dict['len']
self.len = self.dataset_len if max_length is None else min(
self.dataset_len, max_length)
# Calculate # batches
num_batches, remainder = divmod(self.len, self.batch_size)
if remainder > 0:
num_batches += 1
self.num_batches = num_batches
def __iter__(self):
if self.shuffle:
r = torch.randperm(self.dataset_len)
self.data_dict = {k: v[r] if isinstance(
v, Iterable) else v for k, v in self.data_dict.items()}
self.i = 0
return self
def __next__(self):
if self.i >= self.len:
raise StopIteration
batch = {k: v[self.i:self.i+self.batch_size]
if isinstance(v, Iterable) else v for k, v in self.data_dict.items()}
self.i += self.batch_size
return batch
def __len__(self):
return self.num_batches
class PatchDataModule(pl.LightningDataModule):
def __init__(self, data_dict,
patch_size: int = 2**5,
batch_size: int = 2**4,
patch_num: int = 2**6):
super().__init__()
self.data_dict = data_dict
self.H, self.W = data_dict['H'], data_dict['W']
self.len = data_dict['len']
self.batch_size = batch_size
self.patch_size = patch_size
self.patch_num = patch_num
def dataloader(self, data_dict, **kwargs):
return DataDictLoader(data_dict, **kwargs)
def train_dataloader(self):
patches = self.cut_patches()
return self.dataloader(patches, batch_size=self.batch_size, shuffle=True,
max_length=self.patch_num)
def val_dataloader(self):
return self.dataloader(self.data_dict, batch_size=1)
def test_dataloader(self):
return self.dataloader(self.data_dict) # TODO batch size
def cut_patches(self):
# TODO cycle once
patch_indices = get_valid_indices(
self.H, self.W, self.patch_size, self.patch_size//4)
dd = {k: cut_patches(
v, patch_indices, self.patch_size) for k, v in self.data_dict.items()
if isinstance(v, torch.Tensor)
}
threshold = 0.1
mask_p = torch.mean(
dd.get('mask', torch.ones_like(dd['input'])), dim=(-1, -2, -3))
masked_idx = (mask_p > threshold).nonzero(as_tuple=True)[0]
dd = {k: v[masked_idx] for k, v in dd.items()}
dd['len'] = len(masked_idx)
dd['H'], dd['W'] = (self.patch_size,)*2
return dd
class ImageDataset(Dataset):
def __init__(self, file_paths: Iterable, read_func: Callable = read_image_tensor):
self.file_paths = file_paths
def __getitem__(self, idx: int) -> dict:
file = self.file_paths[idx]
return read_image_tensor(file), file.name
def __len__(self) -> int:
return len(self.file_paths)