nthngdy commited on
Commit
56d4bda
·
1 Parent(s): 59d4c00

Create modeling_manta.py

Browse files
Files changed (1) hide show
  1. modeling_manta.py +1039 -0
modeling_manta.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Mesh TensorFlow authors, Manta Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Manta model."""
16
+
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ import warnings
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput, Seq2SeqModelOutput
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.models.longformer import LongformerConfig, LongformerModel
30
+ from transformers.models.t5.configuration_t5 import T5Config
31
+ from transformers.models.t5.modeling_t5 import (
32
+ __HEAD_MASK_WARNING_MSG,
33
+ T5Attention,
34
+ T5Stack,
35
+ )
36
+ from transformers.utils import (
37
+ DUMMY_INPUTS,
38
+ DUMMY_MASK,
39
+ add_start_docstrings,
40
+ add_end_docstrings,
41
+ is_torch_fx_proxy,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from configuration_manta import MantaConfig
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ _CONFIG_FOR_DOC = "MantaConfig"
51
+ _TOKENIZER_FOR_DOC = "ByT5Tokenizer"
52
+
53
+ MANTA_PRETRAINED_MODEL_ARCHIVE_LIST = []
54
+
55
+
56
+ def gaussian_pdf(x):
57
+ return torch.exp(-x * x / 2.0)
58
+
59
+
60
+ def pad_block_embeddings(block_embeddings, pad_length):
61
+ if not pad_length:
62
+ return block_embeddings
63
+
64
+ padding_tensor_len = max(pad_length - block_embeddings.size(1), 0)
65
+
66
+ padding_tensor = torch.zeros(
67
+ (block_embeddings.size(0), padding_tensor_len, block_embeddings.size(2)),
68
+ device=block_embeddings.device,
69
+ dtype=block_embeddings.dtype,
70
+ )
71
+ return torch.cat([block_embeddings[:, :pad_length, :], padding_tensor], dim=1)
72
+
73
+
74
+ @add_end_docstrings()
75
+ @dataclass
76
+ class MantaSeq2SeqLMOutput(Seq2SeqLMOutput):
77
+ """
78
+ Base class for Manta encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
79
+ decoding.
80
+
81
+ Args:
82
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
83
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
84
+
85
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
86
+ hidden_size)` is output.
87
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
88
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
89
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
90
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
91
+
92
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
93
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
94
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
95
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
96
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
97
+
98
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
99
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
100
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
101
+ sequence_length)`.
102
+
103
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
104
+ self-attention heads.
105
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
106
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
107
+ sequence_length)`.
108
+
109
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
110
+ weighted average in the cross-attention heads.
111
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
112
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
113
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
114
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
115
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
116
+
117
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
118
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
119
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
120
+ sequence_length)`.
121
+
122
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
123
+ self-attention heads.
124
+ frontier_predictions: (`torch.FloatTensor`, *optional*, of shape `(batch_size, sequence_length, 1)`):
125
+ Probability scores of being a frontier as predicted by the FrontierPredictor module.
126
+ """
127
+
128
+ frontier_predictions: Optional[torch.FloatTensor] = None
129
+
130
+
131
+ @dataclass
132
+ class MantaBaseModelOutput(BaseModelOutput):
133
+ """
134
+ Base class for Manta's outputs, with potential hidden states, attentions and Manta's frontier predictions.
135
+
136
+ Args:
137
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
138
+ Sequence of hidden-states at the output of the last layer of the model.
139
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
140
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
141
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
142
+
143
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
144
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
145
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
146
+ sequence_length)`.
147
+
148
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
149
+ heads.
150
+ frontier_predictions: (`torch.FloatTensor`, *optional*, of shape `(batch_size, sequence_length, 1)`):
151
+ Probability scores of being a frontier as predicted by the FrontierPredictor module.
152
+ """
153
+
154
+ frontier_predictions: Optional[torch.FloatTensor] = None
155
+
156
+
157
+ class MantaFrontierPredictor(nn.Module):
158
+ def __init__(
159
+ self,
160
+ hidden_size,
161
+ num_layers,
162
+ num_attention_heads,
163
+ dropout_rate,
164
+ attention_window,
165
+ max_length,
166
+ ):
167
+ super().__init__()
168
+
169
+ # First, find out what the maximum position will be after tensors are padded to a multiple of local_transformer_attention_window.
170
+ # Then, add 1 because LongFormer position embeddings are bugged when passed inputs_embeds.
171
+ max_position_embeddings = (max_length // attention_window + 1) * attention_window + 1
172
+ self.hidden_size = hidden_size
173
+
174
+ self.config = LongformerConfig(
175
+ attention_probs_dropout_prob=dropout_rate,
176
+ attention_window=attention_window,
177
+ hidden_act="gelu",
178
+ hidden_dropout_prob=dropout_rate,
179
+ hidden_size=hidden_size,
180
+ intermediate_size=hidden_size * 4,
181
+ max_position_embeddings=max_position_embeddings,
182
+ num_attention_heads=num_attention_heads,
183
+ num_hidden_layers=num_layers,
184
+ position_embedding_type="absolute", # Actually cannot be changed
185
+ vocab_size=1, # Remove almost entirely the embeddings
186
+ pad_token_id=0,
187
+ )
188
+ self.local_transformer = LongformerModel(self.config)
189
+
190
+ self.output_projection = nn.Linear(hidden_size, 1)
191
+
192
+ def forward(self, embeddings, attention_mask):
193
+ longformer_output = self.local_transformer(inputs_embeds=embeddings, attention_mask=attention_mask)
194
+
195
+ projection_outputs = self.output_projection(longformer_output.last_hidden_state)
196
+
197
+ frontier_predictions = torch.sigmoid(projection_outputs.squeeze(-1))
198
+
199
+ return frontier_predictions
200
+
201
+
202
+ class MantaConvFeatures(nn.Module):
203
+ def __init__(
204
+ self,
205
+ in_channels,
206
+ out_channels,
207
+ kernel_size,
208
+ groups,
209
+ padding,
210
+ ):
211
+ """
212
+ This nn.Module "decomposes" the convolution in order to extract and cache feature maps. This amounts to
213
+ computing an element-wise multiplication between weights of size (hidden_dim, kernel_size) and the input.
214
+ """
215
+ super().__init__()
216
+ self.in_channels = in_channels
217
+ self.out_channels = out_channels
218
+ self.kernel_size = kernel_size
219
+ self.groups = groups
220
+ self.padding = padding
221
+
222
+ if groups == in_channels:
223
+ assert (
224
+ in_channels == out_channels
225
+ ), "When using `groups = in_channels`, make sure to have `in_channels == out_channels`"
226
+ self.weight = nn.Parameter(torch.Tensor(1, 1, kernel_size, out_channels))
227
+ elif self.groups == 1:
228
+ self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels, kernel_size))
229
+ else:
230
+ raise ValueError("MantaConvFeatures only supports `groups = 1` or `groups = in_channels`")
231
+
232
+ left_pad = (kernel_size - 1) // 2
233
+ self.pad = (left_pad, kernel_size - 1 - left_pad)
234
+
235
+ self.reset_parameters()
236
+
237
+ def reset_parameters(self):
238
+ """
239
+ See https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv1d, in the `_ConvNd` class :
240
+ > Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
241
+ > uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
242
+ > For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573"
243
+
244
+ The reason we permute the weights before init is because `kaiming_uniform_` uses the number of in and out
245
+ features for initialization, which are computed as tensor.size(0) and tensor.size(1). However, these
246
+ dimensions do not correspond for my weights.
247
+ """
248
+ if self.groups == self.out_channels:
249
+ nn.init.kaiming_uniform_(self.weight.permute(3, 0, 1, 2), a=math.sqrt(5))
250
+ else:
251
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
252
+
253
+ def forward(self, x: torch.Tensor):
254
+ if self.groups == 1:
255
+ return self.forward_matmul(x)
256
+ else:
257
+ return self.forward_elementwise(x)
258
+
259
+ def forward_matmul(self, x: torch.Tensor):
260
+
261
+ if self.padding == "same":
262
+ padded_x = self._pad_pre_conv(x)
263
+ else:
264
+ padded_x = x
265
+
266
+ bs, _, seq_len = padded_x.size()
267
+
268
+ padded_x = padded_x.transpose(-1, -2)
269
+ # Size: (bs, seq_len+pad, hidden)
270
+
271
+ out = padded_x.matmul(self.weight.view(self.weight.size(0), -1)).view(bs, seq_len, self.out_channels, -1)
272
+ # Size: (bs, seq_len+pad, hidden, kernel_size)
273
+
274
+ return out.permute(0, 2, 3, 1)
275
+
276
+ def forward_elementwise(self, x: torch.Tensor):
277
+ assert len(x.size()) == 3
278
+ assert x.size(1) == self.out_channels
279
+ # Size: (bs, hidden, seq_len)
280
+
281
+ if self.padding == "same":
282
+ padded_x = self._pad_pre_conv(x)
283
+ else:
284
+ padded_x = x
285
+
286
+ # Unsqueeze for broadcasting with the kernel_size dim of the filters
287
+ padded_x = padded_x.transpose(-1, -2).unsqueeze(2)
288
+ # Size: (bs, seq_len, 1, hidden)
289
+
290
+ out = padded_x * self.weight
291
+ # Size: (bs, seq_len, kernel_size, hidden)
292
+
293
+ return out.transpose(1, 3)
294
+
295
+ def _pad_pre_conv(self, inp: torch.Tensor):
296
+ """
297
+ Pad with zeros at the beginning and end just like `nn.Conv1d`.
298
+ """
299
+ return nn.functional.pad(inp, self.pad, "constant", 0.0)
300
+
301
+ def extra_repr(self):
302
+ return "in_features={}, out_features={}, kernel_size={}, groups={}".format(
303
+ self.in_channels, self.out_channels, self.kernel_size, self.groups
304
+ )
305
+
306
+
307
+ class MantaCachedConvolutionPooling(nn.Module):
308
+ def __init__(
309
+ self,
310
+ padding_length,
311
+ output_dim,
312
+ kernel_size,
313
+ hidden_dim,
314
+ depthwise_convolution,
315
+ variance_regularization,
316
+ mean_pool,
317
+ ):
318
+ super().__init__()
319
+ self.padding_length = padding_length
320
+ self.output_dim = output_dim
321
+ self.kernel_size = kernel_size
322
+ self.hidden_dim = hidden_dim
323
+ self.depthwise_convolution = depthwise_convolution
324
+ self.variance_regularization = variance_regularization
325
+ self.mean_pool = mean_pool
326
+
327
+ if isinstance(self.kernel_size, int):
328
+ self.kernel_size = [[self.kernel_size, hidden_dim]]
329
+
330
+ self.conv_output_dim = sum([k_dim[1] for k_dim in self.kernel_size])
331
+
332
+ # Since the sum of the hidden dimensions of all the filters might not match the language model hidden size, we
333
+ # specify it here
334
+ self.out_projection = nn.Linear(self.conv_output_dim, self.output_dim, bias=True)
335
+
336
+ self.conv_layers = nn.Sequential(
337
+ *[
338
+ MantaConvFeatures(self.hidden_dim, h, k, groups=h if self.depthwise_convolution else 1, padding="same")
339
+ for (k, h) in self.kernel_size
340
+ ]
341
+ )
342
+
343
+ self.eps = None
344
+ self.conv_layer = None
345
+
346
+ def forward(self, unconstrained_separation_probs: torch.Tensor, byte_embeddings: torch.Tensor):
347
+ device = unconstrained_separation_probs.device
348
+ if self.eps is None:
349
+ self.eps = 5 * torch.finfo(unconstrained_separation_probs.dtype).resolution
350
+ self.variance_regularization = max(self.eps, self.variance_regularization)
351
+
352
+ if self.conv_layer is not None:
353
+ self.conv_layer = self.conv_layer.to(device)
354
+ batch_size, seq_len = byte_embeddings.shape[:2]
355
+
356
+ # We set the probability of the first token to be 0 therwise the cumsum will not work
357
+ separation_probs = unconstrained_separation_probs.clone()
358
+ separation_probs[:, 0] = 0
359
+
360
+ assert separation_probs.shape == (batch_size, seq_len)
361
+
362
+ # Compute the moments of the block_id random variable
363
+ block_id_expectation = separation_probs.cumsum(axis=-1)
364
+ block_id_std = torch.sqrt(
365
+ (separation_probs * (1.0 - separation_probs)).cumsum(axis=-1) + self.variance_regularization
366
+ )
367
+
368
+ # Get the maximum number of blocks
369
+ max_nb_blocks = min(seq_len, (block_id_expectation + 3 * block_id_std).max().int().item() + 1)
370
+ possible_blocks_id = torch.arange(max_nb_blocks).to(device)
371
+
372
+ # Get the block/byte proba using the Gaussian PDF
373
+ log_scale = block_id_std[:, None, :].log()
374
+ log_proba = (
375
+ -((block_id_expectation[:, None, :] - possible_blocks_id[None, :, None]) ** 2)
376
+ / (2 * block_id_std[:, None, :])
377
+ - log_scale
378
+ - math.log((2 * math.pi) ** 0.5)
379
+ )
380
+ block_byte_proba = log_proba.softmax(-2)
381
+
382
+ token_size = block_byte_proba.sum(-1, keepdim=True)
383
+ regularized_token_size = torch.maximum(token_size, torch.ones_like(token_size))
384
+
385
+ if self.mean_pool:
386
+ block_byte_proba_normalized = block_byte_proba / regularized_token_size
387
+ else:
388
+ # Makes no sense to regularize using sequence length in the max_pooling case.
389
+ block_byte_proba_normalized = block_byte_proba
390
+
391
+ block_embeddings = self.pooling(byte_embeddings, block_byte_proba_normalized)
392
+
393
+ pad_length = min(self.padding_length, max_nb_blocks)
394
+
395
+ block_embeddings = pad_block_embeddings(block_embeddings, pad_length)
396
+ block_embeddings = self.out_projection(block_embeddings)
397
+
398
+ return block_embeddings
399
+
400
+ def pooling(self, embeddings: torch.Tensor, block_byte_proba: torch.Tensor):
401
+ block_embeddings = []
402
+
403
+ for conv_layer in self.conv_layers:
404
+ # First, compute the convolution maps SEPARATELY, i.e. without summing them together, only the element wise multiplication
405
+ # This is similar to a cache that we'll reuse for each block probabilities.
406
+ features = conv_layer(embeddings.transpose(1, 2)).permute(0, 3, 1, 2)
407
+ # Size : (batch_size, seq_len + padding, hidden_dim, kernel_size)
408
+
409
+ pad = conv_layer.pad
410
+
411
+ for i in range(0, conv_layer.kernel_size):
412
+ # We shift like that to match the padding done inside `conv_layer`
413
+ features[..., i] = features[..., i].roll(pad[0] - i, 1)
414
+ # Cut out the padded vector to obtain the right sequence length at the end
415
+ features = features[:, pad[1] : features.size(1) - pad[0]]
416
+ # Size : (batch_size, seq_len, hidden_dim, kernel_size)
417
+
418
+ # Then, artificially sum the convolution features by shifting the input bytes
419
+ padded_block_byte_proba = nn.functional.pad(block_byte_proba, pad, "constant", 0.0)
420
+ expanded_block_byte_proba = []
421
+ for i in range(0, conv_layer.kernel_size):
422
+ rolled_proba = padded_block_byte_proba.clone().roll(pad[0] - i, -1)
423
+ expanded_block_byte_proba.append(rolled_proba)
424
+ expanded_block_byte_proba = torch.stack(expanded_block_byte_proba, -1)
425
+ # We use :tensor.size(2) - pad instead of just :-pad because if pad = 0, we have an undesired behaviour where the whole sequence is removed
426
+ expanded_block_byte_proba = expanded_block_byte_proba[
427
+ :, :, pad[1] : expanded_block_byte_proba.size(2) - pad[0], :
428
+ ]
429
+ # Size : (batch_size, block_size, seq_len, kernel_size)
430
+
431
+ if self.mean_pool:
432
+ convolved = torch.einsum("b s h k, b B s k -> b B h", features, expanded_block_byte_proba)
433
+ else:
434
+ convolved = torch.einsum("b s h k, b B s k -> b B s h", features, expanded_block_byte_proba)
435
+ convolved = convolved.max(dim=-2).values
436
+
437
+ block_embeddings.append(convolved)
438
+
439
+ block_embeddings = torch.cat(block_embeddings, dim=-1)
440
+
441
+ return block_embeddings
442
+
443
+
444
+ class MantaPreTrainedModel(PreTrainedModel):
445
+ """
446
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
447
+ models.
448
+ """
449
+
450
+ config_class = MantaConfig
451
+ base_model_prefix = "transformer"
452
+ supports_gradient_checkpointing = True
453
+
454
+ def _init_weights(self, module):
455
+ """Initialize the weights"""
456
+ pass
457
+
458
+ def _set_gradient_checkpointing(self, module, value=False):
459
+ if isinstance(module, (T5Attention, T5Stack)):
460
+ module.gradient_checkpointing = value
461
+
462
+ def _shift_right(self, input_ids):
463
+ decoder_start_token_id = self.config.decoder_start_token_id
464
+ pad_token_id = self.config.pad_token_id
465
+
466
+ assert decoder_start_token_id is not None, (
467
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
468
+ " See T5 docs for more information"
469
+ )
470
+
471
+ # shift inputs to the right
472
+ if is_torch_fx_proxy(input_ids):
473
+ # Item assignment is not supported natively for proxies.
474
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
475
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
476
+ else:
477
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
478
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
479
+ shifted_input_ids[..., 0] = decoder_start_token_id
480
+
481
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
482
+ # replace possible -100 values in labels by `pad_token_id`
483
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
484
+
485
+ return shifted_input_ids
486
+
487
+
488
+ @add_start_docstrings(
489
+ "The bare Manta Model transformer outputting encoder's raw hidden-states without any specific head on top."
490
+ )
491
+ class MantaEncoderModel(MantaPreTrainedModel):
492
+ authorized_missing_keys = [
493
+ r"encoder.embed_tokens.weight",
494
+ ]
495
+
496
+ def __init__(self, config: MantaConfig):
497
+ super().__init__(config)
498
+ self.byte_embeddings = nn.Embedding(config.vocab_size, config.byte_embedding_dim)
499
+
500
+ self.frontier_predictor = MantaFrontierPredictor(
501
+ hidden_size=config.byte_embedding_dim,
502
+ num_layers=config.frontier_predictor_num_layers,
503
+ num_attention_heads=config.frontier_predictor_num_attention_heads,
504
+ dropout_rate=config.dropout_rate,
505
+ attention_window=config.frontier_predictor_attention_window,
506
+ max_length=config.max_length_inputs,
507
+ )
508
+
509
+ self.pooler = MantaCachedConvolutionPooling(
510
+ padding_length=config.max_length_encoder_decoder,
511
+ output_dim=config.d_model,
512
+ kernel_size=config.pooling_kernel_size,
513
+ hidden_dim=config.byte_embedding_dim,
514
+ depthwise_convolution=config.pooling_depthwise_convolution,
515
+ variance_regularization=config.pooling_variance_regularization,
516
+ mean_pool=config.pooling_mean_pool,
517
+ )
518
+
519
+ self.t5_encoder = T5Stack(
520
+ T5Config(
521
+ d_model=config.d_model,
522
+ d_kv=config.d_kv,
523
+ d_ff=config.d_ff,
524
+ num_layers=config.num_layers,
525
+ num_heads=config.num_heads,
526
+ relative_attention_num_buckets=config.relative_attention_num_buckets,
527
+ relative_attention_max_distance=config.relative_attention_max_distance,
528
+ dropout_rate=config.dropout_rate,
529
+ layer_norm_epsilon=config.layer_norm_epsilon,
530
+ initializer_factor=config.initializer_factor,
531
+ feed_forward_proj=config.feed_forward_proj,
532
+ pad_token_id=config.pad_token_id,
533
+ eos_token_id=config.eos_token_id,
534
+ is_decoder=False,
535
+ use_cache=False,
536
+ )
537
+ )
538
+
539
+ # Initialize weights and apply final processing
540
+ self.post_init()
541
+
542
+ def get_input_embeddings(self):
543
+ return self.byte_embeddings
544
+
545
+ def set_input_embeddings(self, new_embeddings):
546
+ self.byte_embeddings = new_embeddings
547
+
548
+ def _prune_heads(self, heads_to_prune):
549
+ """
550
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
551
+ class PreTrainedModel
552
+ """
553
+ for layer, heads in heads_to_prune.items():
554
+ self.t5_encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
555
+
556
+ def _compute_pooled_representations(
557
+ self,
558
+ input_ids: Optional[torch.LongTensor] = None,
559
+ attention_mask: Optional[torch.FloatTensor] = None,
560
+ inputs_embeds: Optional[torch.FloatTensor] = None,
561
+ ):
562
+ if inputs_embeds is None and input_ids is None:
563
+ return None
564
+
565
+ byte_embeddings = inputs_embeds if inputs_embeds is not None else self.byte_embeddings(input_ids)
566
+
567
+ frontier_predictions = self.frontier_predictor(byte_embeddings, attention_mask)
568
+
569
+ pooled_representations = self.pooler(frontier_predictions, byte_embeddings)
570
+
571
+ return pooled_representations, frontier_predictions
572
+
573
+ @replace_return_docstrings(output_type=MantaBaseModelOutput, config_class=_CONFIG_FOR_DOC)
574
+ def forward(
575
+ self,
576
+ input_ids: Optional[torch.LongTensor] = None,
577
+ attention_mask: Optional[torch.FloatTensor] = None,
578
+ head_mask: Optional[torch.FloatTensor] = None,
579
+ inputs_embeds: Optional[torch.FloatTensor] = None,
580
+ output_attentions: Optional[bool] = None,
581
+ output_hidden_states: Optional[bool] = None,
582
+ return_dict: Optional[bool] = None,
583
+ ) -> Union[Tuple[torch.FloatTensor], MantaBaseModelOutput]:
584
+ r"""
585
+ Returns:
586
+
587
+ Example:
588
+
589
+ ```python
590
+ >>> from transformers import ByT5Tokenizer, MantaEncoderModel
591
+
592
+ >>> tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
593
+ >>> model = MantaEncoderModel.from_pretrained("nthngdy/manta-small")
594
+ >>> input_ids = tokenizer(
595
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
596
+ ... ).input_ids # Batch size 1
597
+ >>> outputs = model(input_ids=input_ids)
598
+ >>> last_hidden_states = outputs.last_hidden_state
599
+ ```"""
600
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
601
+ output_hidden_states = (
602
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
603
+ )
604
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
605
+
606
+ pooled_representations, frontier_predictions = self._compute_pooled_representations(
607
+ input_ids, attention_mask, inputs_embeds
608
+ )
609
+
610
+ encoder_outputs = self.t5_encoder(
611
+ inputs_embeds=pooled_representations,
612
+ head_mask=head_mask,
613
+ output_attentions=output_attentions,
614
+ output_hidden_states=output_hidden_states,
615
+ return_dict=return_dict,
616
+ )
617
+
618
+ if not return_dict:
619
+ return encoder_outputs + (frontier_predictions,)
620
+
621
+ return MantaBaseModelOutput(frontier_predictions=frontier_predictions, **encoder_outputs)
622
+
623
+
624
+ class MantaModel(MantaPreTrainedModel):
625
+ _keys_to_ignore_on_load_missing = [
626
+ r"encoder_decoder.encoder.embed_tokens.weight",
627
+ r"encoder_decoder.decoder.embed_tokens.weight",
628
+ ]
629
+ _keys_to_ignore_on_load_unexpected = [
630
+ r"encoder_decoder.decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
631
+ ]
632
+
633
+ def __init__(self, config: MantaConfig):
634
+ super().__init__(config)
635
+
636
+ self.encoder = MantaEncoderModel(config)
637
+
638
+ self.decoder_embeddings = nn.Embedding(config.vocab_size, config.d_model)
639
+ self.decoder = T5Stack(
640
+ T5Config(
641
+ vocab_size=config.vocab_size,
642
+ d_model=config.d_model,
643
+ d_kv=config.d_kv,
644
+ d_ff=config.d_ff,
645
+ num_layers=config.num_decoder_layers,
646
+ num_heads=config.num_heads,
647
+ relative_attention_num_buckets=config.relative_attention_num_buckets,
648
+ relative_attention_max_distance=config.relative_attention_max_distance,
649
+ dropout_rate=config.dropout_rate,
650
+ layer_norm_epsilon=config.layer_norm_epsilon,
651
+ initializer_factor=config.initializer_factor,
652
+ feed_forward_proj=config.feed_forward_proj,
653
+ use_cache=config.use_cache,
654
+ pad_token_id=config.pad_token_id,
655
+ eos_token_id=config.eos_token_id,
656
+ is_decoder=True,
657
+ is_encoder_decoder=False,
658
+ ),
659
+ self.decoder_embeddings,
660
+ )
661
+
662
+ # Initialize weights and apply final processing
663
+ self.post_init()
664
+
665
+ def get_input_embeddings(self):
666
+ return self.encoder.get_input_embeddings()
667
+
668
+ def set_input_embeddings(self, new_embeddings):
669
+ self.encoder.set_input_embeddings(new_embeddings)
670
+
671
+ def get_encoder(self):
672
+ return self.encoder
673
+
674
+ def get_decoder(self):
675
+ return self.decoder
676
+
677
+ def _prune_heads(self, heads_to_prune):
678
+ """
679
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
680
+ class PreTrainedModel
681
+ """
682
+ for layer, heads in heads_to_prune.items():
683
+ self.encoder.layer[layer].attention.prune_heads(heads)
684
+
685
+ @replace_return_docstrings(output_type=MantaSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
686
+ def forward(
687
+ self,
688
+ input_ids: Optional[torch.LongTensor] = None,
689
+ attention_mask: Optional[torch.FloatTensor] = None,
690
+ decoder_input_ids: Optional[torch.LongTensor] = None,
691
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
692
+ head_mask: Optional[torch.FloatTensor] = None,
693
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
694
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
695
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
696
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
697
+ inputs_embeds: Optional[torch.Tensor] = None,
698
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
699
+ use_cache: Optional[bool] = None,
700
+ output_attentions: Optional[bool] = None,
701
+ output_hidden_states: Optional[bool] = None,
702
+ return_dict: Optional[bool] = None,
703
+ ) -> Union[Tuple[torch.FloatTensor], MantaSeq2SeqLMOutput]:
704
+ r"""
705
+ Returns:
706
+
707
+ Example:
708
+
709
+ ```python
710
+ >>> from transformers import ByT5Tokenizer, MantaModel
711
+
712
+ >>> tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
713
+ >>> model = MantaModel.from_pretrained("nthngdy/manta-small")
714
+
715
+ >>> input_ids = tokenizer(
716
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
717
+ ... ).input_ids # Batch size 1
718
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
719
+
720
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for MantaModel.
721
+ >>> # This is not needed for torch's MantaForConditionalGeneration as it does this internally using labels arg.
722
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
723
+
724
+ >>> # forward pass
725
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
726
+ >>> last_hidden_states = outputs.last_hidden_state
727
+ ```"""
728
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
729
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
730
+ output_hidden_states = (
731
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
732
+ )
733
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
734
+
735
+ if encoder_outputs is None:
736
+ encoder_outputs = self.encoder(
737
+ input_ids=input_ids,
738
+ attention_mask=attention_mask,
739
+ inputs_embeds=inputs_embeds,
740
+ head_mask=head_mask,
741
+ output_attentions=output_attentions,
742
+ output_hidden_states=output_hidden_states,
743
+ return_dict=return_dict,
744
+ )
745
+ elif return_dict and not isinstance(encoder_outputs, MantaBaseModelOutput):
746
+ encoder_outputs = MantaBaseModelOutput(
747
+ last_hidden_state=encoder_outputs[0],
748
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
749
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
750
+ frontier_predictions=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
751
+ )
752
+
753
+ hidden_states = encoder_outputs[0]
754
+
755
+ decoder_outputs = self.decoder(
756
+ input_ids=decoder_input_ids,
757
+ attention_mask=decoder_attention_mask,
758
+ encoder_hidden_states=hidden_states,
759
+ encoder_attention_mask=attention_mask,
760
+ inputs_embeds=decoder_inputs_embeds,
761
+ head_mask=decoder_head_mask,
762
+ cross_attn_head_mask=cross_attn_head_mask,
763
+ past_key_values=past_key_values,
764
+ use_cache=use_cache,
765
+ output_attentions=output_attentions,
766
+ output_hidden_states=output_hidden_states,
767
+ return_dict=return_dict,
768
+ )
769
+
770
+ if not return_dict:
771
+ return decoder_outputs + encoder_outputs
772
+
773
+ return MantaSeq2SeqLMOutput(
774
+ last_hidden_state=decoder_outputs.last_hidden_state,
775
+ past_key_values=decoder_outputs.past_key_values,
776
+ decoder_hidden_states=decoder_outputs.hidden_states,
777
+ decoder_attentions=decoder_outputs.attentions,
778
+ cross_attentions=decoder_outputs.cross_attentions,
779
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
780
+ encoder_hidden_states=encoder_outputs.hidden_states,
781
+ encoder_attentions=encoder_outputs.attentions,
782
+ frontier_predictions=encoder_outputs.frontier_predictions,
783
+ )
784
+
785
+
786
+ @add_start_docstrings("""Manta Model with a `language modeling` head on top.""")
787
+ class MantaForConditionalGeneration(MantaPreTrainedModel):
788
+ _keys_to_ignore_on_load_missing = [
789
+ r"encoder.embed_tokens.weight",
790
+ r"decoder.embed_tokens.weight",
791
+ r"lm_head.weight",
792
+ ]
793
+ _keys_to_ignore_on_load_unexpected = [
794
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
795
+ ]
796
+
797
+ def __init__(self, config: MantaConfig):
798
+ super().__init__(config)
799
+ self.model_dim = config.d_model
800
+
801
+ self.encoder = MantaEncoderModel(config)
802
+
803
+ self.decoder_embeddings = nn.Embedding(config.vocab_size, config.d_model)
804
+ self.decoder = T5Stack(
805
+ T5Config(
806
+ vocab_size=config.vocab_size,
807
+ d_model=config.d_model,
808
+ d_kv=config.d_kv,
809
+ d_ff=config.d_ff,
810
+ num_layers=config.num_decoder_layers,
811
+ num_heads=config.num_heads,
812
+ relative_attention_num_buckets=config.relative_attention_num_buckets,
813
+ relative_attention_max_distance=config.relative_attention_max_distance,
814
+ dropout_rate=config.dropout_rate,
815
+ layer_norm_epsilon=config.layer_norm_epsilon,
816
+ initializer_factor=config.initializer_factor,
817
+ feed_forward_proj=config.feed_forward_proj,
818
+ use_cache=config.use_cache,
819
+ pad_token_id=config.pad_token_id,
820
+ eos_token_id=config.eos_token_id,
821
+ is_decoder=True,
822
+ is_encoder_decoder=False,
823
+ ),
824
+ self.decoder_embeddings,
825
+ )
826
+
827
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
828
+
829
+ # Initialize weights and apply final processing
830
+ self.post_init()
831
+
832
+ def get_input_embeddings(self):
833
+ return self.encoder.get_input_embeddings()
834
+
835
+ def set_input_embeddings(self, new_embeddings):
836
+ self.encoder.set_input_embeddings(new_embeddings)
837
+
838
+ def set_output_embeddings(self, new_embeddings):
839
+ self.lm_head = new_embeddings
840
+
841
+ def get_output_embeddings(self):
842
+ return self.lm_head
843
+
844
+ def get_encoder(self):
845
+ return self.encoder
846
+
847
+ def get_decoder(self):
848
+ return self.decoder
849
+
850
+ @replace_return_docstrings(output_type=MantaSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
851
+ def forward(
852
+ self,
853
+ input_ids: Optional[torch.LongTensor] = None,
854
+ attention_mask: Optional[torch.FloatTensor] = None,
855
+ decoder_input_ids: Optional[torch.LongTensor] = None,
856
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
857
+ head_mask: Optional[torch.FloatTensor] = None,
858
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
859
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
860
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
861
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
862
+ inputs_embeds: Optional[torch.FloatTensor] = None,
863
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
864
+ labels: Optional[torch.LongTensor] = None,
865
+ use_cache: Optional[bool] = None,
866
+ output_attentions: Optional[bool] = None,
867
+ output_hidden_states: Optional[bool] = None,
868
+ return_dict: Optional[bool] = None,
869
+ ) -> Union[Tuple[torch.FloatTensor], MantaSeq2SeqLMOutput]:
870
+ r"""
871
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
872
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
873
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
874
+ labels in `[0, ..., config.vocab_size]`
875
+
876
+ Returns:
877
+
878
+ Examples:
879
+
880
+ ```python
881
+ >>> from transformers import ByT5Tokenizer, MantaForConditionalGeneration
882
+
883
+ >>> tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
884
+ >>> model = MantaForConditionalGeneration.from_pretrained("nthngdy/manta-small")
885
+
886
+ >>> # training
887
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
888
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
889
+ >>> outputs = model(input_ids=input_ids, labels=labels)
890
+ >>> loss = outputs.loss
891
+ >>> logits = outputs.logits
892
+
893
+ >>> # inference
894
+ >>> input_ids = tokenizer(
895
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
896
+ ... ).input_ids # Batch size 1
897
+ >>> outputs = model.generate(input_ids)
898
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
899
+ >>> # studies have shown that owning a dog is good for you.
900
+ ```"""
901
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
902
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
903
+ output_hidden_states = (
904
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
905
+ )
906
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
907
+
908
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
909
+ if head_mask is not None and decoder_head_mask is None:
910
+ if self.config.num_layers == self.config.num_decoder_layers:
911
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
912
+ decoder_head_mask = head_mask
913
+
914
+ # Encode if needed (training, first prediction pass)
915
+ if encoder_outputs is None:
916
+ encoder_outputs = self.encoder(
917
+ input_ids=input_ids,
918
+ attention_mask=attention_mask,
919
+ inputs_embeds=inputs_embeds,
920
+ head_mask=head_mask,
921
+ output_attentions=output_attentions,
922
+ output_hidden_states=output_hidden_states,
923
+ return_dict=return_dict,
924
+ )
925
+ elif return_dict and not isinstance(encoder_outputs, MantaBaseModelOutput):
926
+ encoder_outputs = BaseModelOutput(
927
+ last_hidden_state=encoder_outputs[0],
928
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
929
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
930
+ frontier_predictions=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
931
+ )
932
+
933
+ hidden_states = encoder_outputs[0]
934
+
935
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
936
+ # get decoder inputs from shifting lm labels to the right
937
+ decoder_input_ids = self._shift_right(labels)
938
+
939
+ # Decode
940
+ decoder_outputs = self.decoder(
941
+ input_ids=decoder_input_ids,
942
+ attention_mask=decoder_attention_mask,
943
+ inputs_embeds=decoder_inputs_embeds,
944
+ past_key_values=past_key_values,
945
+ encoder_hidden_states=hidden_states,
946
+ head_mask=decoder_head_mask,
947
+ cross_attn_head_mask=cross_attn_head_mask,
948
+ use_cache=use_cache,
949
+ output_attentions=output_attentions,
950
+ output_hidden_states=output_hidden_states,
951
+ return_dict=return_dict,
952
+ )
953
+
954
+ sequence_output = decoder_outputs[0]
955
+
956
+ if self.config.tie_word_embeddings:
957
+ # Rescale output before projecting on vocab
958
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
959
+ sequence_output = sequence_output * (self.model_dim**-0.5)
960
+
961
+ lm_logits = self.lm_head(sequence_output)
962
+
963
+ loss = None
964
+ if labels is not None:
965
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
966
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
967
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
968
+
969
+ if not return_dict:
970
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
971
+ return ((loss,) + output) if loss is not None else output
972
+
973
+ return MantaSeq2SeqLMOutput(
974
+ loss=loss,
975
+ logits=lm_logits,
976
+ past_key_values=decoder_outputs.past_key_values,
977
+ decoder_hidden_states=decoder_outputs.hidden_states,
978
+ decoder_attentions=decoder_outputs.attentions,
979
+ cross_attentions=decoder_outputs.cross_attentions,
980
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
981
+ encoder_hidden_states=encoder_outputs.hidden_states,
982
+ encoder_attentions=encoder_outputs.attentions,
983
+ frontier_predictions=encoder_outputs.frontier_predictions,
984
+ )
985
+
986
+ def prepare_inputs_for_generation(
987
+ self,
988
+ input_ids,
989
+ past=None,
990
+ attention_mask=None,
991
+ head_mask=None,
992
+ decoder_head_mask=None,
993
+ cross_attn_head_mask=None,
994
+ use_cache=None,
995
+ encoder_outputs=None,
996
+ **kwargs
997
+ ):
998
+
999
+ # cut decoder_input_ids if past is used
1000
+ if past is not None:
1001
+ input_ids = input_ids[:, -1:]
1002
+
1003
+ return {
1004
+ "decoder_input_ids": input_ids,
1005
+ "past_key_values": past,
1006
+ "encoder_outputs": encoder_outputs,
1007
+ "attention_mask": attention_mask,
1008
+ "head_mask": head_mask,
1009
+ "decoder_head_mask": decoder_head_mask,
1010
+ "cross_attn_head_mask": cross_attn_head_mask,
1011
+ "use_cache": use_cache,
1012
+ }
1013
+
1014
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1015
+ return self._shift_right(labels)
1016
+
1017
+ def _reorder_cache(self, past, beam_idx):
1018
+ # if decoder past is not included in output
1019
+ # speedy decoding is disabled and no need to reorder
1020
+ if past is None:
1021
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1022
+ return past
1023
+
1024
+ reordered_decoder_past = ()
1025
+ for layer_past_states in past:
1026
+ # get the correct batch idx from layer past batch dim
1027
+ # batch dim of `past` is at 2nd position
1028
+ reordered_layer_past_states = ()
1029
+ for layer_past_state in layer_past_states:
1030
+ # need to set correct `past` for each of the four key / value states
1031
+ reordered_layer_past_states = reordered_layer_past_states + (
1032
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1033
+ )
1034
+
1035
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1036
+ assert len(reordered_layer_past_states) == len(layer_past_states)
1037
+
1038
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1039
+ return reordered_decoder_past