Christina Theodoris commited on
Commit
402ba9b
·
1 Parent(s): b925dcc

Subclass collator for cell classification

Browse files
examples/cell_classification.ipynb CHANGED
@@ -1890,6 +1890,7 @@
1890
  " \"do_train\": True,\n",
1891
  " \"do_eval\": True,\n",
1892
  " \"evaluation_strategy\": \"epoch\",\n",
 
1893
  " \"logging_steps\": logging_steps,\n",
1894
  " \"group_by_length\": True,\n",
1895
  " \"length_column_name\": \"length\",\n",
@@ -1927,7 +1928,7 @@
1927
  ],
1928
  "metadata": {
1929
  "kernelspec": {
1930
- "display_name": "Python 3.8.6 64-bit ('3.8.6')",
1931
  "language": "python",
1932
  "name": "python3"
1933
  },
@@ -1941,7 +1942,7 @@
1941
  "name": "python",
1942
  "nbconvert_exporter": "python",
1943
  "pygments_lexer": "ipython3",
1944
- "version": "3.8.6"
1945
  },
1946
  "vscode": {
1947
  "interpreter": {
 
1890
  " \"do_train\": True,\n",
1891
  " \"do_eval\": True,\n",
1892
  " \"evaluation_strategy\": \"epoch\",\n",
1893
+ " \"save_strategy\": \"epoch\",\n",
1894
  " \"logging_steps\": logging_steps,\n",
1895
  " \"group_by_length\": True,\n",
1896
  " \"length_column_name\": \"length\",\n",
 
1928
  ],
1929
  "metadata": {
1930
  "kernelspec": {
1931
+ "display_name": "Python 3 (ipykernel)",
1932
  "language": "python",
1933
  "name": "python3"
1934
  },
 
1942
  "name": "python",
1943
  "nbconvert_exporter": "python",
1944
  "pygments_lexer": "ipython3",
1945
+ "version": "3.10.11"
1946
  },
1947
  "vscode": {
1948
  "interpreter": {
examples/gene_classification.ipynb CHANGED
@@ -2,7 +2,6 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
- "id": "234afff3",
6
  "metadata": {},
7
  "source": [
8
  "## Geneformer Fine-Tuning for Classification of Dosage-Sensitive vs. -Insensitive Transcription Factors (TFs)"
@@ -448,7 +447,6 @@
448
  {
449
  "cell_type": "code",
450
  "execution_count": null,
451
- "id": "d24e1ab7-0131-44bd-b458-1ce5ba31853e",
452
  "metadata": {},
453
  "outputs": [],
454
  "source": [
@@ -2385,7 +2383,7 @@
2385
  ],
2386
  "metadata": {
2387
  "kernelspec": {
2388
- "display_name": "Python 3.8.6 64-bit ('3.8.6')",
2389
  "language": "python",
2390
  "name": "python3"
2391
  },
@@ -2399,7 +2397,7 @@
2399
  "name": "python",
2400
  "nbconvert_exporter": "python",
2401
  "pygments_lexer": "ipython3",
2402
- "version": "3.8.6"
2403
  },
2404
  "vscode": {
2405
  "interpreter": {
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
 
5
  "metadata": {},
6
  "source": [
7
  "## Geneformer Fine-Tuning for Classification of Dosage-Sensitive vs. -Insensitive Transcription Factors (TFs)"
 
447
  {
448
  "cell_type": "code",
449
  "execution_count": null,
 
450
  "metadata": {},
451
  "outputs": [],
452
  "source": [
 
2383
  ],
2384
  "metadata": {
2385
  "kernelspec": {
2386
+ "display_name": "Python 3 (ipykernel)",
2387
  "language": "python",
2388
  "name": "python3"
2389
  },
 
2397
  "name": "python",
2398
  "nbconvert_exporter": "python",
2399
  "pygments_lexer": "ipython3",
2400
+ "version": "3.10.11"
2401
  },
2402
  "vscode": {
2403
  "interpreter": {
geneformer/__init__.py CHANGED
@@ -1,12 +1,11 @@
1
  from . import tokenizer
2
  from . import pretrainer
3
- from . import collator_for_cell_classification
4
- from . import collator_for_gene_classification
5
  from . import in_silico_perturber
6
  from . import in_silico_perturber_stats
7
  from .tokenizer import TranscriptomeTokenizer
8
  from .pretrainer import GeneformerPretrainer
9
- from .collator_for_gene_classification import DataCollatorForGeneClassification
10
- from .collator_for_cell_classification import DataCollatorForCellClassification
11
  from .in_silico_perturber import InSilicoPerturber
12
  from .in_silico_perturber_stats import InSilicoPerturberStats
 
1
  from . import tokenizer
2
  from . import pretrainer
3
+ from . import collator_for_classification
 
4
  from . import in_silico_perturber
5
  from . import in_silico_perturber_stats
6
  from .tokenizer import TranscriptomeTokenizer
7
  from .pretrainer import GeneformerPretrainer
8
+ from .collator_for_classification import DataCollatorForGeneClassification
9
+ from .collator_for_classification import DataCollatorForCellClassification
10
  from .in_silico_perturber import InSilicoPerturber
11
  from .in_silico_perturber_stats import InSilicoPerturberStats
geneformer/{collator_for_cell_classification.py → collator_for_classification.py} RENAMED
@@ -1,7 +1,7 @@
1
  """
2
- Geneformer collator for cell classification.
3
 
4
- Huggingface data collator modified to accommodate single-cell transcriptomics data for cell classification.
5
  """
6
  import numpy as np
7
  import torch
@@ -30,18 +30,6 @@ LARGE_INTEGER = int(
30
 
31
  # precollator functions
32
 
33
- def run_once(f):
34
- def wrapper(*args, **kwargs):
35
- if not wrapper.has_run:
36
- wrapper.has_run = True
37
- return f(*args, **kwargs)
38
- wrapper.has_run = False
39
- return wrapper
40
-
41
- @run_once
42
- def check_output_once(output):
43
- return print(output)
44
-
45
  class ExplicitEnum(Enum):
46
  """
47
  Enum with more explicit error message for missing values.
@@ -91,7 +79,7 @@ class TensorType(ExplicitEnum):
91
  JAX = "jax"
92
 
93
 
94
- class PrecollatorForCellClassification(SpecialTokensMixin):
95
  mask_token = "<mask>"
96
  mask_token_id = token_dictionary.get("<mask>")
97
  pad_token = "<pad>"
@@ -240,6 +228,7 @@ class PrecollatorForCellClassification(SpecialTokensMixin):
240
  Dict[str, List[EncodedInput]],
241
  List[Dict[str, EncodedInput]],
242
  ],
 
243
  padding: Union[bool, str, PaddingStrategy] = True,
244
  max_length: Optional[int] = None,
245
  pad_to_multiple_of: Optional[int] = None,
@@ -357,6 +346,7 @@ class PrecollatorForCellClassification(SpecialTokensMixin):
357
  if required_input and not isinstance(required_input[0], (list, tuple)):
358
  encoded_inputs = self._pad(
359
  encoded_inputs,
 
360
  max_length=max_length,
361
  padding_strategy=padding_strategy,
362
  pad_to_multiple_of=pad_to_multiple_of,
@@ -378,6 +368,7 @@ class PrecollatorForCellClassification(SpecialTokensMixin):
378
  inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
379
  outputs = self._pad(
380
  inputs,
 
381
  max_length=max_length,
382
  padding_strategy=padding_strategy,
383
  pad_to_multiple_of=pad_to_multiple_of,
@@ -388,12 +379,14 @@ class PrecollatorForCellClassification(SpecialTokensMixin):
388
  if key not in batch_outputs:
389
  batch_outputs[key] = []
390
  batch_outputs[key].append(value)
391
- del batch_outputs["label"]
 
392
  return BatchEncoding(batch_outputs, tensor_type=return_tensors)
393
 
394
  def _pad(
395
  self,
396
  encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
 
397
  max_length: Optional[int] = None,
398
  padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
399
  pad_to_multiple_of: Optional[int] = None,
@@ -446,6 +439,8 @@ class PrecollatorForCellClassification(SpecialTokensMixin):
446
  if "special_tokens_mask" in encoded_inputs:
447
  encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
448
  encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
 
 
449
  elif self.padding_side == "left":
450
  if return_attention_mask:
451
  encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
@@ -456,13 +451,13 @@ class PrecollatorForCellClassification(SpecialTokensMixin):
456
  if "special_tokens_mask" in encoded_inputs:
457
  encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
458
  encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
 
 
459
  else:
460
  raise ValueError("Invalid padding strategy:" + str(self.padding_side))
461
  elif return_attention_mask and "attention_mask" not in encoded_inputs:
462
  encoded_inputs["attention_mask"] = [1] * len(required_input)
463
 
464
- # check_output_once(encoded_inputs)
465
-
466
  return encoded_inputs
467
 
468
  def get_special_tokens_mask(
@@ -526,7 +521,7 @@ class PrecollatorForCellClassification(SpecialTokensMixin):
526
 
527
  # collator functions
528
 
529
- class DataCollatorForCellClassification(DataCollatorForTokenClassification):
530
  """
531
  Data collator that will dynamically pad the inputs received, as well as the labels.
532
  Args:
@@ -551,22 +546,49 @@ class DataCollatorForCellClassification(DataCollatorForTokenClassification):
551
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
552
  """
553
 
554
- tokenizer: PrecollatorForCellClassification()
 
555
  padding: Union[bool, str, PaddingStrategy] = True
556
  max_length: Optional[int] = None
557
  pad_to_multiple_of: Optional[int] = None
558
  label_pad_token_id: int = -100
 
 
 
 
 
 
 
 
 
559
 
560
- def __call__(self, features):
561
  label_name = "label" if "label" in features[0].keys() else "labels"
562
  labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
563
  batch = self.tokenizer.pad(
564
  features,
 
565
  padding=self.padding,
566
  max_length=self.max_length,
567
  pad_to_multiple_of=self.pad_to_multiple_of,
568
  return_tensors="pt",
569
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
 
571
  # Special handling for labels.
572
  # Ensure that tensor is created with the correct type
@@ -576,6 +598,5 @@ class DataCollatorForCellClassification(DataCollatorForTokenClassification):
576
  label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
577
  dtype = torch.long if isinstance(label, int) else torch.float
578
  batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
579
-
580
- batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
581
  return batch
 
1
  """
2
+ Geneformer collator for gene and cell classification.
3
 
4
+ Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
  import numpy as np
7
  import torch
 
30
 
31
  # precollator functions
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class ExplicitEnum(Enum):
34
  """
35
  Enum with more explicit error message for missing values.
 
79
  JAX = "jax"
80
 
81
 
82
+ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
83
  mask_token = "<mask>"
84
  mask_token_id = token_dictionary.get("<mask>")
85
  pad_token = "<pad>"
 
228
  Dict[str, List[EncodedInput]],
229
  List[Dict[str, EncodedInput]],
230
  ],
231
+ class_type, # options: "gene" or "cell"
232
  padding: Union[bool, str, PaddingStrategy] = True,
233
  max_length: Optional[int] = None,
234
  pad_to_multiple_of: Optional[int] = None,
 
346
  if required_input and not isinstance(required_input[0], (list, tuple)):
347
  encoded_inputs = self._pad(
348
  encoded_inputs,
349
+ class_type=class_type,
350
  max_length=max_length,
351
  padding_strategy=padding_strategy,
352
  pad_to_multiple_of=pad_to_multiple_of,
 
368
  inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
369
  outputs = self._pad(
370
  inputs,
371
+ class_type=class_type,
372
  max_length=max_length,
373
  padding_strategy=padding_strategy,
374
  pad_to_multiple_of=pad_to_multiple_of,
 
379
  if key not in batch_outputs:
380
  batch_outputs[key] = []
381
  batch_outputs[key].append(value)
382
+ if class_type == "cell":
383
+ del batch_outputs["label"]
384
  return BatchEncoding(batch_outputs, tensor_type=return_tensors)
385
 
386
  def _pad(
387
  self,
388
  encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
389
+ class_type, # options: "gene" or "cell"
390
  max_length: Optional[int] = None,
391
  padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
392
  pad_to_multiple_of: Optional[int] = None,
 
439
  if "special_tokens_mask" in encoded_inputs:
440
  encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
441
  encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
442
+ if class_type == "gene":
443
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
444
  elif self.padding_side == "left":
445
  if return_attention_mask:
446
  encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
 
451
  if "special_tokens_mask" in encoded_inputs:
452
  encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
453
  encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
454
+ if class_type == "gene":
455
+ encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
456
  else:
457
  raise ValueError("Invalid padding strategy:" + str(self.padding_side))
458
  elif return_attention_mask and "attention_mask" not in encoded_inputs:
459
  encoded_inputs["attention_mask"] = [1] * len(required_input)
460
 
 
 
461
  return encoded_inputs
462
 
463
  def get_special_tokens_mask(
 
521
 
522
  # collator functions
523
 
524
+ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
525
  """
526
  Data collator that will dynamically pad the inputs received, as well as the labels.
527
  Args:
 
546
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
547
  """
548
 
549
+ tokenizer = PrecollatorForGeneAndCellClassification()
550
+ class_type = "gene"
551
  padding: Union[bool, str, PaddingStrategy] = True
552
  max_length: Optional[int] = None
553
  pad_to_multiple_of: Optional[int] = None
554
  label_pad_token_id: int = -100
555
+
556
+ def __init__(self, *args, **kwargs) -> None:
557
+ super().__init__(
558
+ tokenizer=self.tokenizer,
559
+ padding=self.padding,
560
+ max_length=self.max_length,
561
+ pad_to_multiple_of=self.pad_to_multiple_of,
562
+ label_pad_token_id=self.label_pad_token_id,
563
+ *args, **kwargs)
564
 
565
+ def _prepare_batch(self, features):
566
  label_name = "label" if "label" in features[0].keys() else "labels"
567
  labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
568
  batch = self.tokenizer.pad(
569
  features,
570
+ class_type=self.class_type,
571
  padding=self.padding,
572
  max_length=self.max_length,
573
  pad_to_multiple_of=self.pad_to_multiple_of,
574
  return_tensors="pt",
575
  )
576
+ return batch
577
+
578
+ def __call__(self, features):
579
+ batch = self._prepare_batch(features)
580
+
581
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
582
+ return batch
583
+
584
+
585
+ class DataCollatorForCellClassification(DataCollatorForGeneClassification):
586
+
587
+ class_type = "cell"
588
+
589
+ def _prepare_batch(self, features):
590
+
591
+ batch = super()._prepare_batch(features)
592
 
593
  # Special handling for labels.
594
  # Ensure that tensor is created with the correct type
 
598
  label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
599
  dtype = torch.long if isinstance(label, int) else torch.float
600
  batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
601
+
 
602
  return batch
geneformer/collator_for_gene_classification.py DELETED
@@ -1,561 +0,0 @@
1
- """
2
- Geneformer collator for gene classification.
3
-
4
- Huggingface data collator modified to accommodate single-cell transcriptomics data for gene classification.
5
- """
6
- import numpy as np
7
- import torch
8
- import warnings
9
- from enum import Enum
10
- from typing import Dict, List, Optional, Union
11
-
12
- from transformers import (
13
- DataCollatorForTokenClassification,
14
- SpecialTokensMixin,
15
- BatchEncoding,
16
- )
17
- from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
- from transformers.utils.generic import _is_tensorflow, _is_torch
19
-
20
- from .pretrainer import token_dictionary
21
-
22
- EncodedInput = List[int]
23
- logger = logging.get_logger(__name__)
24
- VERY_LARGE_INTEGER = int(
25
- 1e30
26
- ) # This is used to set the max input length for a model with infinite size input
27
- LARGE_INTEGER = int(
28
- 1e20
29
- ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
30
-
31
- # precollator functions
32
-
33
- class ExplicitEnum(Enum):
34
- """
35
- Enum with more explicit error message for missing values.
36
- """
37
-
38
- @classmethod
39
- def _missing_(cls, value):
40
- raise ValueError(
41
- "%r is not a valid %s, please select one of %s"
42
- % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
43
- )
44
-
45
- class TruncationStrategy(ExplicitEnum):
46
- """
47
- Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
48
- tab-completion in an IDE.
49
- """
50
-
51
- ONLY_FIRST = "only_first"
52
- ONLY_SECOND = "only_second"
53
- LONGEST_FIRST = "longest_first"
54
- DO_NOT_TRUNCATE = "do_not_truncate"
55
-
56
-
57
-
58
- class PaddingStrategy(ExplicitEnum):
59
- """
60
- Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
61
- in an IDE.
62
- """
63
-
64
- LONGEST = "longest"
65
- MAX_LENGTH = "max_length"
66
- DO_NOT_PAD = "do_not_pad"
67
-
68
-
69
-
70
- class TensorType(ExplicitEnum):
71
- """
72
- Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
73
- tab-completion in an IDE.
74
- """
75
-
76
- PYTORCH = "pt"
77
- TENSORFLOW = "tf"
78
- NUMPY = "np"
79
- JAX = "jax"
80
-
81
-
82
- class PrecollatorForGeneClassification(SpecialTokensMixin):
83
- mask_token = "<mask>"
84
- mask_token_id = token_dictionary.get("<mask>")
85
- pad_token = "<pad>"
86
- pad_token_id = token_dictionary.get("<pad>")
87
- padding_side = "right"
88
- all_special_ids = [
89
- token_dictionary.get("<mask>"),
90
- token_dictionary.get("<pad>")
91
- ]
92
- model_input_names = ["input_ids"]
93
-
94
- def _get_padding_truncation_strategies(
95
- self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
96
- ):
97
- """
98
- Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
99
- and pad_to_max_length) and behaviors.
100
- """
101
- old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
102
- old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
103
-
104
- # Backward compatibility for previous behavior, maybe we should deprecate it:
105
- # If you only set max_length, it activates truncation for max_length
106
- if max_length is not None and padding is False and truncation is False:
107
- if verbose:
108
- if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
109
- logger.warning(
110
- "Truncation was not explicitly activated but `max_length` is provided a specific value, "
111
- "please use `truncation=True` to explicitly truncate examples to max length. "
112
- "Defaulting to 'longest_first' truncation strategy. "
113
- "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
114
- "more precisely by providing a specific strategy to `truncation`."
115
- )
116
- self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
117
- truncation = "longest_first"
118
-
119
- # Get padding strategy
120
- if padding is False and old_pad_to_max_length:
121
- if verbose:
122
- warnings.warn(
123
- "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
124
- "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
125
- "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
126
- "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
127
- "maximal input size of the model (e.g. 512 for Bert).",
128
- FutureWarning,
129
- )
130
- if max_length is None:
131
- padding_strategy = PaddingStrategy.LONGEST
132
- else:
133
- padding_strategy = PaddingStrategy.MAX_LENGTH
134
- elif padding is not False:
135
- if padding is True:
136
- padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
137
- elif not isinstance(padding, PaddingStrategy):
138
- padding_strategy = PaddingStrategy(padding)
139
- elif isinstance(padding, PaddingStrategy):
140
- padding_strategy = padding
141
- else:
142
- padding_strategy = PaddingStrategy.DO_NOT_PAD
143
-
144
- # Get truncation strategy
145
- if truncation is False and old_truncation_strategy != "do_not_truncate":
146
- if verbose:
147
- warnings.warn(
148
- "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
149
- "use `truncation=True` to truncate examples to a max length. You can give a specific "
150
- "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
151
- "maximal input size of the model (e.g. 512 for Bert). "
152
- " If you have pairs of inputs, you can give a specific truncation strategy selected among "
153
- "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
154
- "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
155
- "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
156
- FutureWarning,
157
- )
158
- truncation_strategy = TruncationStrategy(old_truncation_strategy)
159
- elif truncation is not False:
160
- if truncation is True:
161
- truncation_strategy = (
162
- TruncationStrategy.LONGEST_FIRST
163
- ) # Default to truncate the longest sequences in pairs of inputs
164
- elif not isinstance(truncation, TruncationStrategy):
165
- truncation_strategy = TruncationStrategy(truncation)
166
- elif isinstance(truncation, TruncationStrategy):
167
- truncation_strategy = truncation
168
- else:
169
- truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
170
-
171
- # Set max length if needed
172
- if max_length is None:
173
- if padding_strategy == PaddingStrategy.MAX_LENGTH:
174
- if self.model_max_length > LARGE_INTEGER:
175
- if verbose:
176
- if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
177
- logger.warning(
178
- "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
179
- "Default to no padding."
180
- )
181
- self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
182
- padding_strategy = PaddingStrategy.DO_NOT_PAD
183
- else:
184
- max_length = self.model_max_length
185
-
186
- if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
187
- if self.model_max_length > LARGE_INTEGER:
188
- if verbose:
189
- if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
190
- logger.warning(
191
- "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
192
- "Default to no truncation."
193
- )
194
- self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
195
- truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
196
- else:
197
- max_length = self.model_max_length
198
-
199
- # Test if we have a padding token
200
- if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
201
- raise ValueError(
202
- "Asking to pad but the tokenizer does not have a padding token. "
203
- "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
204
- "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
205
- )
206
-
207
- # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
208
- if (
209
- truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
210
- and padding_strategy != PaddingStrategy.DO_NOT_PAD
211
- and pad_to_multiple_of is not None
212
- and max_length is not None
213
- and (max_length % pad_to_multiple_of != 0)
214
- ):
215
- raise ValueError(
216
- f"Truncation and padding are both activated but "
217
- f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
218
- )
219
-
220
- return padding_strategy, truncation_strategy, max_length, kwargs
221
-
222
- def pad(
223
- self,
224
- encoded_inputs: Union[
225
- BatchEncoding,
226
- List[BatchEncoding],
227
- Dict[str, EncodedInput],
228
- Dict[str, List[EncodedInput]],
229
- List[Dict[str, EncodedInput]],
230
- ],
231
- padding: Union[bool, str, PaddingStrategy] = True,
232
- max_length: Optional[int] = None,
233
- pad_to_multiple_of: Optional[int] = None,
234
- return_attention_mask: Optional[bool] = True,
235
- return_tensors: Optional[Union[str, TensorType]] = None,
236
- verbose: bool = True,
237
- ) -> BatchEncoding:
238
- """
239
- Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
240
- in the batch.
241
-
242
- Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
243
- ``self.pad_token_id`` and ``self.pad_token_type_id``)
244
-
245
- .. note::
246
-
247
- If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
248
- result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
249
- case of PyTorch tensors, you will lose the specific device of your tensors however.
250
-
251
- Args:
252
- encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
253
- Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
254
- List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
255
- List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
256
- well as in a PyTorch Dataloader collate function.
257
-
258
- Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
259
- see the note above for the return type.
260
- padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
261
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
262
- index) among:
263
-
264
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
265
- single sequence if provided).
266
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
267
- maximum acceptable input length for the model if that argument is not provided.
268
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
269
- different lengths).
270
- max_length (:obj:`int`, `optional`):
271
- Maximum length of the returned list and optionally padding length (see above).
272
- pad_to_multiple_of (:obj:`int`, `optional`):
273
- If set will pad the sequence to a multiple of the provided value.
274
-
275
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
276
- >= 7.5 (Volta).
277
- return_attention_mask (:obj:`bool`, `optional`):
278
- Whether to return the attention mask. If left to the default, will return the attention mask according
279
- to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
280
-
281
- `What are attention masks? <../glossary.html#attention-mask>`__
282
- return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
283
- If set, will return tensors instead of list of python integers. Acceptable values are:
284
-
285
- * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
286
- * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
287
- * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
288
- verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
289
- Whether or not to print more information and warnings.
290
- """
291
- # If we have a list of dicts, let's convert it in a dict of lists
292
- # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
293
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
294
- encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
295
-
296
- # The model's main input name, usually `input_ids`, has be passed for padding
297
- if self.model_input_names[0] not in encoded_inputs:
298
- raise ValueError(
299
- "You should supply an encoding or a list of encodings to this method"
300
- f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
301
- )
302
-
303
- required_input = encoded_inputs[self.model_input_names[0]]
304
-
305
- if not required_input:
306
- if return_attention_mask:
307
- encoded_inputs["attention_mask"] = []
308
- return encoded_inputs
309
-
310
- # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
311
- # and rebuild them afterwards if no return_tensors is specified
312
- # Note that we lose the specific device the tensor may be on for PyTorch
313
-
314
- first_element = required_input[0]
315
- if isinstance(first_element, (list, tuple)):
316
- # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
317
- index = 0
318
- while len(required_input[index]) == 0:
319
- index += 1
320
- if index < len(required_input):
321
- first_element = required_input[index][0]
322
- # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
323
- if not isinstance(first_element, (int, list, tuple)):
324
- if is_tf_available() and _is_tensorflow(first_element):
325
- return_tensors = "tf" if return_tensors is None else return_tensors
326
- elif is_torch_available() and _is_torch(first_element):
327
- return_tensors = "pt" if return_tensors is None else return_tensors
328
- elif isinstance(first_element, np.ndarray):
329
- return_tensors = "np" if return_tensors is None else return_tensors
330
- else:
331
- raise ValueError(
332
- f"type of {first_element} unknown: {type(first_element)}. "
333
- f"Should be one of a python, numpy, pytorch or tensorflow object."
334
- )
335
-
336
- for key, value in encoded_inputs.items():
337
- encoded_inputs[key] = to_py_obj(value)
338
-
339
- # Convert padding_strategy in PaddingStrategy
340
- padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
341
- padding=padding, max_length=max_length, verbose=verbose
342
- )
343
-
344
- required_input = encoded_inputs[self.model_input_names[0]]
345
- if required_input and not isinstance(required_input[0], (list, tuple)):
346
- encoded_inputs = self._pad(
347
- encoded_inputs,
348
- max_length=max_length,
349
- padding_strategy=padding_strategy,
350
- pad_to_multiple_of=pad_to_multiple_of,
351
- return_attention_mask=return_attention_mask,
352
- )
353
- return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
354
-
355
- batch_size = len(required_input)
356
- assert all(
357
- len(v) == batch_size for v in encoded_inputs.values()
358
- ), "Some items in the output dictionary have a different batch size than others."
359
-
360
- if padding_strategy == PaddingStrategy.LONGEST:
361
- max_length = max(len(inputs) for inputs in required_input)
362
- padding_strategy = PaddingStrategy.MAX_LENGTH
363
-
364
- batch_outputs = {}
365
- for i in range(batch_size):
366
- inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
367
- outputs = self._pad(
368
- inputs,
369
- max_length=max_length,
370
- padding_strategy=padding_strategy,
371
- pad_to_multiple_of=pad_to_multiple_of,
372
- return_attention_mask=return_attention_mask,
373
- )
374
-
375
- for key, value in outputs.items():
376
- if key not in batch_outputs:
377
- batch_outputs[key] = []
378
- batch_outputs[key].append(value)
379
-
380
- return BatchEncoding(batch_outputs, tensor_type=return_tensors)
381
-
382
- def _pad(
383
- self,
384
- encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
385
- max_length: Optional[int] = None,
386
- padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
387
- pad_to_multiple_of: Optional[int] = None,
388
- return_attention_mask: Optional[bool] = True,
389
- ) -> dict:
390
- """
391
- Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
392
-
393
- Args:
394
- encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
395
- max_length: maximum length of the returned list and optionally padding length (see below).
396
- Will truncate by taking into account the special tokens.
397
- padding_strategy: PaddingStrategy to use for padding.
398
-
399
- - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
400
- - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
401
- - PaddingStrategy.DO_NOT_PAD: Do not pad
402
- The tokenizer padding sides are defined in self.padding_side:
403
-
404
- - 'left': pads on the left of the sequences
405
- - 'right': pads on the right of the sequences
406
- pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
407
- This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
408
- >= 7.5 (Volta).
409
- return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
410
- """
411
- # Load from model defaults
412
- if return_attention_mask is None:
413
- return_attention_mask = "attention_mask" in self.model_input_names
414
-
415
- required_input = encoded_inputs[self.model_input_names[0]]
416
-
417
- if padding_strategy == PaddingStrategy.LONGEST:
418
- max_length = len(required_input)
419
-
420
- if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
421
- max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
422
-
423
- needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
424
-
425
- if needs_to_be_padded:
426
- difference = max_length - len(required_input)
427
- if self.padding_side == "right":
428
- if return_attention_mask:
429
- encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
430
- if "token_type_ids" in encoded_inputs:
431
- encoded_inputs["token_type_ids"] = (
432
- encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
433
- )
434
- if "special_tokens_mask" in encoded_inputs:
435
- encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
436
- encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
437
- encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
438
- elif self.padding_side == "left":
439
- if return_attention_mask:
440
- encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
441
- if "token_type_ids" in encoded_inputs:
442
- encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
443
- "token_type_ids"
444
- ]
445
- if "special_tokens_mask" in encoded_inputs:
446
- encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
447
- encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
448
- encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
449
- else:
450
- raise ValueError("Invalid padding strategy:" + str(self.padding_side))
451
- elif return_attention_mask and "attention_mask" not in encoded_inputs:
452
- encoded_inputs["attention_mask"] = [1] * len(required_input)
453
-
454
- # check_output_once(encoded_inputs)
455
-
456
- return encoded_inputs
457
-
458
- def get_special_tokens_mask(
459
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
460
- ) -> List[int]:
461
- """
462
- Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
463
- special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
464
- Args:
465
- token_ids_0 (:obj:`List[int]`):
466
- List of ids of the first sequence.
467
- token_ids_1 (:obj:`List[int]`, `optional`):
468
- List of ids of the second sequence.
469
- already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
470
- Whether or not the token list is already formatted with special tokens for the model.
471
- Returns:
472
- A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
473
- """
474
- assert already_has_special_tokens and token_ids_1 is None, (
475
- "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
476
- "Please use a slow (full python) tokenizer to activate this argument."
477
- "Or set `return_special_tokens_mask=True` when calling the encoding method "
478
- "to get the special tokens mask in any tokenizer. "
479
- )
480
-
481
- all_special_ids = self.all_special_ids # cache the property
482
-
483
- special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
484
-
485
- return special_tokens_mask
486
-
487
- def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
488
- """
489
- Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
490
- vocabulary.
491
- Args:
492
- tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
493
- Returns:
494
- :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
495
- """
496
- if tokens is None:
497
- return None
498
-
499
- if isinstance(tokens, str):
500
- return self._convert_token_to_id_with_added_voc(tokens)
501
-
502
- ids = []
503
- for token in tokens:
504
- ids.append(self._convert_token_to_id_with_added_voc(token))
505
- return ids
506
-
507
- def _convert_token_to_id_with_added_voc(self, token):
508
- if token is None:
509
- return None
510
-
511
- return token_dictionary.get(token)
512
-
513
- def __len__(self):
514
- return len(token_dictionary)
515
-
516
- # collator functions
517
-
518
- class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
519
- """
520
- Data collator that will dynamically pad the inputs received, as well as the labels.
521
- Args:
522
- tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
523
- The tokenizer used for encoding the data.
524
- padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
525
- Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
526
- among:
527
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
528
- sequence if provided).
529
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
530
- maximum acceptable input length for the model if that argument is not provided.
531
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
532
- different lengths).
533
- max_length (:obj:`int`, `optional`):
534
- Maximum length of the returned list and optionally padding length (see above).
535
- pad_to_multiple_of (:obj:`int`, `optional`):
536
- If set will pad the sequence to a multiple of the provided value.
537
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
538
- 7.5 (Volta).
539
- label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
540
- The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
541
- """
542
-
543
- tokenizer: PrecollatorForGeneClassification()
544
- padding: Union[bool, str, PaddingStrategy] = True
545
- max_length: Optional[int] = None
546
- pad_to_multiple_of: Optional[int] = None
547
- label_pad_token_id: int = -100
548
-
549
- def __call__(self, features):
550
- label_name = "label" if "label" in features[0].keys() else "labels"
551
- labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
552
- batch = self.tokenizer.pad(
553
- features,
554
- padding=self.padding,
555
- max_length=self.max_length,
556
- pad_to_multiple_of=self.pad_to_multiple_of,
557
- return_tensors="pt",
558
- )
559
-
560
- batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
561
- return batch