File size: 929 Bytes
803ef9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torchvision.transforms as T
def aug_transform(crop, base_transform, cfg, extra_t=[]):
""" augmentation transform generated from config """
return T.Compose(
[
T.RandomApply(
[T.ColorJitter(cfg.cj0, cfg.cj1, cfg.cj2, cfg.cj3)], p=cfg.cj_p
),
T.RandomGrayscale(p=cfg.gs_p),
T.RandomResizedCrop(
crop,
scale=(cfg.crop_s0, cfg.crop_s1),
ratio=(cfg.crop_r0, cfg.crop_r1),
interpolation=3,
),
T.RandomHorizontalFlip(p=cfg.hf_p),
*extra_t,
base_transform(),
]
)
class MultiSample:
""" generates n samples with augmentation """
def __init__(self, transform, n=2):
self.transform = transform
self.num = n
def __call__(self, x):
return tuple(self.transform(x) for _ in range(self.num))
|