mix-bt / ssl-sota /datasets /transforms.py
wgcban's picture
Upload 98 files
803ef9e
raw
history blame
929 Bytes
import torchvision.transforms as T
def aug_transform(crop, base_transform, cfg, extra_t=[]):
""" augmentation transform generated from config """
return T.Compose(
[
T.RandomApply(
[T.ColorJitter(cfg.cj0, cfg.cj1, cfg.cj2, cfg.cj3)], p=cfg.cj_p
),
T.RandomGrayscale(p=cfg.gs_p),
T.RandomResizedCrop(
crop,
scale=(cfg.crop_s0, cfg.crop_s1),
ratio=(cfg.crop_r0, cfg.crop_r1),
interpolation=3,
),
T.RandomHorizontalFlip(p=cfg.hf_p),
*extra_t,
base_transform(),
]
)
class MultiSample:
""" generates n samples with augmentation """
def __init__(self, transform, n=2):
self.transform = transform
self.num = n
def __call__(self, x):
return tuple(self.transform(x) for _ in range(self.num))