|
import random |
|
from torchvision.datasets import ImageFolder |
|
import torchvision.transforms as T |
|
from PIL import ImageFilter |
|
from .transforms import MultiSample, aug_transform |
|
from .base import BaseDataset |
|
|
|
|
|
class RandomBlur: |
|
def __init__(self, r0, r1): |
|
self.r0, self.r1 = r0, r1 |
|
|
|
def __call__(self, image): |
|
r = random.uniform(self.r0, self.r1) |
|
return image.filter(ImageFilter.GaussianBlur(radius=r)) |
|
|
|
|
|
def base_transform(): |
|
return T.Compose( |
|
[T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] |
|
) |
|
|
|
|
|
class ImageNet(BaseDataset): |
|
def ds_train(self): |
|
aug_with_blur = aug_transform( |
|
224, |
|
base_transform, |
|
self.aug_cfg, |
|
extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)], |
|
) |
|
t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples) |
|
return ImageFolder(root=self.aug_cfg.imagenet_path + "train", transform=t) |
|
|
|
def ds_clf(self): |
|
t = base_transform() |
|
return ImageFolder(root=self.aug_cfg.imagenet_path + "clf", transform=t) |
|
|
|
def ds_test(self): |
|
t = base_transform() |
|
return ImageFolder(root=self.aug_cfg.imagenet_path + "test", transform=t) |
|
|