Christina Theodoris commited on
Commit
bcc03e8
·
1 Parent(s): 5426788

Add Geneformer trainer and pretraining example

Browse files
examples/pretrain_geneformer_w_deepspeed.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # run with:
5
+ # deepspeed --num_gpus=12 --num_nodes=3 pretrain_geneformer_w_deepspeed.py --deepspeed ds_config.json
6
+
7
+ import datetime
8
+
9
+ # imports
10
+ import os
11
+
12
+ os.environ["NCCL_DEBUG"] = "INFO"
13
+ os.environ["OMPI_MCA_opal_cuda_support"] = "true"
14
+ os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
15
+
16
+ import pickle
17
+ import random
18
+ import subprocess
19
+
20
+ import numpy as np
21
+ import pytz
22
+ import torch
23
+ from datasets import load_from_disk
24
+ from transformers import BertConfig, BertForMaskedLM, TrainingArguments
25
+
26
+ from .trainer import GeneformerTrainer
27
+
28
+ seed_num = 0
29
+ random.seed(seed_num)
30
+ np.random.seed(seed_num)
31
+ seed_val = 42
32
+ torch.manual_seed(seed_val)
33
+ torch.cuda.manual_seed_all(seed_val)
34
+
35
+ # set local time/directories
36
+ timezone = pytz.timezone("US/Eastern")
37
+ rootdir = "/parent_ouput_directory"
38
+
39
+ # set model parameters
40
+ # model type
41
+ model_type = "bert"
42
+ # max input size
43
+ max_input_size = 2**11 # 2048
44
+ # number of layers
45
+ num_layers = 6
46
+ # number of attention heads
47
+ num_attn_heads = 4
48
+ # number of embedding dimensions
49
+ num_embed_dim = 256
50
+ # intermediate size
51
+ intermed_size = num_embed_dim * 2
52
+ # activation function
53
+ activ_fn = "relu"
54
+ # initializer range, layer norm, dropout
55
+ initializer_range = 0.02
56
+ layer_norm_eps = 1e-12
57
+ attention_probs_dropout_prob = 0.02
58
+ hidden_dropout_prob = 0.02
59
+
60
+
61
+ # set training parameters
62
+ # total number of examples in Genecorpus-30M after QC filtering:
63
+ num_examples = 27_406_208
64
+ # number gpus
65
+ num_gpus = 12
66
+ # batch size for training and eval
67
+ geneformer_batch_size = 12
68
+ # max learning rate
69
+ max_lr = 1e-3
70
+ # learning schedule
71
+ lr_schedule_fn = "linear"
72
+ # warmup steps
73
+ warmup_steps = 10_000
74
+ # number of epochs
75
+ epochs = 3
76
+ # optimizer
77
+ optimizer = "adamw"
78
+ # weight_decay
79
+ weight_decay = 0.001
80
+
81
+
82
+ # output directories
83
+ current_date = datetime.datetime.now(tz=timezone)
84
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
85
+ run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}"
86
+ training_output_dir = f"{rootdir}/models/{run_name}/"
87
+ logging_dir = f"{rootdir}/runs/{run_name}/"
88
+ model_output_dir = os.path.join(training_output_dir, "models/")
89
+
90
+
91
+ # ensure not overwriting previously saved model
92
+ model_output_file = os.path.join(model_output_dir, "pytorch_model.bin")
93
+ if os.path.isfile(model_output_file) is True:
94
+ raise Exception("Model already saved to this directory.")
95
+
96
+
97
+ # make training and model output directories
98
+ subprocess.call(f"mkdir {training_output_dir}", shell=True)
99
+ subprocess.call(f"mkdir {model_output_dir}", shell=True)
100
+
101
+
102
+ # load gene_ensembl_id:token dictionary (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/datasets/token_dictionary.pkl)
103
+ with open("token_dictionary.pkl", "rb") as fp:
104
+ token_dictionary = pickle.load(fp)
105
+
106
+ # model configuration
107
+ config = {
108
+ "hidden_size": num_embed_dim,
109
+ "num_hidden_layers": num_layers,
110
+ "initializer_range": initializer_range,
111
+ "layer_norm_eps": layer_norm_eps,
112
+ "attention_probs_dropout_prob": attention_probs_dropout_prob,
113
+ "hidden_dropout_prob": hidden_dropout_prob,
114
+ "intermediate_size": intermed_size,
115
+ "hidden_act": activ_fn,
116
+ "max_position_embeddings": max_input_size,
117
+ "model_type": model_type,
118
+ "num_attention_heads": num_attn_heads,
119
+ "pad_token_id": token_dictionary.get("<pad>"),
120
+ "vocab_size": len(token_dictionary), # genes+2 for <mask> and <pad> tokens
121
+ }
122
+
123
+ config = BertConfig(**config)
124
+ model = BertForMaskedLM(config)
125
+ model = model.train()
126
+
127
+ # define the training arguments
128
+ training_args = {
129
+ "learning_rate": max_lr,
130
+ "do_train": True,
131
+ "do_eval": False,
132
+ "group_by_length": True,
133
+ "length_column_name": "length",
134
+ "disable_tqdm": False,
135
+ "lr_scheduler_type": lr_schedule_fn,
136
+ "warmup_steps": warmup_steps,
137
+ "weight_decay": weight_decay,
138
+ "per_device_train_batch_size": geneformer_batch_size,
139
+ "num_train_epochs": epochs,
140
+ "load_best_model_at_end": True,
141
+ "save_strategy": "steps",
142
+ "save_steps": num_examples / geneformer_batch_size / 8, # 8 saves per epoch
143
+ "logging_steps": 1000,
144
+ "output_dir": training_output_dir,
145
+ "logging_dir": logging_dir,
146
+ }
147
+ training_args = TrainingArguments(**training_args)
148
+
149
+ print("Starting training.")
150
+
151
+ # define the trainer
152
+ trainer = GeneformerTrainer(
153
+ model=model,
154
+ args=training_args,
155
+ # pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
156
+ train_dataset=load_from_disk("genecorpus_30M_2048.dataset"),
157
+ # file of lengths of each example cell (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl)
158
+ example_lengths_file="genecorpus_30M_2048_sorted_lengths.pkl",
159
+ token_dictionary=token_dictionary,
160
+ )
161
+
162
+ # train
163
+ trainer.train()
164
+
165
+ # save model
166
+ trainer.save_model(model_output_dir)
geneformer/trainer.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer trainer and collator.
3
+
4
+ Huggingface trainer and data collator modified to accommodate single-cell transcriptomics data.
5
+ """
6
+ import collections
7
+ import math
8
+ import pickle
9
+ import warnings
10
+ from enum import Enum
11
+ from typing import Dict, Iterator, List, Optional, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ from datasets import Dataset
16
+ from packaging import version
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from torch.utils.data.sampler import RandomSampler
19
+ from transformers import (
20
+ BatchEncoding,
21
+ DataCollatorForLanguageModeling,
22
+ SpecialTokensMixin,
23
+ Trainer,
24
+ )
25
+ from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
26
+ from transformers.trainer_pt_utils import (
27
+ DistributedLengthGroupedSampler,
28
+ DistributedSamplerWithLoop,
29
+ LengthGroupedSampler,
30
+ )
31
+ from transformers.training_args import ParallelMode
32
+ from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
+ from transformers.utils.generic import _is_tensorflow, _is_torch
34
+
35
+ from .tokenizer import TOKEN_DICTIONARY_FILE
36
+
37
+ logger = logging.get_logger(__name__)
38
+ EncodedInput = List[int]
39
+ VERY_LARGE_INTEGER = int(
40
+ 1e30
41
+ ) # This is used to set the max input length for a model with infinite size input
42
+ LARGE_INTEGER = int(
43
+ 1e20
44
+ ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
45
+
46
+ if is_sagemaker_dp_enabled():
47
+ import smdistributed.dataparallel.torch.distributed as dist
48
+ else:
49
+ import torch.distributed as dist
50
+
51
+ _is_torch_generator_available = False
52
+ if version.parse(torch.__version__) >= version.parse("1.6"):
53
+ _is_torch_generator_available = True
54
+
55
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
56
+ token_dictionary = pickle.load(f)
57
+
58
+
59
+ class ExplicitEnum(Enum):
60
+ """
61
+ Enum with more explicit error message for missing values.
62
+ """
63
+
64
+ @classmethod
65
+ def _missing_(cls, value):
66
+ raise ValueError(
67
+ "%r is not a valid %s, please select one of %s"
68
+ % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
69
+ )
70
+
71
+
72
+ class TruncationStrategy(ExplicitEnum):
73
+ """
74
+ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
75
+ tab-completion in an IDE.
76
+ """
77
+
78
+ ONLY_FIRST = "only_first"
79
+ ONLY_SECOND = "only_second"
80
+ LONGEST_FIRST = "longest_first"
81
+ DO_NOT_TRUNCATE = "do_not_truncate"
82
+
83
+
84
+ class PaddingStrategy(ExplicitEnum):
85
+ """
86
+ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
87
+ in an IDE.
88
+ """
89
+
90
+ LONGEST = "longest"
91
+ MAX_LENGTH = "max_length"
92
+ DO_NOT_PAD = "do_not_pad"
93
+
94
+
95
+ class TensorType(ExplicitEnum):
96
+ """
97
+ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
98
+ tab-completion in an IDE.
99
+ """
100
+
101
+ PYTORCH = "pt"
102
+ TENSORFLOW = "tf"
103
+ NUMPY = "np"
104
+ JAX = "jax"
105
+
106
+
107
+ class GeneformerPreCollator(SpecialTokensMixin):
108
+ def __init__(self, *args, **kwargs) -> None:
109
+ self.token_dictionary = kwargs.get("token_dictionary")
110
+ self.mask_token = "<mask>"
111
+ self.mask_token_id = self.token_dictionary.get("<mask>")
112
+ self.pad_token = "<pad>"
113
+ self.pad_token_id = self.token_dictionary.get("<pad>")
114
+ self.padding_side = "right"
115
+ self.all_special_ids = [
116
+ self.token_dictionary.get("<mask>"),
117
+ self.token_dictionary.get("<pad>"),
118
+ ]
119
+ self.model_input_names = ["input_ids"]
120
+
121
+ super().__init__(*args, **kwargs)
122
+
123
+ def _get_padding_truncation_strategies(
124
+ self,
125
+ padding=False,
126
+ truncation=False,
127
+ max_length=None,
128
+ pad_to_multiple_of=None,
129
+ verbose=True,
130
+ **kwargs,
131
+ ):
132
+ """
133
+ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
134
+ and pad_to_max_length) and behaviors.
135
+ """
136
+ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
137
+ old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
138
+
139
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
140
+ # If you only set max_length, it activates truncation for max_length
141
+ if max_length is not None and padding is False and truncation is False:
142
+ if verbose:
143
+ if not self.deprecation_warnings.get(
144
+ "Truncation-not-explicitly-activated", False
145
+ ):
146
+ logger.warning(
147
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, "
148
+ "please use `truncation=True` to explicitly truncate examples to max length. "
149
+ "Defaulting to 'longest_first' truncation strategy. "
150
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
151
+ "more precisely by providing a specific strategy to `truncation`."
152
+ )
153
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
154
+ truncation = "longest_first"
155
+
156
+ # Get padding strategy
157
+ if padding is False and old_pad_to_max_length:
158
+ if verbose:
159
+ warnings.warn(
160
+ "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
161
+ "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
162
+ "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
163
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
164
+ "maximal input size of the model (e.g. 512 for Bert).",
165
+ FutureWarning,
166
+ )
167
+ if max_length is None:
168
+ padding_strategy = PaddingStrategy.LONGEST
169
+ else:
170
+ padding_strategy = PaddingStrategy.MAX_LENGTH
171
+ elif padding is not False:
172
+ if padding is True:
173
+ padding_strategy = (
174
+ PaddingStrategy.LONGEST
175
+ ) # Default to pad to the longest sequence in the batch
176
+ elif not isinstance(padding, PaddingStrategy):
177
+ padding_strategy = PaddingStrategy(padding)
178
+ elif isinstance(padding, PaddingStrategy):
179
+ padding_strategy = padding
180
+ else:
181
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
182
+
183
+ # Get truncation strategy
184
+ if truncation is False and old_truncation_strategy != "do_not_truncate":
185
+ if verbose:
186
+ warnings.warn(
187
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
188
+ "use `truncation=True` to truncate examples to a max length. You can give a specific "
189
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
190
+ "maximal input size of the model (e.g. 512 for Bert). "
191
+ " If you have pairs of inputs, you can give a specific truncation strategy selected among "
192
+ "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
193
+ "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
194
+ "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
195
+ FutureWarning,
196
+ )
197
+ truncation_strategy = TruncationStrategy(old_truncation_strategy)
198
+ elif truncation is not False:
199
+ if truncation is True:
200
+ truncation_strategy = (
201
+ TruncationStrategy.LONGEST_FIRST
202
+ ) # Default to truncate the longest sequences in pairs of inputs
203
+ elif not isinstance(truncation, TruncationStrategy):
204
+ truncation_strategy = TruncationStrategy(truncation)
205
+ elif isinstance(truncation, TruncationStrategy):
206
+ truncation_strategy = truncation
207
+ else:
208
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
209
+
210
+ # Set max length if needed
211
+ if max_length is None:
212
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
213
+ if self.model_max_length > LARGE_INTEGER:
214
+ if verbose:
215
+ if not self.deprecation_warnings.get(
216
+ "Asking-to-pad-to-max_length", False
217
+ ):
218
+ logger.warning(
219
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
220
+ "Default to no padding."
221
+ )
222
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
223
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
224
+ else:
225
+ max_length = self.model_max_length
226
+
227
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
228
+ if self.model_max_length > LARGE_INTEGER:
229
+ if verbose:
230
+ if not self.deprecation_warnings.get(
231
+ "Asking-to-truncate-to-max_length", False
232
+ ):
233
+ logger.warning(
234
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
235
+ "Default to no truncation."
236
+ )
237
+ self.deprecation_warnings[
238
+ "Asking-to-truncate-to-max_length"
239
+ ] = True
240
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
241
+ else:
242
+ max_length = self.model_max_length
243
+
244
+ # Test if we have a padding token
245
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
246
+ not self.pad_token or self.pad_token_id < 0
247
+ ):
248
+ raise ValueError(
249
+ "Asking to pad but the tokenizer does not have a padding token. "
250
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
251
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
252
+ )
253
+
254
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
255
+ if (
256
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
257
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
258
+ and pad_to_multiple_of is not None
259
+ and max_length is not None
260
+ and (max_length % pad_to_multiple_of != 0)
261
+ ):
262
+ raise ValueError(
263
+ f"Truncation and padding are both activated but "
264
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
265
+ )
266
+
267
+ return padding_strategy, truncation_strategy, max_length, kwargs
268
+
269
+ def pad(
270
+ self,
271
+ encoded_inputs: Union[
272
+ BatchEncoding,
273
+ List[BatchEncoding],
274
+ Dict[str, EncodedInput],
275
+ Dict[str, List[EncodedInput]],
276
+ List[Dict[str, EncodedInput]],
277
+ ],
278
+ padding: Union[bool, str, PaddingStrategy] = True,
279
+ max_length: Optional[int] = None,
280
+ pad_to_multiple_of: Optional[int] = None,
281
+ return_attention_mask: Optional[bool] = True,
282
+ return_tensors: Optional[Union[str, TensorType]] = None,
283
+ verbose: bool = True,
284
+ ) -> BatchEncoding:
285
+ """
286
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
287
+ in the batch.
288
+
289
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
290
+ ``self.pad_token_id`` and ``self.pad_token_type_id``)
291
+
292
+ .. note::
293
+
294
+ If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
295
+ result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
296
+ case of PyTorch tensors, you will lose the specific device of your tensors however.
297
+
298
+ Args:
299
+ encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
300
+ Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
301
+ List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
302
+ List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
303
+ well as in a PyTorch Dataloader collate function.
304
+
305
+ Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
306
+ see the note above for the return type.
307
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
308
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
309
+ index) among:
310
+
311
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
312
+ single sequence if provided).
313
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
314
+ maximum acceptable input length for the model if that argument is not provided.
315
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
316
+ different lengths).
317
+ max_length (:obj:`int`, `optional`):
318
+ Maximum length of the returned list and optionally padding length (see above).
319
+ pad_to_multiple_of (:obj:`int`, `optional`):
320
+ If set will pad the sequence to a multiple of the provided value.
321
+
322
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
323
+ >= 7.5 (Volta).
324
+ return_attention_mask (:obj:`bool`, `optional`):
325
+ Whether to return the attention mask. If left to the default, will return the attention mask according
326
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
327
+
328
+ `What are attention masks? <../glossary.html#attention-mask>`__
329
+ return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
330
+ If set, will return tensors instead of list of python integers. Acceptable values are:
331
+
332
+ * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
333
+ * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
334
+ * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
335
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
336
+ Whether or not to print more information and warnings.
337
+ """
338
+ # If we have a list of dicts, let's convert it in a dict of lists
339
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
340
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(
341
+ encoded_inputs[0], (dict, BatchEncoding)
342
+ ):
343
+ encoded_inputs = {
344
+ key: [example[key] for example in encoded_inputs]
345
+ for key in encoded_inputs[0].keys()
346
+ }
347
+
348
+ # The model's main input name, usually `input_ids`, has be passed for padding
349
+ if self.model_input_names[0] not in encoded_inputs:
350
+ raise ValueError(
351
+ "You should supply an encoding or a list of encodings to this method"
352
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
353
+ )
354
+
355
+ required_input = encoded_inputs[self.model_input_names[0]]
356
+
357
+ if not required_input:
358
+ if return_attention_mask:
359
+ encoded_inputs["attention_mask"] = []
360
+ return encoded_inputs
361
+
362
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
363
+ # and rebuild them afterwards if no return_tensors is specified
364
+ # Note that we lose the specific device the tensor may be on for PyTorch
365
+
366
+ first_element = required_input[0]
367
+ if isinstance(first_element, (list, tuple)):
368
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
369
+ index = 0
370
+ while len(required_input[index]) == 0:
371
+ index += 1
372
+ if index < len(required_input):
373
+ first_element = required_input[index][0]
374
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
375
+ if not isinstance(first_element, (int, list, tuple)):
376
+ if is_tf_available() and _is_tensorflow(first_element):
377
+ return_tensors = "tf" if return_tensors is None else return_tensors
378
+ elif is_torch_available() and _is_torch(first_element):
379
+ return_tensors = "pt" if return_tensors is None else return_tensors
380
+ elif isinstance(first_element, np.ndarray):
381
+ return_tensors = "np" if return_tensors is None else return_tensors
382
+ else:
383
+ raise ValueError(
384
+ f"type of {first_element} unknown: {type(first_element)}. "
385
+ f"Should be one of a python, numpy, pytorch or tensorflow object."
386
+ )
387
+
388
+ for key, value in encoded_inputs.items():
389
+ encoded_inputs[key] = to_py_obj(value)
390
+
391
+ # Convert padding_strategy in PaddingStrategy
392
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
393
+ padding=padding, max_length=max_length, verbose=verbose
394
+ )
395
+
396
+ required_input = encoded_inputs[self.model_input_names[0]]
397
+ if required_input and not isinstance(required_input[0], (list, tuple)):
398
+ encoded_inputs = self._pad(
399
+ encoded_inputs,
400
+ max_length=max_length,
401
+ padding_strategy=padding_strategy,
402
+ pad_to_multiple_of=pad_to_multiple_of,
403
+ return_attention_mask=return_attention_mask,
404
+ )
405
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
406
+
407
+ batch_size = len(required_input)
408
+ assert all(
409
+ len(v) == batch_size for v in encoded_inputs.values()
410
+ ), "Some items in the output dictionary have a different batch size than others."
411
+
412
+ if padding_strategy == PaddingStrategy.LONGEST:
413
+ max_length = max(len(inputs) for inputs in required_input)
414
+ padding_strategy = PaddingStrategy.MAX_LENGTH
415
+
416
+ batch_outputs = {}
417
+ for i in range(batch_size):
418
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
419
+ outputs = self._pad(
420
+ inputs,
421
+ max_length=max_length,
422
+ padding_strategy=padding_strategy,
423
+ pad_to_multiple_of=pad_to_multiple_of,
424
+ return_attention_mask=return_attention_mask,
425
+ )
426
+
427
+ for key, value in outputs.items():
428
+ if key not in batch_outputs:
429
+ batch_outputs[key] = []
430
+ batch_outputs[key].append(value)
431
+
432
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
433
+
434
+ def _pad(
435
+ self,
436
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
437
+ max_length: Optional[int] = None,
438
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
439
+ pad_to_multiple_of: Optional[int] = None,
440
+ return_attention_mask: Optional[bool] = None,
441
+ ) -> dict:
442
+ """
443
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
444
+
445
+ Args:
446
+ encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
447
+ max_length: maximum length of the returned list and optionally padding length (see below).
448
+ Will truncate by taking into account the special tokens.
449
+ padding_strategy: PaddingStrategy to use for padding.
450
+
451
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
452
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
453
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
454
+ The tokenizer padding sides are defined in self.padding_side:
455
+
456
+ - 'left': pads on the left of the sequences
457
+ - 'right': pads on the right of the sequences
458
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
459
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
460
+ >= 7.5 (Volta).
461
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
462
+ """
463
+ # Load from model defaults
464
+ if return_attention_mask is None:
465
+ return_attention_mask = "attention_mask" in self.model_input_names
466
+
467
+ required_input = encoded_inputs[self.model_input_names[0]]
468
+
469
+ if padding_strategy == PaddingStrategy.LONGEST:
470
+ max_length = len(required_input)
471
+
472
+ if (
473
+ max_length is not None
474
+ and pad_to_multiple_of is not None
475
+ and (max_length % pad_to_multiple_of != 0)
476
+ ):
477
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
478
+
479
+ needs_to_be_padded = (
480
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
481
+ and len(required_input) != max_length
482
+ )
483
+
484
+ if needs_to_be_padded:
485
+ difference = max_length - len(required_input)
486
+ if self.padding_side == "right":
487
+ if return_attention_mask:
488
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [
489
+ 0
490
+ ] * difference
491
+ if "token_type_ids" in encoded_inputs:
492
+ encoded_inputs["token_type_ids"] = (
493
+ encoded_inputs["token_type_ids"]
494
+ + [self.pad_token_type_id] * difference
495
+ )
496
+ if "special_tokens_mask" in encoded_inputs:
497
+ encoded_inputs["special_tokens_mask"] = (
498
+ encoded_inputs["special_tokens_mask"] + [1] * difference
499
+ )
500
+ encoded_inputs[self.model_input_names[0]] = (
501
+ required_input + [self.pad_token_id] * difference
502
+ )
503
+ elif self.padding_side == "left":
504
+ if return_attention_mask:
505
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
506
+ required_input
507
+ )
508
+ if "token_type_ids" in encoded_inputs:
509
+ encoded_inputs["token_type_ids"] = [
510
+ self.pad_token_type_id
511
+ ] * difference + encoded_inputs["token_type_ids"]
512
+ if "special_tokens_mask" in encoded_inputs:
513
+ encoded_inputs["special_tokens_mask"] = [
514
+ 1
515
+ ] * difference + encoded_inputs["special_tokens_mask"]
516
+ encoded_inputs[self.model_input_names[0]] = [
517
+ self.pad_token_id
518
+ ] * difference + required_input
519
+ else:
520
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
521
+ elif return_attention_mask and "attention_mask" not in encoded_inputs:
522
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
523
+
524
+ return encoded_inputs
525
+
526
+ def get_special_tokens_mask(
527
+ self,
528
+ token_ids_0: List[int],
529
+ token_ids_1: Optional[List[int]] = None,
530
+ already_has_special_tokens: bool = False,
531
+ ) -> List[int]:
532
+ """
533
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
534
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
535
+ Args:
536
+ token_ids_0 (:obj:`List[int]`):
537
+ List of ids of the first sequence.
538
+ token_ids_1 (:obj:`List[int]`, `optional`):
539
+ List of ids of the second sequence.
540
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
541
+ Whether or not the token list is already formatted with special tokens for the model.
542
+ Returns:
543
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
544
+ """
545
+ assert already_has_special_tokens and token_ids_1 is None, (
546
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
547
+ "Please use a slow (full python) tokenizer to activate this argument."
548
+ "Or set `return_special_tokens_mask=True` when calling the encoding method "
549
+ "to get the special tokens mask in any tokenizer. "
550
+ )
551
+
552
+ all_special_ids = self.all_special_ids # cache the property
553
+
554
+ special_tokens_mask = [
555
+ 1 if token in all_special_ids else 0 for token in token_ids_0
556
+ ]
557
+
558
+ return special_tokens_mask
559
+
560
+ def convert_tokens_to_ids(
561
+ self, tokens: Union[str, List[str]]
562
+ ) -> Union[int, List[int]]:
563
+ """
564
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
565
+ vocabulary.
566
+ Args:
567
+ tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
568
+ Returns:
569
+ :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
570
+ """
571
+ if tokens is None:
572
+ return None
573
+
574
+ if isinstance(tokens, str):
575
+ return self._convert_token_to_id_with_added_voc(tokens)
576
+
577
+ ids = []
578
+ for token in tokens:
579
+ ids.append(self._convert_token_to_id_with_added_voc(token))
580
+ return ids
581
+
582
+ def _convert_token_to_id_with_added_voc(self, token):
583
+ if token is None:
584
+ return None
585
+
586
+ return self.token_dictionary.get(token)
587
+
588
+ def __len__(self):
589
+ return len(self.token_dictionary)
590
+
591
+
592
+ class GeneformerTrainer(Trainer):
593
+ def __init__(self, *args, **kwargs):
594
+ data_collator = kwargs.get("data_collator")
595
+ token_dictionary = kwargs.get("token_dictionary")
596
+
597
+ if data_collator is None:
598
+ precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
599
+
600
+ # # Data Collator Functions
601
+ data_collator = DataCollatorForLanguageModeling(
602
+ tokenizer=precollator, mlm=True, mlm_probability=0.15
603
+ )
604
+ kwargs["data_collator"] = data_collator
605
+
606
+ super().__init__(*args, **kwargs)
607
+
608
+ # load previously saved length vector for dataset to speed up LengthGroupedSampler
609
+ # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
610
+ if kwargs.get("example_lengths_file"):
611
+ with open(kwargs.get("example_lengths_file"), "rb") as f:
612
+ self.example_lengths = pickle.load(f)
613
+ else:
614
+ raise Exception(
615
+ "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
616
+ )
617
+
618
+ # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
619
+ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
620
+ if not isinstance(self.train_dataset, collections.abc.Sized):
621
+ return None
622
+
623
+ generator = None
624
+ if self.args.world_size <= 1 and _is_torch_generator_available:
625
+ generator = torch.Generator()
626
+ generator.manual_seed(
627
+ int(torch.empty((), dtype=torch.int64).random_().item())
628
+ )
629
+
630
+ # Build the sampler.
631
+ if self.args.group_by_length:
632
+ if is_datasets_available() and isinstance(self.train_dataset, Dataset):
633
+ lengths = self.example_lengths
634
+ else:
635
+ lengths = None
636
+ print(f"Lengths: {len(lengths)}")
637
+ model_input_name = (
638
+ self.tokenizer.model_input_names[0]
639
+ if self.tokenizer is not None
640
+ else None
641
+ )
642
+ if self.args.world_size <= 1:
643
+ return LengthGroupedSampler(
644
+ self.train_dataset,
645
+ self.args.train_batch_size,
646
+ lengths=lengths,
647
+ model_input_name=model_input_name,
648
+ generator=generator,
649
+ )
650
+ else:
651
+ return CustomDistributedLengthGroupedSampler(
652
+ self.train_dataset,
653
+ self.args.train_batch_size,
654
+ num_replicas=self.args.world_size,
655
+ rank=self.args.process_index,
656
+ lengths=lengths,
657
+ model_input_name=model_input_name,
658
+ seed=self.args.seed,
659
+ )
660
+
661
+ else:
662
+ if self.args.world_size <= 1:
663
+ if _is_torch_generator_available:
664
+ return RandomSampler(self.train_dataset, generator=generator)
665
+ return RandomSampler(self.train_dataset)
666
+ elif (
667
+ self.args.parallel_mode
668
+ in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
669
+ and not self.args.dataloader_drop_last
670
+ ):
671
+ # Use a loop for TPUs when drop_last is False to have all batches have the same size.
672
+ return DistributedSamplerWithLoop(
673
+ self.train_dataset,
674
+ batch_size=self.args.per_device_train_batch_size,
675
+ num_replicas=self.args.world_size,
676
+ rank=self.args.process_index,
677
+ seed=self.args.seed,
678
+ )
679
+ else:
680
+ return DistributedSampler(
681
+ self.train_dataset,
682
+ num_replicas=self.args.world_size,
683
+ rank=self.args.process_index,
684
+ seed=self.args.seed,
685
+ )
686
+
687
+
688
+ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
689
+ r"""
690
+ Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
691
+ length while keeping a bit of randomness.
692
+ """
693
+ # Copied and adapted from PyTorch DistributedSampler.
694
+ def __init__(
695
+ self,
696
+ dataset: Dataset,
697
+ batch_size: int,
698
+ num_replicas: Optional[int] = None,
699
+ rank: Optional[int] = None,
700
+ seed: int = 0,
701
+ drop_last: bool = False,
702
+ lengths: Optional[List[int]] = None,
703
+ model_input_name: Optional[str] = None,
704
+ ):
705
+ if num_replicas is None:
706
+ if not dist.is_available():
707
+ raise RuntimeError("Requires distributed package to be available")
708
+ num_replicas = dist.get_world_size()
709
+ if rank is None:
710
+ if not dist.is_available():
711
+ raise RuntimeError("Requires distributed package to be available")
712
+ rank = dist.get_rank()
713
+ self.dataset = dataset
714
+ self.batch_size = batch_size
715
+ self.num_replicas = num_replicas
716
+ self.rank = rank
717
+ self.epoch = 0
718
+ self.drop_last = drop_last
719
+ # If the dataset length is evenly divisible by # of replicas, then there
720
+ # is no need to drop any data, since the dataset will be split equally.
721
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0:
722
+ # Split to nearest available length that is evenly divisible.
723
+ # This is to ensure each rank receives the same amount of data when
724
+ # using this Sampler.
725
+ self.num_samples = math.ceil(
726
+ (len(self.dataset) - self.num_replicas) / self.num_replicas
727
+ )
728
+ else:
729
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
730
+ self.total_size = self.num_samples * self.num_replicas
731
+ self.seed = seed
732
+ self.model_input_name = (
733
+ model_input_name if model_input_name is not None else "input_ids"
734
+ )
735
+
736
+ if lengths is None:
737
+ print("Lengths is none - calculating lengths.")
738
+ if (
739
+ not (
740
+ isinstance(dataset[0], dict)
741
+ or isinstance(dataset[0], BatchEncoding)
742
+ )
743
+ or self.model_input_name not in dataset[0]
744
+ ):
745
+ raise ValueError(
746
+ "Can only automatically infer lengths for datasets whose items are dictionaries with an "
747
+ f"'{self.model_input_name}' key."
748
+ )
749
+ lengths = [len(feature[self.model_input_name]) for feature in dataset]
750
+ self.lengths = lengths
751
+
752
+ def __iter__(self) -> Iterator:
753
+ # Deterministically shuffle based on epoch and seed
754
+ g = torch.Generator()
755
+ g.manual_seed(self.seed + self.epoch)
756
+
757
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
758
+
759
+ if not self.drop_last:
760
+ # add extra samples to make it evenly divisible
761
+ indices += indices[: (self.total_size - len(indices))]
762
+ else:
763
+ # remove tail of data to make it evenly divisible.
764
+ indices = indices[: self.total_size]
765
+ assert len(indices) == self.total_size
766
+
767
+ # subsample
768
+ indices = indices[self.rank : self.total_size : self.num_replicas]
769
+ assert len(indices) == self.num_samples
770
+
771
+ return iter(indices)
772
+
773
+
774
+ def get_length_grouped_indices(
775
+ lengths, batch_size, mega_batch_mult=None, generator=None
776
+ ):
777
+ """
778
+ Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
779
+ similar lengths. To do this, the indices are:
780
+
781
+ - randomly permuted
782
+ - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
783
+ - sorted by length in each mega-batch
784
+
785
+ The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
786
+ maximum length placed first, so that an OOM happens sooner rather than later.
787
+ """
788
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
789
+ if mega_batch_mult is None:
790
+ # mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
791
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
792
+ # Just in case, for tiny datasets
793
+ if mega_batch_mult == 0:
794
+ mega_batch_mult = 1
795
+
796
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
797
+ indices = torch.randperm(len(lengths), generator=generator)
798
+ megabatch_size = mega_batch_mult * batch_size
799
+ megabatches = [
800
+ indices[i : i + megabatch_size].tolist()
801
+ for i in range(0, len(lengths), megabatch_size)
802
+ ]
803
+ megabatches = [
804
+ list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
805
+ for megabatch in megabatches
806
+ ]
807
+
808
+ # The rest is to get the biggest batch first.
809
+ # Since each megabatch is sorted by descending length, the longest element is the first
810
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
811
+ max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
812
+ # Switch to put the longest element in first position
813
+ megabatches[0][0], megabatches[max_idx][0] = (
814
+ megabatches[max_idx][0],
815
+ megabatches[0][0],
816
+ )
817
+
818
+ return [item for sublist in megabatches for item in sublist]