File size: 1,689 Bytes
803ef9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from abc import ABCMeta, abstractmethod
from functools import lru_cache
from torch.utils.data import DataLoader


class BaseDataset(metaclass=ABCMeta):
    """
        base class for datasets, it includes 3 types:
            - for self-supervised training,
            - for classifier training for evaluation,
            - for testing
    """

    def __init__(
        self, bs_train, aug_cfg, num_workers, bs_clf=1000, bs_test=1000,
    ):
        self.aug_cfg = aug_cfg
        self.bs_train, self.bs_clf, self.bs_test = bs_train, bs_clf, bs_test
        self.num_workers = num_workers

    @abstractmethod
    def ds_train(self):
        raise NotImplementedError

    @abstractmethod
    def ds_clf(self):
        raise NotImplementedError

    @abstractmethod
    def ds_test(self):
        raise NotImplementedError

    @property
    @lru_cache()
    def train(self):
        return DataLoader(
            dataset=self.ds_train(),
            batch_size=self.bs_train,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
        )

    @property
    @lru_cache()
    def clf(self):
        return DataLoader(
            dataset=self.ds_clf(),
            batch_size=self.bs_clf,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
        )

    @property
    @lru_cache()
    def test(self):
        return DataLoader(
            dataset=self.ds_test(),
            batch_size=self.bs_test,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )