ctheodoris commited on
Commit
8140c51
·
verified ·
1 Parent(s): df297bc

update pretrainer to not use distributed sampler (Trainer uses accelerate)

Browse files
Files changed (1) hide show
  1. geneformer/pretrainer.py +5 -171
geneformer/pretrainer.py CHANGED
@@ -607,7 +607,7 @@ class GeneformerPretrainer(Trainer):
607
  )
608
  super().__init__(*args, **kwargs)
609
 
610
- # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
611
  def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
612
  if not isinstance(self.train_dataset, collections.abc.Sized):
613
  return None
@@ -630,181 +630,15 @@ class GeneformerPretrainer(Trainer):
630
  if self.tokenizer is not None
631
  else None
632
  )
633
- if self.args.world_size <= 1:
634
- return LengthGroupedSampler(
635
  dataset=self.train_dataset,
636
  batch_size=self.args.train_batch_size,
637
  lengths=lengths,
638
  model_input_name=model_input_name,
639
  generator=generator,
640
- )
641
- else:
642
- return CustomDistributedLengthGroupedSampler(
643
- dataset=self.train_dataset,
644
- batch_size=self.args.train_batch_size,
645
- num_replicas=self.args.world_size,
646
- rank=self.args.process_index,
647
- lengths=lengths,
648
- model_input_name=model_input_name,
649
- seed=self.args.seed,
650
- )
651
-
652
- else:
653
- if self.args.world_size <= 1:
654
- if _is_torch_generator_available:
655
- return RandomSampler(self.train_dataset, generator=generator)
656
- return RandomSampler(self.train_dataset)
657
- elif (
658
- self.args.parallel_mode
659
- in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
660
- and not self.args.dataloader_drop_last
661
- ):
662
- # Use a loop for TPUs when drop_last is False to have all batches have the same size.
663
- return DistributedSamplerWithLoop(
664
- self.train_dataset,
665
- batch_size=self.args.per_device_train_batch_size,
666
- num_replicas=self.args.world_size,
667
- rank=self.args.process_index,
668
- seed=self.args.seed,
669
- )
670
- else:
671
- return DistributedSampler(
672
- self.train_dataset,
673
- num_replicas=self.args.world_size,
674
- rank=self.args.process_index,
675
- seed=self.args.seed,
676
- )
677
-
678
-
679
- class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
680
- r"""
681
- Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
682
- length while keeping a bit of randomness.
683
- """
684
-
685
- # Copied and adapted from PyTorch DistributedSampler.
686
- def __init__(
687
- self,
688
- dataset: Dataset,
689
- batch_size: int,
690
- num_replicas: Optional[int] = None,
691
- rank: Optional[int] = None,
692
- seed: int = 0,
693
- drop_last: bool = False,
694
- lengths: Optional[List[int]] = None,
695
- model_input_name: Optional[str] = None,
696
- ):
697
- if num_replicas is None:
698
- if not dist.is_available():
699
- raise RuntimeError("Requires distributed package to be available")
700
- num_replicas = dist.get_world_size()
701
- if rank is None:
702
- if not dist.is_available():
703
- raise RuntimeError("Requires distributed package to be available")
704
- rank = dist.get_rank()
705
- self.dataset = dataset
706
- self.batch_size = batch_size
707
- self.num_replicas = num_replicas
708
- self.rank = rank
709
- self.epoch = 0
710
- self.drop_last = drop_last
711
- # If the dataset length is evenly divisible by # of replicas, then there
712
- # is no need to drop any data, since the dataset will be split equally.
713
- if self.drop_last and len(self.dataset) % self.num_replicas != 0:
714
- # Split to nearest available length that is evenly divisible.
715
- # This is to ensure each rank receives the same amount of data when
716
- # using this Sampler.
717
- self.num_samples = math.ceil(
718
- (len(self.dataset) - self.num_replicas) / self.num_replicas
719
  )
720
- else:
721
- self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
722
- self.total_size = self.num_samples * self.num_replicas
723
- self.seed = seed
724
- self.model_input_name = (
725
- model_input_name if model_input_name is not None else "input_ids"
726
- )
727
-
728
- if lengths is None:
729
- print("Lengths is none - calculating lengths.")
730
- if (
731
- not (
732
- isinstance(dataset[0], dict)
733
- or isinstance(dataset[0], BatchEncoding)
734
- )
735
- or self.model_input_name not in dataset[0]
736
- ):
737
- raise ValueError(
738
- "Can only automatically infer lengths for datasets whose items are dictionaries with an "
739
- f"'{self.model_input_name}' key."
740
- )
741
- lengths = [len(feature[self.model_input_name]) for feature in dataset]
742
- self.lengths = lengths
743
-
744
- def __iter__(self) -> Iterator:
745
- # Deterministically shuffle based on epoch and seed
746
- g = torch.Generator()
747
- g.manual_seed(self.seed + self.epoch)
748
-
749
- indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
750
 
751
- if not self.drop_last:
752
- # add extra samples to make it evenly divisible
753
- indices += indices[: (self.total_size - len(indices))]
754
  else:
755
- # remove tail of data to make it evenly divisible.
756
- indices = indices[: self.total_size]
757
- assert len(indices) == self.total_size
758
-
759
- # subsample
760
- indices = indices[self.rank : self.total_size : self.num_replicas]
761
- assert len(indices) == self.num_samples
762
-
763
- return iter(indices)
764
-
765
-
766
- def get_length_grouped_indices(
767
- lengths, batch_size, mega_batch_mult=None, generator=None
768
- ):
769
- """
770
- Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
771
- similar lengths. To do this, the indices are:
772
-
773
- - randomly permuted
774
- - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
775
- - sorted by length in each mega-batch
776
-
777
- The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
778
- maximum length placed first, so that an OOM happens sooner rather than later.
779
- """
780
- # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
781
- if mega_batch_mult is None:
782
- # mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
783
- mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
784
- # Just in case, for tiny datasets
785
- if mega_batch_mult == 0:
786
- mega_batch_mult = 1
787
-
788
- # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
789
- indices = torch.randperm(len(lengths), generator=generator)
790
- megabatch_size = mega_batch_mult * batch_size
791
- megabatches = [
792
- indices[i : i + megabatch_size].tolist()
793
- for i in range(0, len(lengths), megabatch_size)
794
- ]
795
- megabatches = [
796
- list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
797
- for megabatch in megabatches
798
- ]
799
-
800
- # The rest is to get the biggest batch first.
801
- # Since each megabatch is sorted by descending length, the longest element is the first
802
- megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
803
- max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
804
- # Switch to put the longest element in first position
805
- megabatches[0][0], megabatches[max_idx][0] = (
806
- megabatches[max_idx][0],
807
- megabatches[0][0],
808
- )
809
-
810
- return [item for sublist in megabatches for item in sublist]
 
607
  )
608
  super().__init__(*args, **kwargs)
609
 
610
+ # updated to not use distributed sampler since Trainer now distributes with accelerate
611
  def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
612
  if not isinstance(self.train_dataset, collections.abc.Sized):
613
  return None
 
630
  if self.tokenizer is not None
631
  else None
632
  )
633
+ return LengthGroupedSampler(
 
634
  dataset=self.train_dataset,
635
  batch_size=self.args.train_batch_size,
636
  lengths=lengths,
637
  model_input_name=model_input_name,
638
  generator=generator,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
 
 
 
641
  else:
642
+ if _is_torch_generator_available:
643
+ return RandomSampler(self.train_dataset, generator=generator)
644
+ return RandomSampler(self.train_dataset)