from torch.utils.data import DataLoader class PriorDataLoader(DataLoader): pass # init accepts num_steps as first argument # has two attributes set on class or object level: # num_features: int and # num_outputs: int # fuse_x_y: bool # Optional: validate function that accepts a transformer model