wgcban's picture
Upload 98 files
803ef9e
raw
history blame
1.69 kB
from abc import ABCMeta, abstractmethod
from functools import lru_cache
from torch.utils.data import DataLoader
class BaseDataset(metaclass=ABCMeta):
"""
base class for datasets, it includes 3 types:
- for self-supervised training,
- for classifier training for evaluation,
- for testing
"""
def __init__(
self, bs_train, aug_cfg, num_workers, bs_clf=1000, bs_test=1000,
):
self.aug_cfg = aug_cfg
self.bs_train, self.bs_clf, self.bs_test = bs_train, bs_clf, bs_test
self.num_workers = num_workers
@abstractmethod
def ds_train(self):
raise NotImplementedError
@abstractmethod
def ds_clf(self):
raise NotImplementedError
@abstractmethod
def ds_test(self):
raise NotImplementedError
@property
@lru_cache()
def train(self):
return DataLoader(
dataset=self.ds_train(),
batch_size=self.bs_train,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
)
@property
@lru_cache()
def clf(self):
return DataLoader(
dataset=self.ds_clf(),
batch_size=self.bs_clf,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
)
@property
@lru_cache()
def test(self):
return DataLoader(
dataset=self.ds_test(),
batch_size=self.bs_test,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
drop_last=False,
)