File size: 533 Bytes
803ef9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from .cifar10 import CIFAR10
from .cifar100 import CIFAR100
from .stl10 import STL10
from .tiny_in import TinyImageNet
from .imagenet import ImageNet


DS_LIST = ["cifar10", "cifar100", "stl10", "tinyimagenet", "imagenet"]


def get_ds(name):
    assert name in DS_LIST
    if name == "cifar10":
        return CIFAR10
    elif name == "cifar100":
        return CIFAR100
    elif name == "stl10":
        return STL10
    elif name == "tinyimagenet":
        return TinyImageNet
    elif name == "imagenet":
        return ImageNet