kiddothe2b
commited on
Commit
·
e917aad
1
Parent(s):
7c46478
Add HAT implementation files
Browse files- configuration_hat.py +150 -0
- modelling_hat.py +0 -0
- tokenization_hat.py +249 -0
configuration_hat.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
""" HAT configuration"""
|
14 |
+
from collections import OrderedDict
|
15 |
+
from typing import Mapping
|
16 |
+
|
17 |
+
from transformers.onnx import OnnxConfig
|
18 |
+
from transformers.utils import logging
|
19 |
+
from transformers import PretrainedConfig
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
HAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
25 |
+
"kiddothe2b/hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096/resolve/main/config.json",
|
26 |
+
"kiddothe2b/adhoc-hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-base-4096/resolve/main/config.json",
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
class HATConfig(PretrainedConfig):
|
31 |
+
r"""
|
32 |
+
This is the configuration class to store the configuration of a :class:`~transformers.HAT`.
|
33 |
+
It is used to instantiate a HAT model according to the specified arguments,
|
34 |
+
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
|
35 |
+
to that of the HAT `kiddothe2b/hierarchical-transformer-base-4096
|
36 |
+
<https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096>`__ architecture.
|
37 |
+
|
38 |
+
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
39 |
+
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
40 |
+
|
41 |
+
|
42 |
+
Args:
|
43 |
+
vocab_size (:obj:`int`, `optional`, defaults to 30522):
|
44 |
+
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
|
45 |
+
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
|
46 |
+
:class:`~transformers.TFBertModel`.
|
47 |
+
max_sentences (:obj:`int`, `optional`, defaults to 64):
|
48 |
+
The maximum number of sentences that this model might ever be used with.
|
49 |
+
max_sentence_size (:obj:`int`, `optional`, defaults to 128):
|
50 |
+
The maximum sentence length that this model might ever be used with.
|
51 |
+
model_max_length (:obj:`int`, `optional`, defaults to 8192):
|
52 |
+
The maximum sequence length (max_sentences * max_sentence_size) that this model might ever be used with
|
53 |
+
encoder_layout (:obj:`Dict`):
|
54 |
+
The sentence/document encoder layout.
|
55 |
+
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
56 |
+
Dimensionality of the encoder layers and the pooler layer.
|
57 |
+
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
58 |
+
Number of hidden layers in the Transformer encoder.
|
59 |
+
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
60 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
61 |
+
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
62 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
63 |
+
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
|
64 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
65 |
+
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
|
66 |
+
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
67 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
68 |
+
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
69 |
+
The dropout ratio for the attention probabilities.
|
70 |
+
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
71 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
72 |
+
just in case (e.g., 512 or 1024 or 2048).
|
73 |
+
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
74 |
+
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
|
75 |
+
:class:`~transformers.TFBertModel`.
|
76 |
+
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
77 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
78 |
+
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
79 |
+
The epsilon used by the layer normalization layers.
|
80 |
+
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
|
81 |
+
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
|
82 |
+
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
|
83 |
+
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
|
84 |
+
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
85 |
+
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
86 |
+
<https://arxiv.org/abs/2009.13658>`__.
|
87 |
+
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
88 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
89 |
+
relevant if ``config.is_decoder=True``.
|
90 |
+
classifier_dropout (:obj:`float`, `optional`):
|
91 |
+
The dropout ratio for the classification head.
|
92 |
+
"""
|
93 |
+
model_type = "hierarchical-transformer"
|
94 |
+
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
vocab_size=30522,
|
98 |
+
hidden_size=768,
|
99 |
+
max_sentences=64,
|
100 |
+
max_sentence_size=128,
|
101 |
+
model_max_length=8192,
|
102 |
+
num_hidden_layers=12,
|
103 |
+
num_attention_heads=12,
|
104 |
+
intermediate_size=3072,
|
105 |
+
hidden_act="gelu",
|
106 |
+
hidden_dropout_prob=0.1,
|
107 |
+
attention_probs_dropout_prob=0.1,
|
108 |
+
max_position_embeddings=512,
|
109 |
+
type_vocab_size=2,
|
110 |
+
initializer_range=0.02,
|
111 |
+
layer_norm_eps=1e-12,
|
112 |
+
pad_token_id=0,
|
113 |
+
position_embedding_type="absolute",
|
114 |
+
encoder_layout=None,
|
115 |
+
use_cache=True,
|
116 |
+
classifier_dropout=None,
|
117 |
+
**kwargs
|
118 |
+
):
|
119 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
120 |
+
|
121 |
+
self.vocab_size = vocab_size
|
122 |
+
self.hidden_size = hidden_size
|
123 |
+
self.max_sentences = max_sentences
|
124 |
+
self.max_sentence_size = max_sentence_size
|
125 |
+
self.model_max_length = model_max_length
|
126 |
+
self.encoder_layout = encoder_layout
|
127 |
+
self.num_hidden_layers = num_hidden_layers
|
128 |
+
self.num_attention_heads = num_attention_heads
|
129 |
+
self.hidden_act = hidden_act
|
130 |
+
self.intermediate_size = intermediate_size
|
131 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
132 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
133 |
+
self.max_position_embeddings = max_position_embeddings
|
134 |
+
self.type_vocab_size = type_vocab_size
|
135 |
+
self.initializer_range = initializer_range
|
136 |
+
self.layer_norm_eps = layer_norm_eps
|
137 |
+
self.position_embedding_type = position_embedding_type
|
138 |
+
self.use_cache = use_cache
|
139 |
+
self.classifier_dropout = classifier_dropout
|
140 |
+
|
141 |
+
|
142 |
+
class HATOnnxConfig(OnnxConfig):
|
143 |
+
@property
|
144 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
145 |
+
return OrderedDict(
|
146 |
+
[
|
147 |
+
("input_ids", {0: "batch", 1: "sequence"}),
|
148 |
+
("attention_mask", {0: "batch", 1: "sequence"}),
|
149 |
+
]
|
150 |
+
)
|
modelling_hat.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenization_hat.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
"""Tokenization classes for HAT."""
|
14 |
+
import torch
|
15 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
16 |
+
from .configuration_hat import HATConfig
|
17 |
+
from transformers.utils import logging
|
18 |
+
try:
|
19 |
+
from nltk import sent_tokenize
|
20 |
+
except:
|
21 |
+
raise Exception('NLTK is not installed! Install it with `pip install nltk`...')
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class HATTokenizer:
|
26 |
+
def __init__(self, tokenizer=None):
|
27 |
+
self._tokenizer = tokenizer
|
28 |
+
self.config = HATConfig.from_pretrained(self._tokenizer.name_or_path)
|
29 |
+
self._tokenizer.model_max_length = self.model_max_length
|
30 |
+
self.type2id = {'input_ids': (self._tokenizer.cls_token_id, self._tokenizer.pad_token_id),
|
31 |
+
'token_type_ids': (0, 0),
|
32 |
+
'attention_mask': (1, 0),
|
33 |
+
'special_tokens_mask': (1, -100)}
|
34 |
+
|
35 |
+
@property
|
36 |
+
def model_max_length(self):
|
37 |
+
return self.config.model_max_length
|
38 |
+
|
39 |
+
@property
|
40 |
+
def mask_token(self):
|
41 |
+
return self._tokenizer.mask_token
|
42 |
+
|
43 |
+
@property
|
44 |
+
def mask_token_id(self):
|
45 |
+
return self._tokenizer.mask_token_id
|
46 |
+
|
47 |
+
@property
|
48 |
+
def pad_token_id(self):
|
49 |
+
return self._tokenizer.pad_token_id
|
50 |
+
|
51 |
+
@property
|
52 |
+
def cls_token_id(self):
|
53 |
+
return self._tokenizer.cls_token_id
|
54 |
+
|
55 |
+
@property
|
56 |
+
def sep_token_id(self):
|
57 |
+
return self._tokenizer.sep_token_id
|
58 |
+
|
59 |
+
@property
|
60 |
+
def vocab(self):
|
61 |
+
return self._tokenizer.vocab
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
"""
|
65 |
+
Size of the full vocabulary with the added tokens.
|
66 |
+
"""
|
67 |
+
return len(self._tokenizer)
|
68 |
+
|
69 |
+
def pad(self, *args, **kwargs):
|
70 |
+
return self._tokenizer.pad(*args, **kwargs)
|
71 |
+
|
72 |
+
def convert_tokens_to_ids(self, *args, **kwargs):
|
73 |
+
return self._tokenizer.convert_tokens_to_ids(*args, **kwargs)
|
74 |
+
|
75 |
+
def batch_decode(self, *args, **kwargs):
|
76 |
+
return self._tokenizer.batch_decode(*args, **kwargs)
|
77 |
+
|
78 |
+
def decode(self, *args, **kwargs):
|
79 |
+
return self._tokenizer.decode(*args, **kwargs)
|
80 |
+
|
81 |
+
def tokenize(self, text, **kwargs):
|
82 |
+
return self._tokenizer.tokenize(text, **kwargs)
|
83 |
+
|
84 |
+
def encode(self, text, **kwargs):
|
85 |
+
input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs)
|
86 |
+
input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences],
|
87 |
+
chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids'])
|
88 |
+
return input_ids
|
89 |
+
|
90 |
+
def get_special_tokens_mask(self, *args, **kwargs):
|
91 |
+
return self._tokenizer.get_special_tokens_mask(*args, **kwargs)
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
95 |
+
try:
|
96 |
+
tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
97 |
+
except:
|
98 |
+
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
99 |
+
return cls(tokenizer=tokenizer)
|
100 |
+
|
101 |
+
def save_pretrained(self, *args, **kwargs):
|
102 |
+
return self._tokenizer.save_pretrained( *args, **kwargs)
|
103 |
+
|
104 |
+
def __call__(self, text, **kwargs):
|
105 |
+
greedy_chunking = kwargs.pop('greedy_chunking', None)
|
106 |
+
text_pair = kwargs.pop('text_pair', None)
|
107 |
+
if isinstance(text[0], list):
|
108 |
+
batch = self.auto_chunking(text, **kwargs)
|
109 |
+
elif greedy_chunking:
|
110 |
+
# fixed uniform chunking
|
111 |
+
batch = self.uniform_chunking(text, **kwargs)
|
112 |
+
else:
|
113 |
+
# dynamic sentence splitting and grouping
|
114 |
+
batch = self.sentence_splitting(text, **kwargs)
|
115 |
+
|
116 |
+
if text_pair:
|
117 |
+
batch_b = self._tokenizer(text_pair, add_special_tokens=False,
|
118 |
+
padding=False, truncation=False)
|
119 |
+
for idx, sample in enumerate(batch['input_ids']):
|
120 |
+
n_sentences = sum(sample[::self.config.max_sentence_size])
|
121 |
+
for input_key in batch:
|
122 |
+
batch[input_key][idx][self.config.max_sentence_size * n_sentences:
|
123 |
+
self.config.max_sentence_size * (n_sentences + 1)] = \
|
124 |
+
self.pad_sentence(batch_b[input_key][idx],
|
125 |
+
special_id=(self.sep_token_id, self.pad_token_id)
|
126 |
+
if input_key == 'input_ids' else self.type2id[input_key])
|
127 |
+
|
128 |
+
return batch
|
129 |
+
|
130 |
+
def uniform_chunking(self, texts, **kwargs):
|
131 |
+
original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs)
|
132 |
+
batch = {input_type: [] for input_type in original_batch}
|
133 |
+
for input_type in original_batch:
|
134 |
+
fixed_batch = []
|
135 |
+
for example in original_batch[input_type]:
|
136 |
+
fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences],
|
137 |
+
chunk_size=self.config.max_sentence_length,
|
138 |
+
special_id=self.type2id[input_type]))
|
139 |
+
batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch)
|
140 |
+
|
141 |
+
if kwargs['padding']:
|
142 |
+
batch = self.pad(batch,
|
143 |
+
padding=kwargs['padding'],
|
144 |
+
max_length=kwargs['max_length'],
|
145 |
+
pad_to_multiple_of=kwargs['max_length'])
|
146 |
+
|
147 |
+
return batch
|
148 |
+
|
149 |
+
def auto_chunking(self, texts, **kwargs):
|
150 |
+
batch = {}
|
151 |
+
for text_idx, text in enumerate(texts):
|
152 |
+
example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs)
|
153 |
+
for input_key in example_batch:
|
154 |
+
key_inputs_list = []
|
155 |
+
for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]):
|
156 |
+
key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key]))
|
157 |
+
if isinstance(key_inputs_list[0], list):
|
158 |
+
key_inputs_list = [token for sentence in key_inputs_list for token in sentence]
|
159 |
+
else:
|
160 |
+
key_inputs_list = torch.stack(key_inputs_list)
|
161 |
+
if input_key in batch:
|
162 |
+
batch[input_key].append(key_inputs_list)
|
163 |
+
else:
|
164 |
+
batch[input_key] = [key_inputs_list]
|
165 |
+
|
166 |
+
if kwargs['padding']:
|
167 |
+
batch = self.pad(batch,
|
168 |
+
padding=kwargs['padding'],
|
169 |
+
max_length=kwargs['max_length'],
|
170 |
+
pad_to_multiple_of=kwargs['max_length'])
|
171 |
+
|
172 |
+
return batch
|
173 |
+
|
174 |
+
def chunks(self, flat_inputs, chunk_size=128, special_id=0):
|
175 |
+
if isinstance(flat_inputs, list):
|
176 |
+
return self.list_chunks(flat_inputs, chunk_size, special_id)
|
177 |
+
else:
|
178 |
+
return self.tensor_chunks(flat_inputs, chunk_size, special_id)
|
179 |
+
|
180 |
+
def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
|
181 |
+
"""Yield successive n-sized chunks from lst."""
|
182 |
+
structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]]
|
183 |
+
+ flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)]
|
184 |
+
return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs]
|
185 |
+
|
186 |
+
def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
|
187 |
+
"""Yield successive n-sized chunks from lst."""
|
188 |
+
structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int),
|
189 |
+
flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)])
|
190 |
+
return structured_inputs.reshape(-1)
|
191 |
+
|
192 |
+
def sentence_splitting(self, texts, **kwargs):
|
193 |
+
fixed_batch = []
|
194 |
+
doc_out = {}
|
195 |
+
for text in texts:
|
196 |
+
# sentence splitting
|
197 |
+
sentences = sent_tokenize(text)
|
198 |
+
# tokenization of sentences
|
199 |
+
sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False)
|
200 |
+
# sentence grouping - merging short sentences to minimize padding
|
201 |
+
doc_out = self.sentence_grouping(sentences)
|
202 |
+
fixed_batch.append(doc_out)
|
203 |
+
# batchify examples
|
204 |
+
batch = {input_type: [] for input_type in doc_out}
|
205 |
+
for input_type in batch:
|
206 |
+
batch[input_type] = [example[input_type] for example in fixed_batch]
|
207 |
+
if not isinstance(batch[input_type][0], list):
|
208 |
+
batch[input_type] = torch.stack(batch[input_type])
|
209 |
+
|
210 |
+
if kwargs['padding']:
|
211 |
+
batch = self.pad(batch,
|
212 |
+
padding=kwargs['padding'],
|
213 |
+
max_length=kwargs['max_length'],
|
214 |
+
pad_to_multiple_of=kwargs['max_length'])
|
215 |
+
|
216 |
+
return batch
|
217 |
+
|
218 |
+
def sentence_grouping(self, sentences):
|
219 |
+
doc_out = {input_type: [] for input_type in sentences}
|
220 |
+
for input_type in sentences:
|
221 |
+
tmp_doc = []
|
222 |
+
tmp_sentence = []
|
223 |
+
for example in sentences[input_type]:
|
224 |
+
if len(tmp_doc) >= self.config.max_sentences:
|
225 |
+
break
|
226 |
+
if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1:
|
227 |
+
tmp_sentence.extend(example)
|
228 |
+
else:
|
229 |
+
tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example,
|
230 |
+
chunk_size=self.config.max_sentence_length,
|
231 |
+
special_id=self.type2id[input_type]))
|
232 |
+
tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:]
|
233 |
+
if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences:
|
234 |
+
tmp_doc.append(self.pad_sentence(tmp_sentence,
|
235 |
+
chunk_size=self.config.max_sentence_length,
|
236 |
+
special_id=self.type2id[input_type]))
|
237 |
+
doc_out[input_type] = [token for sentence in tmp_doc for token in sentence]
|
238 |
+
return doc_out
|
239 |
+
|
240 |
+
def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)):
|
241 |
+
if isinstance(flat_input, list):
|
242 |
+
return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1)
|
243 |
+
else:
|
244 |
+
return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum()
|
245 |
+
else special_id[1]], dtype=torch.int),
|
246 |
+
flat_input[:chunk_size-1],
|
247 |
+
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
248 |
+
))
|
249 |
+
|