import torchvision.transforms as T def aug_transform(crop, base_transform, cfg, extra_t=[]): """ augmentation transform generated from config """ return T.Compose( [ T.RandomApply( [T.ColorJitter(cfg.cj0, cfg.cj1, cfg.cj2, cfg.cj3)], p=cfg.cj_p ), T.RandomGrayscale(p=cfg.gs_p), T.RandomResizedCrop( crop, scale=(cfg.crop_s0, cfg.crop_s1), ratio=(cfg.crop_r0, cfg.crop_r1), interpolation=3, ), T.RandomHorizontalFlip(p=cfg.hf_p), *extra_t, base_transform(), ] ) class MultiSample: """ generates n samples with augmentation """ def __init__(self, transform, n=2): self.transform = transform self.num = n def __call__(self, x): return tuple(self.transform(x) for _ in range(self.num))