Christina Theodoris
commited on
Commit
·
b925dcc
1
Parent(s):
8ce598f
Update pretrainer for transformers==4.28.0
Browse files
examples/pretrain_geneformer_w_deepspeed.py
CHANGED
@@ -137,9 +137,8 @@ training_args = {
|
|
137 |
"weight_decay": weight_decay,
|
138 |
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
"num_train_epochs": epochs,
|
140 |
-
"load_best_model_at_end": True,
|
141 |
"save_strategy": "steps",
|
142 |
-
"save_steps": num_examples / geneformer_batch_size / 8, # 8 saves per epoch
|
143 |
"logging_steps": 1000,
|
144 |
"output_dir": training_output_dir,
|
145 |
"logging_dir": logging_dir,
|
|
|
137 |
"weight_decay": weight_decay,
|
138 |
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
"num_train_epochs": epochs,
|
|
|
140 |
"save_strategy": "steps",
|
141 |
+
"save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
|
142 |
"logging_steps": 1000,
|
143 |
"output_dir": training_output_dir,
|
144 |
"logging_dir": logging_dir,
|
geneformer/pretrainer.py
CHANGED
@@ -106,19 +106,23 @@ class TensorType(ExplicitEnum):
|
|
106 |
|
107 |
class GeneformerPreCollator(SpecialTokensMixin):
|
108 |
def __init__(self, *args, **kwargs) -> None:
|
|
|
|
|
|
|
109 |
self.token_dictionary = kwargs.get("token_dictionary")
|
110 |
-
self.mask_token = "<mask>"
|
111 |
-
self.mask_token_id = self.token_dictionary.get("<mask>")
|
112 |
-
self.pad_token = "<pad>"
|
113 |
-
self.pad_token_id = self.token_dictionary.get("<pad>")
|
114 |
self.padding_side = "right"
|
115 |
-
self.all_special_ids = [
|
116 |
-
|
117 |
-
|
118 |
-
]
|
119 |
self.model_input_names = ["input_ids"]
|
120 |
-
|
121 |
-
|
|
|
122 |
|
123 |
def _get_padding_truncation_strategies(
|
124 |
self,
|
@@ -592,8 +596,8 @@ class GeneformerPreCollator(SpecialTokensMixin):
|
|
592 |
|
593 |
class GeneformerPretrainer(Trainer):
|
594 |
def __init__(self, *args, **kwargs):
|
595 |
-
data_collator = kwargs.get("data_collator")
|
596 |
-
token_dictionary = kwargs.
|
597 |
|
598 |
if data_collator is None:
|
599 |
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
@@ -604,17 +608,17 @@ class GeneformerPretrainer(Trainer):
|
|
604 |
)
|
605 |
kwargs["data_collator"] = data_collator
|
606 |
|
607 |
-
super().__init__(*args, **kwargs)
|
608 |
-
|
609 |
# load previously saved length vector for dataset to speed up LengthGroupedSampler
|
610 |
# pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
|
611 |
-
|
612 |
-
|
|
|
613 |
self.example_lengths = pickle.load(f)
|
614 |
else:
|
615 |
raise Exception(
|
616 |
"example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
|
617 |
)
|
|
|
618 |
|
619 |
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
620 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
@@ -634,7 +638,6 @@ class GeneformerPretrainer(Trainer):
|
|
634 |
lengths = self.example_lengths
|
635 |
else:
|
636 |
lengths = None
|
637 |
-
print(f"Lengths: {len(lengths)}")
|
638 |
model_input_name = (
|
639 |
self.tokenizer.model_input_names[0]
|
640 |
if self.tokenizer is not None
|
@@ -642,16 +645,16 @@ class GeneformerPretrainer(Trainer):
|
|
642 |
)
|
643 |
if self.args.world_size <= 1:
|
644 |
return LengthGroupedSampler(
|
645 |
-
self.train_dataset,
|
646 |
-
self.args.train_batch_size,
|
647 |
lengths=lengths,
|
648 |
model_input_name=model_input_name,
|
649 |
generator=generator,
|
650 |
)
|
651 |
else:
|
652 |
return CustomDistributedLengthGroupedSampler(
|
653 |
-
self.train_dataset,
|
654 |
-
self.args.train_batch_size,
|
655 |
num_replicas=self.args.world_size,
|
656 |
rank=self.args.process_index,
|
657 |
lengths=lengths,
|
@@ -754,7 +757,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
|
754 |
# Deterministically shuffle based on epoch and seed
|
755 |
g = torch.Generator()
|
756 |
g.manual_seed(self.seed + self.epoch)
|
757 |
-
|
758 |
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
759 |
|
760 |
if not self.drop_last:
|
|
|
106 |
|
107 |
class GeneformerPreCollator(SpecialTokensMixin):
|
108 |
def __init__(self, *args, **kwargs) -> None:
|
109 |
+
|
110 |
+
super().__init__(mask_token = "<mask>", pad_token = "<pad>")
|
111 |
+
|
112 |
self.token_dictionary = kwargs.get("token_dictionary")
|
113 |
+
# self.mask_token = "<mask>"
|
114 |
+
# self.mask_token_id = self.token_dictionary.get("<mask>")
|
115 |
+
# self.pad_token = "<pad>"
|
116 |
+
# self.pad_token_id = self.token_dictionary.get("<pad>")
|
117 |
self.padding_side = "right"
|
118 |
+
# self.all_special_ids = [
|
119 |
+
# self.token_dictionary.get("<mask>"),
|
120 |
+
# self.token_dictionary.get("<pad>"),
|
121 |
+
# ]
|
122 |
self.model_input_names = ["input_ids"]
|
123 |
+
|
124 |
+
def convert_ids_to_tokens(self,value):
|
125 |
+
return self.token_dictionary.get(value)
|
126 |
|
127 |
def _get_padding_truncation_strategies(
|
128 |
self,
|
|
|
596 |
|
597 |
class GeneformerPretrainer(Trainer):
|
598 |
def __init__(self, *args, **kwargs):
|
599 |
+
data_collator = kwargs.get("data_collator",None)
|
600 |
+
token_dictionary = kwargs.pop("token_dictionary")
|
601 |
|
602 |
if data_collator is None:
|
603 |
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
|
|
608 |
)
|
609 |
kwargs["data_collator"] = data_collator
|
610 |
|
|
|
|
|
611 |
# load previously saved length vector for dataset to speed up LengthGroupedSampler
|
612 |
# pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
|
613 |
+
example_lengths_file = kwargs.pop("example_lengths_file")
|
614 |
+
if example_lengths_file:
|
615 |
+
with open(example_lengths_file, "rb") as f:
|
616 |
self.example_lengths = pickle.load(f)
|
617 |
else:
|
618 |
raise Exception(
|
619 |
"example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
|
620 |
)
|
621 |
+
super().__init__(*args, **kwargs)
|
622 |
|
623 |
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
624 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
|
638 |
lengths = self.example_lengths
|
639 |
else:
|
640 |
lengths = None
|
|
|
641 |
model_input_name = (
|
642 |
self.tokenizer.model_input_names[0]
|
643 |
if self.tokenizer is not None
|
|
|
645 |
)
|
646 |
if self.args.world_size <= 1:
|
647 |
return LengthGroupedSampler(
|
648 |
+
dataset=self.train_dataset,
|
649 |
+
batch_size=self.args.train_batch_size,
|
650 |
lengths=lengths,
|
651 |
model_input_name=model_input_name,
|
652 |
generator=generator,
|
653 |
)
|
654 |
else:
|
655 |
return CustomDistributedLengthGroupedSampler(
|
656 |
+
dataset=self.train_dataset,
|
657 |
+
batch_size=self.args.train_batch_size,
|
658 |
num_replicas=self.args.world_size,
|
659 |
rank=self.args.process_index,
|
660 |
lengths=lengths,
|
|
|
757 |
# Deterministically shuffle based on epoch and seed
|
758 |
g = torch.Generator()
|
759 |
g.manual_seed(self.seed + self.epoch)
|
760 |
+
|
761 |
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
762 |
|
763 |
if not self.drop_last:
|