smb-vision-large / dataload.py
chenz53's picture
Update dataload.py
f8a660a verified
import json
from typing import Optional, Sequence
import numpy as np
import torch
import torch.distributed as ptdist
from monai.data import (
CacheDataset,
PersistentDataset,
partition_dataset,
)
from monai.data.utils import pad_list_data_collate
from monai.transforms import (
Compose,
CropForegroundd,
EnsureChannelFirstd,
LoadImaged,
Orientationd,
RandSpatialCropSamplesd,
ScaleIntensityRanged,
Spacingd,
SpatialPadd,
CenterSpatialCropd,
ToTensord,
Transform,
)
class PermuteImage(Transform):
"""Permute the dimensions of the image"""
def __call__(self, data):
data["image"] = data["image"].permute(
3, 0, 1, 2
) # Adjust permutation order as needed
return data
class CTDataset:
def __init__(
self,
json_path: str,
img_size: int,
depth: int,
mask_patch_size: int,
patch_size: int,
downsample_ratio: Sequence[float],
cache_dir: str,
batch_size: int = 1,
val_batch_size: int = 1,
num_workers: int = 4,
cache_num: int = 0,
cache_rate: float = 0.0,
dist: bool = False,
):
super().__init__()
self.json_path = json_path
self.img_size = img_size
self.depth = depth
self.mask_patch_size = mask_patch_size
self.patch_size = patch_size
self.cache_dir = cache_dir
self.downsample_ratio = downsample_ratio
self.batch_size = batch_size
self.val_batch_size = val_batch_size
self.num_workers = num_workers
self.cache_num = cache_num
self.cache_rate = cache_rate
self.dist = dist
data_list = json.load(open(json_path, "r"))
if "train" in data_list.keys():
self.train_list = data_list["train"]
if "validation" in data_list.keys():
self.val_list = data_list["validation"]
def val_transforms(
self,
):
return self.train_transforms()
def train_transforms(
self,
):
transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys=["image"]),
Orientationd(keys=["image"], axcodes="RAS"),
Spacingd(
keys=["image"],
pixdim=self.downsample_ratio,
mode=("bilinear"),
),
ScaleIntensityRanged(
keys=["image"],
a_min=-175,
a_max=250,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image"], source_key="image"),
CenterSpatialCropd(
keys=["image"],
roi_size=(self.img_size, self.img_size, self.depth),
),
SpatialPadd(
keys=["image"],
spatial_size=(self.img_size, self.img_size, self.depth),
),
# RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
# RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
ToTensord(keys=["image"]),
PermuteImage(),
]
)
return transforms
def setup(self, stage: Optional[str] = None):
# Assign Train split(s) for use in Dataloaders
if stage in [None, "train"]:
if self.dist:
train_partition = partition_dataset(
data=self.train_list,
num_partitions=ptdist.get_world_size(),
shuffle=True,
even_divisible=True,
drop_last=False,
)[ptdist.get_rank()]
valid_partition = partition_dataset(
data=self.val_list,
num_partitions=ptdist.get_world_size(),
shuffle=False,
even_divisible=True,
drop_last=False,
)[ptdist.get_rank()]
# self.cache_num //= ptdist.get_world_size()
else:
train_partition = self.train_list
valid_partition = self.val_list
if any([self.cache_num, self.cache_rate]) > 0:
train_ds = CacheDataset(
train_partition,
cache_num=self.cache_num,
cache_rate=self.cache_rate,
num_workers=self.num_workers,
transform=self.train_transforms(),
)
valid_ds = CacheDataset(
valid_partition,
cache_num=self.cache_num // 4,
cache_rate=self.cache_rate,
num_workers=self.num_workers,
transform=self.val_transforms(),
)
else:
train_ds = PersistentDataset(
train_partition,
transform=self.train_transforms(),
cache_dir=self.cache_dir,
)
valid_ds = PersistentDataset(
valid_partition,
transform=self.val_transforms(),
cache_dir=self.cache_dir,
)
return {"train": train_ds, "validation": valid_ds}
if stage in [None, "test"]:
if any([self.cache_num, self.cache_rate]) > 0:
test_ds = CacheDataset(
self.val_list,
cache_num=self.cache_num // 4,
cache_rate=self.cache_rate,
num_workers=self.num_workers,
transform=self.val_transforms(),
)
else:
test_ds = PersistentDataset(
self.val_list,
transform=self.val_transforms(),
cache_dir=self.cache_dir,
)
return {"test": test_ds}
return {"train": None, "validation": None}
def train_dataloader(self, train_ds):
# def collate_fn(examples):
# pixel_values = torch.stack([example["image"] for example in examples])
# mask = torch.stack([example["mask"] for example in examples])
# return {"pixel_values": pixel_values, "bool_masked_pos": mask}
return torch.utils.data.DataLoader(
train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=True,
collate_fn=pad_list_data_collate,
# collate_fn=collate_fn
# drop_last=False,
# prefetch_factor=4,
)
def val_dataloader(self, valid_ds):
return torch.utils.data.DataLoader(
valid_ds,
batch_size=self.val_batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
# drop_last=False,
collate_fn=pad_list_data_collate,
# prefetch_factor=4,
)
def test_dataloader(self, test_ds):
return torch.utils.data.DataLoader(
test_ds,
batch_size=self.val_batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
# drop_last=False,
collate_fn=pad_list_data_collate,
# prefetch_factor=4,
)