|
from abc import ABCMeta, abstractmethod |
|
from functools import lru_cache |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
class BaseDataset(metaclass=ABCMeta): |
|
""" |
|
base class for datasets, it includes 3 types: |
|
- for self-supervised training, |
|
- for classifier training for evaluation, |
|
- for testing |
|
""" |
|
|
|
def __init__( |
|
self, bs_train, aug_cfg, num_workers, bs_clf=1000, bs_test=1000, |
|
): |
|
self.aug_cfg = aug_cfg |
|
self.bs_train, self.bs_clf, self.bs_test = bs_train, bs_clf, bs_test |
|
self.num_workers = num_workers |
|
|
|
@abstractmethod |
|
def ds_train(self): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def ds_clf(self): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def ds_test(self): |
|
raise NotImplementedError |
|
|
|
@property |
|
@lru_cache() |
|
def train(self): |
|
return DataLoader( |
|
dataset=self.ds_train(), |
|
batch_size=self.bs_train, |
|
shuffle=True, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
drop_last=True, |
|
) |
|
|
|
@property |
|
@lru_cache() |
|
def clf(self): |
|
return DataLoader( |
|
dataset=self.ds_clf(), |
|
batch_size=self.bs_clf, |
|
shuffle=True, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
drop_last=True, |
|
) |
|
|
|
@property |
|
@lru_cache() |
|
def test(self): |
|
return DataLoader( |
|
dataset=self.ds_test(), |
|
batch_size=self.bs_test, |
|
shuffle=False, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
|