mix-bt / ssl-sota /datasets /imagenet.py
wgcban's picture
Upload 98 files
803ef9e
raw
history blame
1.23 kB
import random
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from PIL import ImageFilter
from .transforms import MultiSample, aug_transform
from .base import BaseDataset
class RandomBlur:
def __init__(self, r0, r1):
self.r0, self.r1 = r0, r1
def __call__(self, image):
r = random.uniform(self.r0, self.r1)
return image.filter(ImageFilter.GaussianBlur(radius=r))
def base_transform():
return T.Compose(
[T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)
class ImageNet(BaseDataset):
def ds_train(self):
aug_with_blur = aug_transform(
224,
base_transform,
self.aug_cfg,
extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)],
)
t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples)
return ImageFolder(root=self.aug_cfg.imagenet_path + "train", transform=t)
def ds_clf(self):
t = base_transform()
return ImageFolder(root=self.aug_cfg.imagenet_path + "clf", transform=t)
def ds_test(self):
t = base_transform()
return ImageFolder(root=self.aug_cfg.imagenet_path + "test", transform=t)