|
import json |
|
from typing import Optional, Sequence |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as ptdist |
|
from monai.data import ( |
|
CacheDataset, |
|
PersistentDataset, |
|
partition_dataset, |
|
) |
|
from monai.data.utils import pad_list_data_collate |
|
from monai.transforms import ( |
|
Compose, |
|
CropForegroundd, |
|
EnsureChannelFirstd, |
|
LoadImaged, |
|
Orientationd, |
|
RandSpatialCropSamplesd, |
|
ScaleIntensityRanged, |
|
Spacingd, |
|
SpatialPadd, |
|
CenterSpatialCropd, |
|
ToTensord, |
|
Transform, |
|
) |
|
|
|
|
|
class PermuteImage(Transform): |
|
"""Permute the dimensions of the image""" |
|
|
|
def __call__(self, data): |
|
data["image"] = data["image"].permute( |
|
3, 0, 1, 2 |
|
) |
|
return data |
|
|
|
|
|
class CTDataset: |
|
def __init__( |
|
self, |
|
json_path: str, |
|
img_size: int, |
|
depth: int, |
|
mask_patch_size: int, |
|
patch_size: int, |
|
downsample_ratio: Sequence[float], |
|
cache_dir: str, |
|
batch_size: int = 1, |
|
val_batch_size: int = 1, |
|
num_workers: int = 4, |
|
cache_num: int = 0, |
|
cache_rate: float = 0.0, |
|
dist: bool = False, |
|
): |
|
super().__init__() |
|
self.json_path = json_path |
|
self.img_size = img_size |
|
self.depth = depth |
|
self.mask_patch_size = mask_patch_size |
|
self.patch_size = patch_size |
|
self.cache_dir = cache_dir |
|
self.downsample_ratio = downsample_ratio |
|
self.batch_size = batch_size |
|
self.val_batch_size = val_batch_size |
|
self.num_workers = num_workers |
|
self.cache_num = cache_num |
|
self.cache_rate = cache_rate |
|
self.dist = dist |
|
|
|
data_list = json.load(open(json_path, "r")) |
|
|
|
if "train" in data_list.keys(): |
|
self.train_list = data_list["train"] |
|
if "validation" in data_list.keys(): |
|
self.val_list = data_list["validation"] |
|
|
|
def val_transforms( |
|
self, |
|
): |
|
return self.train_transforms() |
|
|
|
def train_transforms( |
|
self, |
|
): |
|
transforms = Compose( |
|
[ |
|
LoadImaged(keys=["image"]), |
|
EnsureChannelFirstd(keys=["image"]), |
|
Orientationd(keys=["image"], axcodes="RAS"), |
|
Spacingd( |
|
keys=["image"], |
|
pixdim=self.downsample_ratio, |
|
mode=("bilinear"), |
|
), |
|
ScaleIntensityRanged( |
|
keys=["image"], |
|
a_min=-175, |
|
a_max=250, |
|
b_min=0.0, |
|
b_max=1.0, |
|
clip=True, |
|
), |
|
CropForegroundd(keys=["image"], source_key="image"), |
|
CenterSpatialCropd( |
|
keys=["image"], |
|
roi_size=(self.img_size, self.img_size, self.depth), |
|
), |
|
SpatialPadd( |
|
keys=["image"], |
|
spatial_size=(self.img_size, self.img_size, self.depth), |
|
), |
|
|
|
|
|
ToTensord(keys=["image"]), |
|
PermuteImage(), |
|
] |
|
) |
|
|
|
return transforms |
|
|
|
def setup(self, stage: Optional[str] = None): |
|
|
|
if stage in [None, "train"]: |
|
if self.dist: |
|
train_partition = partition_dataset( |
|
data=self.train_list, |
|
num_partitions=ptdist.get_world_size(), |
|
shuffle=True, |
|
even_divisible=True, |
|
drop_last=False, |
|
)[ptdist.get_rank()] |
|
valid_partition = partition_dataset( |
|
data=self.val_list, |
|
num_partitions=ptdist.get_world_size(), |
|
shuffle=False, |
|
even_divisible=True, |
|
drop_last=False, |
|
)[ptdist.get_rank()] |
|
|
|
else: |
|
train_partition = self.train_list |
|
valid_partition = self.val_list |
|
|
|
if any([self.cache_num, self.cache_rate]) > 0: |
|
train_ds = CacheDataset( |
|
train_partition, |
|
cache_num=self.cache_num, |
|
cache_rate=self.cache_rate, |
|
num_workers=self.num_workers, |
|
transform=self.train_transforms(), |
|
) |
|
valid_ds = CacheDataset( |
|
valid_partition, |
|
cache_num=self.cache_num // 4, |
|
cache_rate=self.cache_rate, |
|
num_workers=self.num_workers, |
|
transform=self.val_transforms(), |
|
) |
|
else: |
|
train_ds = PersistentDataset( |
|
train_partition, |
|
transform=self.train_transforms(), |
|
cache_dir=self.cache_dir, |
|
) |
|
valid_ds = PersistentDataset( |
|
valid_partition, |
|
transform=self.val_transforms(), |
|
cache_dir=self.cache_dir, |
|
) |
|
|
|
return {"train": train_ds, "validation": valid_ds} |
|
|
|
if stage in [None, "test"]: |
|
if any([self.cache_num, self.cache_rate]) > 0: |
|
test_ds = CacheDataset( |
|
self.val_list, |
|
cache_num=self.cache_num // 4, |
|
cache_rate=self.cache_rate, |
|
num_workers=self.num_workers, |
|
transform=self.val_transforms(), |
|
) |
|
else: |
|
test_ds = PersistentDataset( |
|
self.val_list, |
|
transform=self.val_transforms(), |
|
cache_dir=self.cache_dir, |
|
) |
|
|
|
return {"test": test_ds} |
|
|
|
return {"train": None, "validation": None} |
|
|
|
def train_dataloader(self, train_ds): |
|
|
|
|
|
|
|
|
|
|
|
return torch.utils.data.DataLoader( |
|
train_ds, |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
shuffle=True, |
|
collate_fn=pad_list_data_collate, |
|
|
|
|
|
|
|
) |
|
|
|
def val_dataloader(self, valid_ds): |
|
return torch.utils.data.DataLoader( |
|
valid_ds, |
|
batch_size=self.val_batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
shuffle=False, |
|
|
|
collate_fn=pad_list_data_collate, |
|
|
|
) |
|
|
|
def test_dataloader(self, test_ds): |
|
return torch.utils.data.DataLoader( |
|
test_ds, |
|
batch_size=self.val_batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
shuffle=False, |
|
|
|
collate_fn=pad_list_data_collate, |
|
|
|
) |
|
|