from .cifar10 import CIFAR10 | |
from .cifar100 import CIFAR100 | |
from .stl10 import STL10 | |
from .tiny_in import TinyImageNet | |
from .imagenet import ImageNet | |
DS_LIST = ["cifar10", "cifar100", "stl10", "tinyimagenet", "imagenet"] | |
def get_ds(name): | |
assert name in DS_LIST | |
if name == "cifar10": | |
return CIFAR10 | |
elif name == "cifar100": | |
return CIFAR100 | |
elif name == "stl10": | |
return STL10 | |
elif name == "tinyimagenet": | |
return TinyImageNet | |
elif name == "imagenet": | |
return ImageNet | |