File size: 36,601 Bytes
bcc03e8 088ea6e bcc03e8 088ea6e bcc03e8 fb130e6 bcc03e8 fb130e6 b925dcc bcc03e8 1366905 bcc03e8 088ea6e bcc03e8 fb130e6 b925dcc fb130e6 bcc03e8 fb130e6 bcc03e8 b925dcc bcc03e8 b925dcc bcc03e8 b925dcc bcc03e8 b925dcc bcc03e8 fb130e6 bcc03e8 fb130e6 bcc03e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 |
"""
Geneformer precollator and pretrainer.
Huggingface data collator and trainer modified to accommodate single-cell transcriptomics data.
"""
import collections
import math
import pickle
import warnings
from enum import Enum
from typing import Dict, Iterator, List, Optional, Union
import numpy as np
import torch
from datasets import Dataset
from packaging import version
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from transformers import (
BatchEncoding,
DataCollatorForLanguageModeling,
SpecialTokensMixin,
Trainer,
)
from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
LengthGroupedSampler,
)
from transformers.training_args import ParallelMode
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
from transformers.utils.generic import _is_tensorflow, _is_torch
logger = logging.get_logger(__name__)
EncodedInput = List[int]
VERY_LARGE_INTEGER = int(
1e30
) # This is used to set the max input length for a model with infinite size input
LARGE_INTEGER = int(
1e20
) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
else:
import torch.distributed as dist
_is_torch_generator_available = False
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_torch_generator_available = True
class ExplicitEnum(Enum):
"""
Enum with more explicit error message for missing values.
"""
@classmethod
def _missing_(cls, value):
raise ValueError(
"%r is not a valid %s, please select one of %s"
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
)
class TruncationStrategy(ExplicitEnum):
"""
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
tab-completion in an IDE.
"""
ONLY_FIRST = "only_first"
ONLY_SECOND = "only_second"
LONGEST_FIRST = "longest_first"
DO_NOT_TRUNCATE = "do_not_truncate"
class PaddingStrategy(ExplicitEnum):
"""
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
in an IDE.
"""
LONGEST = "longest"
MAX_LENGTH = "max_length"
DO_NOT_PAD = "do_not_pad"
class TensorType(ExplicitEnum):
"""
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
tab-completion in an IDE.
"""
PYTORCH = "pt"
TENSORFLOW = "tf"
NUMPY = "np"
JAX = "jax"
class GeneformerPreCollator(SpecialTokensMixin):
def __init__(self, *args, **kwargs) -> None:
super().__init__(mask_token="<mask>", pad_token="<pad>")
self.token_dictionary = kwargs.get("token_dictionary")
self.padding_side = "right"
self.model_input_names = ["input_ids"]
def convert_ids_to_tokens(self, value):
return self.token_dictionary.get(value)
def _get_padding_truncation_strategies(
self,
padding=False,
truncation=False,
max_length=None,
pad_to_multiple_of=None,
verbose=True,
**kwargs,
):
"""
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
and pad_to_max_length) and behaviors.
"""
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
# Backward compatibility for previous behavior, maybe we should deprecate it:
# If you only set max_length, it activates truncation for max_length
if max_length is not None and padding is False and truncation is False:
if verbose:
if not self.deprecation_warnings.get(
"Truncation-not-explicitly-activated", False
):
logger.warning(
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
"please use `truncation=True` to explicitly truncate examples to max length. "
"Defaulting to 'longest_first' truncation strategy. "
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
"more precisely by providing a specific strategy to `truncation`."
)
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
truncation = "longest_first"
# Get padding strategy
if padding is False and old_pad_to_max_length:
if verbose:
warnings.warn(
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
"maximal input size of the model (e.g. 512 for Bert).",
FutureWarning,
)
if max_length is None:
padding_strategy = PaddingStrategy.LONGEST
else:
padding_strategy = PaddingStrategy.MAX_LENGTH
elif padding is not False:
if padding is True:
padding_strategy = (
PaddingStrategy.LONGEST
) # Default to pad to the longest sequence in the batch
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
elif isinstance(padding, PaddingStrategy):
padding_strategy = padding
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
# Get truncation strategy
if truncation is False and old_truncation_strategy != "do_not_truncate":
if verbose:
warnings.warn(
"The `truncation_strategy` argument is deprecated and will be removed in a future version, "
"use `truncation=True` to truncate examples to a max length. You can give a specific "
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
"maximal input size of the model (e.g. 512 for Bert). "
" If you have pairs of inputs, you can give a specific truncation strategy selected among "
"`truncation='only_first'` (will only truncate the first sentence in the pairs) "
"`truncation='only_second'` (will only truncate the second sentence in the pairs) "
"or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
FutureWarning,
)
truncation_strategy = TruncationStrategy(old_truncation_strategy)
elif truncation is not False:
if truncation is True:
truncation_strategy = (
TruncationStrategy.LONGEST_FIRST
) # Default to truncate the longest sequences in pairs of inputs
elif not isinstance(truncation, TruncationStrategy):
truncation_strategy = TruncationStrategy(truncation)
elif isinstance(truncation, TruncationStrategy):
truncation_strategy = truncation
else:
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
# Set max length if needed
if max_length is None:
if padding_strategy == PaddingStrategy.MAX_LENGTH:
if self.model_max_length > LARGE_INTEGER:
if verbose:
if not self.deprecation_warnings.get(
"Asking-to-pad-to-max_length", False
):
logger.warning(
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
"Default to no padding."
)
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
padding_strategy = PaddingStrategy.DO_NOT_PAD
else:
max_length = self.model_max_length
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
if self.model_max_length > LARGE_INTEGER:
if verbose:
if not self.deprecation_warnings.get(
"Asking-to-truncate-to-max_length", False
):
logger.warning(
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
"Default to no truncation."
)
self.deprecation_warnings[
"Asking-to-truncate-to-max_length"
] = True
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
else:
max_length = self.model_max_length
# Test if we have a padding token
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
not self.pad_token or self.pad_token_id < 0
):
raise ValueError(
"Asking to pad but the tokenizer does not have a padding token. "
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
)
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
if (
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
and padding_strategy != PaddingStrategy.DO_NOT_PAD
and pad_to_multiple_of is not None
and max_length is not None
and (max_length % pad_to_multiple_of != 0)
):
raise ValueError(
f"Truncation and padding are both activated but "
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
)
return padding_strategy, truncation_strategy, max_length, kwargs
def pad(
self,
encoded_inputs: Union[
BatchEncoding,
List[BatchEncoding],
Dict[str, EncodedInput],
Dict[str, List[EncodedInput]],
List[Dict[str, EncodedInput]],
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = True,
return_tensors: Optional[Union[str, TensorType]] = None,
verbose: bool = True,
) -> BatchEncoding:
"""
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
in the batch.
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
``self.pad_token_id`` and ``self.pad_token_type_id``)
.. note::
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
case of PyTorch tensors, you will lose the specific device of your tensors however.
Args:
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
well as in a PyTorch Dataloader collate function.
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
see the note above for the return type.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta).
return_attention_mask (:obj:`bool`, `optional`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
`What are attention masks? <../glossary.html#attention-mask>`__
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to print more information and warnings.
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(
encoded_inputs[0], (dict, BatchEncoding)
):
encoded_inputs = {
key: [example[key] for example in encoded_inputs]
for key in encoded_inputs[0].keys()
}
# The model's main input name, usually `input_ids`, has be passed for padding
if self.model_input_names[0] not in encoded_inputs:
raise ValueError(
"You should supply an encoding or a list of encodings to this method"
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
)
required_input = encoded_inputs[self.model_input_names[0]]
if not required_input:
if return_attention_mask:
encoded_inputs["attention_mask"] = []
return encoded_inputs
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = required_input[0]
if isinstance(first_element, (list, tuple)):
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
index = 0
while len(required_input[index]) == 0:
index += 1
if index < len(required_input):
first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
f"Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in encoded_inputs.items():
encoded_inputs[key] = to_py_obj(value)
# Convert padding_strategy in PaddingStrategy
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
padding=padding, max_length=max_length, verbose=verbose
)
required_input = encoded_inputs[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], (list, tuple)):
encoded_inputs = self._pad(
encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
batch_size = len(required_input)
assert all(
len(v) == batch_size for v in encoded_inputs.values()
), "Some items in the output dictionary have a different batch size than others."
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
outputs = self._pad(
inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
required_input = encoded_inputs[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if (
max_length is not None
and pad_to_multiple_of is not None
and (max_length % pad_to_multiple_of != 0)
):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = (
padding_strategy != PaddingStrategy.DO_NOT_PAD
and len(required_input) != max_length
)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(required_input) + [
0
] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"]
+ [self.pad_token_type_id] * difference
)
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = (
encoded_inputs["special_tokens_mask"] + [1] * difference
)
encoded_inputs[self.model_input_names[0]] = (
required_input + [self.pad_token_id] * difference
)
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
required_input
)
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [
self.pad_token_type_id
] * difference + encoded_inputs["token_type_ids"]
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = [
1
] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs[self.model_input_names[0]] = [
self.pad_token_id
] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
return encoded_inputs
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids of the first sequence.
token_ids_1 (:obj:`List[int]`, `optional`):
List of ids of the second sequence.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
assert already_has_special_tokens and token_ids_1 is None, (
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
"Please use a slow (full python) tokenizer to activate this argument."
"Or set `return_special_tokens_mask=True` when calling the encoding method "
"to get the special tokens mask in any tokenizer. "
)
all_special_ids = self.all_special_ids # cache the property
special_tokens_mask = [
1 if token in all_special_ids else 0 for token in token_ids_0
]
return special_tokens_mask
def convert_tokens_to_ids(
self, tokens: Union[str, List[str]]
) -> Union[int, List[int]]:
"""
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
vocabulary.
Args:
tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
Returns:
:obj:`int` or :obj:`List[int]`: The token id or list of token ids.
"""
if tokens is None:
return None
if isinstance(tokens, str):
return self._convert_token_to_id_with_added_voc(tokens)
ids = []
for token in tokens:
ids.append(self._convert_token_to_id_with_added_voc(token))
return ids
def _convert_token_to_id_with_added_voc(self, token):
if token is None:
return None
return self.token_dictionary.get(token)
def __len__(self):
return len(self.token_dictionary)
class GeneformerPretrainer(Trainer):
def __init__(self, *args, **kwargs):
data_collator = kwargs.get("data_collator", None)
token_dictionary = kwargs.pop("token_dictionary")
mlm = kwargs.pop("mlm", True)
mlm_probability = kwargs.pop("mlm_probability", 0.15)
if data_collator is None:
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
# # Data Collator Functions
data_collator = DataCollatorForLanguageModeling(
tokenizer=precollator, mlm=mlm, mlm_probability=mlm_probability
)
kwargs["data_collator"] = data_collator
# load previously saved length vector for dataset to speed up LengthGroupedSampler
# pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
example_lengths_file = kwargs.pop("example_lengths_file")
if example_lengths_file:
with open(example_lengths_file, "rb") as f:
self.example_lengths = pickle.load(f)
else:
raise Exception(
"example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
)
super().__init__(*args, **kwargs)
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
return None
generator = None
if self.args.world_size <= 1 and _is_torch_generator_available:
generator = torch.Generator()
generator.manual_seed(
int(torch.empty((), dtype=torch.int64).random_().item())
)
# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, Dataset):
lengths = self.example_lengths
else:
lengths = None
model_input_name = (
self.tokenizer.model_input_names[0]
if self.tokenizer is not None
else None
)
if self.args.world_size <= 1:
return LengthGroupedSampler(
dataset=self.train_dataset,
batch_size=self.args.train_batch_size,
lengths=lengths,
model_input_name=model_input_name,
generator=generator,
)
else:
return CustomDistributedLengthGroupedSampler(
dataset=self.train_dataset,
batch_size=self.args.train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
lengths=lengths,
model_input_name=model_input_name,
seed=self.args.seed,
)
else:
if self.args.world_size <= 1:
if _is_torch_generator_available:
return RandomSampler(self.train_dataset, generator=generator)
return RandomSampler(self.train_dataset)
elif (
self.args.parallel_mode
in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
and not self.args.dataloader_drop_last
):
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
else:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
r"""
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
length while keeping a bit of randomness.
"""
# Copied and adapted from PyTorch DistributedSampler.
def __init__(
self,
dataset: Dataset,
batch_size: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
drop_last: bool = False,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.batch_size = batch_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
self.model_input_name = (
model_input_name if model_input_name is not None else "input_ids"
)
if lengths is None:
print("Lengths is none - calculating lengths.")
if (
not (
isinstance(dataset[0], dict)
or isinstance(dataset[0], BatchEncoding)
)
or self.model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key."
)
lengths = [len(feature[self.model_input_name]) for feature in dataset]
self.lengths = lengths
def __iter__(self) -> Iterator:
# Deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
if not self.drop_last:
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def get_length_grouped_indices(
lengths, batch_size, mega_batch_mult=None, generator=None
):
"""
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
similar lengths. To do this, the indices are:
- randomly permuted
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
- sorted by length in each mega-batch
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
maximum length placed first, so that an OOM happens sooner rather than later.
"""
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
if mega_batch_mult is None:
# mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
# Just in case, for tiny datasets
if mega_batch_mult == 0:
mega_batch_mult = 1
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = mega_batch_mult * batch_size
megabatches = [
indices[i : i + megabatch_size].tolist()
for i in range(0, len(lengths), megabatch_size)
]
megabatches = [
list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
for megabatch in megabatches
]
# The rest is to get the biggest batch first.
# Since each megabatch is sorted by descending length, the longest element is the first
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
# Switch to put the longest element in first position
megabatches[0][0], megabatches[max_idx][0] = (
megabatches[max_idx][0],
megabatches[0][0],
)
return [item for sublist in megabatches for item in sublist]
|