ctheodoris
commited on
update pretrainer to not use distributed sampler (Trainer uses accelerate)
Browse files- geneformer/pretrainer.py +5 -171
geneformer/pretrainer.py
CHANGED
@@ -607,7 +607,7 @@ class GeneformerPretrainer(Trainer):
|
|
607 |
)
|
608 |
super().__init__(*args, **kwargs)
|
609 |
|
610 |
-
#
|
611 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
612 |
if not isinstance(self.train_dataset, collections.abc.Sized):
|
613 |
return None
|
@@ -630,181 +630,15 @@ class GeneformerPretrainer(Trainer):
|
|
630 |
if self.tokenizer is not None
|
631 |
else None
|
632 |
)
|
633 |
-
|
634 |
-
return LengthGroupedSampler(
|
635 |
dataset=self.train_dataset,
|
636 |
batch_size=self.args.train_batch_size,
|
637 |
lengths=lengths,
|
638 |
model_input_name=model_input_name,
|
639 |
generator=generator,
|
640 |
-
)
|
641 |
-
else:
|
642 |
-
return CustomDistributedLengthGroupedSampler(
|
643 |
-
dataset=self.train_dataset,
|
644 |
-
batch_size=self.args.train_batch_size,
|
645 |
-
num_replicas=self.args.world_size,
|
646 |
-
rank=self.args.process_index,
|
647 |
-
lengths=lengths,
|
648 |
-
model_input_name=model_input_name,
|
649 |
-
seed=self.args.seed,
|
650 |
-
)
|
651 |
-
|
652 |
-
else:
|
653 |
-
if self.args.world_size <= 1:
|
654 |
-
if _is_torch_generator_available:
|
655 |
-
return RandomSampler(self.train_dataset, generator=generator)
|
656 |
-
return RandomSampler(self.train_dataset)
|
657 |
-
elif (
|
658 |
-
self.args.parallel_mode
|
659 |
-
in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
660 |
-
and not self.args.dataloader_drop_last
|
661 |
-
):
|
662 |
-
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
663 |
-
return DistributedSamplerWithLoop(
|
664 |
-
self.train_dataset,
|
665 |
-
batch_size=self.args.per_device_train_batch_size,
|
666 |
-
num_replicas=self.args.world_size,
|
667 |
-
rank=self.args.process_index,
|
668 |
-
seed=self.args.seed,
|
669 |
-
)
|
670 |
-
else:
|
671 |
-
return DistributedSampler(
|
672 |
-
self.train_dataset,
|
673 |
-
num_replicas=self.args.world_size,
|
674 |
-
rank=self.args.process_index,
|
675 |
-
seed=self.args.seed,
|
676 |
-
)
|
677 |
-
|
678 |
-
|
679 |
-
class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
680 |
-
r"""
|
681 |
-
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
682 |
-
length while keeping a bit of randomness.
|
683 |
-
"""
|
684 |
-
|
685 |
-
# Copied and adapted from PyTorch DistributedSampler.
|
686 |
-
def __init__(
|
687 |
-
self,
|
688 |
-
dataset: Dataset,
|
689 |
-
batch_size: int,
|
690 |
-
num_replicas: Optional[int] = None,
|
691 |
-
rank: Optional[int] = None,
|
692 |
-
seed: int = 0,
|
693 |
-
drop_last: bool = False,
|
694 |
-
lengths: Optional[List[int]] = None,
|
695 |
-
model_input_name: Optional[str] = None,
|
696 |
-
):
|
697 |
-
if num_replicas is None:
|
698 |
-
if not dist.is_available():
|
699 |
-
raise RuntimeError("Requires distributed package to be available")
|
700 |
-
num_replicas = dist.get_world_size()
|
701 |
-
if rank is None:
|
702 |
-
if not dist.is_available():
|
703 |
-
raise RuntimeError("Requires distributed package to be available")
|
704 |
-
rank = dist.get_rank()
|
705 |
-
self.dataset = dataset
|
706 |
-
self.batch_size = batch_size
|
707 |
-
self.num_replicas = num_replicas
|
708 |
-
self.rank = rank
|
709 |
-
self.epoch = 0
|
710 |
-
self.drop_last = drop_last
|
711 |
-
# If the dataset length is evenly divisible by # of replicas, then there
|
712 |
-
# is no need to drop any data, since the dataset will be split equally.
|
713 |
-
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
714 |
-
# Split to nearest available length that is evenly divisible.
|
715 |
-
# This is to ensure each rank receives the same amount of data when
|
716 |
-
# using this Sampler.
|
717 |
-
self.num_samples = math.ceil(
|
718 |
-
(len(self.dataset) - self.num_replicas) / self.num_replicas
|
719 |
)
|
720 |
-
else:
|
721 |
-
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
722 |
-
self.total_size = self.num_samples * self.num_replicas
|
723 |
-
self.seed = seed
|
724 |
-
self.model_input_name = (
|
725 |
-
model_input_name if model_input_name is not None else "input_ids"
|
726 |
-
)
|
727 |
-
|
728 |
-
if lengths is None:
|
729 |
-
print("Lengths is none - calculating lengths.")
|
730 |
-
if (
|
731 |
-
not (
|
732 |
-
isinstance(dataset[0], dict)
|
733 |
-
or isinstance(dataset[0], BatchEncoding)
|
734 |
-
)
|
735 |
-
or self.model_input_name not in dataset[0]
|
736 |
-
):
|
737 |
-
raise ValueError(
|
738 |
-
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
739 |
-
f"'{self.model_input_name}' key."
|
740 |
-
)
|
741 |
-
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
742 |
-
self.lengths = lengths
|
743 |
-
|
744 |
-
def __iter__(self) -> Iterator:
|
745 |
-
# Deterministically shuffle based on epoch and seed
|
746 |
-
g = torch.Generator()
|
747 |
-
g.manual_seed(self.seed + self.epoch)
|
748 |
-
|
749 |
-
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
750 |
|
751 |
-
if not self.drop_last:
|
752 |
-
# add extra samples to make it evenly divisible
|
753 |
-
indices += indices[: (self.total_size - len(indices))]
|
754 |
else:
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
# subsample
|
760 |
-
indices = indices[self.rank : self.total_size : self.num_replicas]
|
761 |
-
assert len(indices) == self.num_samples
|
762 |
-
|
763 |
-
return iter(indices)
|
764 |
-
|
765 |
-
|
766 |
-
def get_length_grouped_indices(
|
767 |
-
lengths, batch_size, mega_batch_mult=None, generator=None
|
768 |
-
):
|
769 |
-
"""
|
770 |
-
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
771 |
-
similar lengths. To do this, the indices are:
|
772 |
-
|
773 |
-
- randomly permuted
|
774 |
-
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
775 |
-
- sorted by length in each mega-batch
|
776 |
-
|
777 |
-
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
778 |
-
maximum length placed first, so that an OOM happens sooner rather than later.
|
779 |
-
"""
|
780 |
-
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
781 |
-
if mega_batch_mult is None:
|
782 |
-
# mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
783 |
-
mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
|
784 |
-
# Just in case, for tiny datasets
|
785 |
-
if mega_batch_mult == 0:
|
786 |
-
mega_batch_mult = 1
|
787 |
-
|
788 |
-
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
789 |
-
indices = torch.randperm(len(lengths), generator=generator)
|
790 |
-
megabatch_size = mega_batch_mult * batch_size
|
791 |
-
megabatches = [
|
792 |
-
indices[i : i + megabatch_size].tolist()
|
793 |
-
for i in range(0, len(lengths), megabatch_size)
|
794 |
-
]
|
795 |
-
megabatches = [
|
796 |
-
list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
|
797 |
-
for megabatch in megabatches
|
798 |
-
]
|
799 |
-
|
800 |
-
# The rest is to get the biggest batch first.
|
801 |
-
# Since each megabatch is sorted by descending length, the longest element is the first
|
802 |
-
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
803 |
-
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
804 |
-
# Switch to put the longest element in first position
|
805 |
-
megabatches[0][0], megabatches[max_idx][0] = (
|
806 |
-
megabatches[max_idx][0],
|
807 |
-
megabatches[0][0],
|
808 |
-
)
|
809 |
-
|
810 |
-
return [item for sublist in megabatches for item in sublist]
|
|
|
607 |
)
|
608 |
super().__init__(*args, **kwargs)
|
609 |
|
610 |
+
# updated to not use distributed sampler since Trainer now distributes with accelerate
|
611 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
612 |
if not isinstance(self.train_dataset, collections.abc.Sized):
|
613 |
return None
|
|
|
630 |
if self.tokenizer is not None
|
631 |
else None
|
632 |
)
|
633 |
+
return LengthGroupedSampler(
|
|
|
634 |
dataset=self.train_dataset,
|
635 |
batch_size=self.args.train_batch_size,
|
636 |
lengths=lengths,
|
637 |
model_input_name=model_input_name,
|
638 |
generator=generator,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
|
|
|
|
|
|
|
641 |
else:
|
642 |
+
if _is_torch_generator_available:
|
643 |
+
return RandomSampler(self.train_dataset, generator=generator)
|
644 |
+
return RandomSampler(self.train_dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|