zhiqu22 commited on
Commit
0517e25
·
1 Parent(s): 4b2109e

update codes

Browse files
configuration_mitre.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Mitre model configuration"""
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ class MitreConfig(PretrainedConfig):
11
+ model_type = "mitre"
12
+ keys_to_ignore_at_inference = ["past_key_values"]
13
+ attribute_map = {"num_attention_heads": "decoder_attention_heads", "hidden_size": "d_model"}
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size=160025,
18
+ max_position_embeddings=256,
19
+ decoder_layers=24,
20
+ decoder_ffn_dim=4096,
21
+ decoder_attention_heads=16,
22
+ use_cache=True,
23
+ is_encoder_decoder=False,
24
+ activation_function="relu",
25
+ d_model=1024,
26
+ dropout=0.1,
27
+ attention_dropout=0.1,
28
+ activation_dropout=0.0,
29
+ init_std=0.02,
30
+ decoder_start_token_id=2,
31
+ scale_embedding=True,
32
+ pad_token_id=1,
33
+ bos_token_id=0,
34
+ eos_token_id=2,
35
+ **kwargs,
36
+ ):
37
+ self.vocab_size = vocab_size
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.d_model = d_model
40
+ self.decoder_ffn_dim = decoder_ffn_dim
41
+ self.decoder_layers = decoder_layers
42
+ self.decoder_attention_heads = decoder_attention_heads
43
+ self.dropout = dropout
44
+ self.attention_dropout = attention_dropout
45
+ self.activation_dropout = activation_dropout
46
+ self.activation_function = activation_function
47
+ self.init_std = init_std
48
+ self.use_cache = use_cache
49
+ self.num_hidden_layers = decoder_layers
50
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
51
+ self.is_decoder = True
52
+ self.is_encoder_decoder = False
53
+
54
+ super().__init__(
55
+ pad_token_id=pad_token_id,
56
+ bos_token_id=bos_token_id,
57
+ eos_token_id=eos_token_id,
58
+ is_encoder_decoder=is_encoder_decoder,
59
+ decoder_start_token_id=decoder_start_token_id,
60
+ **kwargs,
61
+ )
62
+
63
+ MitreConfig.register_for_auto_class("AutoConfig")
mitre_spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7d00755ffecaf04eec1af3b31bb5a1bdb79e93cda9ff44ec1ee08656f6bfd84
3
+ size 3158318
modeling_mitre.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ import math
4
+ from typing import List, Optional, Tuple, Union, Dict, Any
5
+
6
+ import torch
7
+ from torch import nn
8
+ from .configuration_mitre import MitreConfig
9
+ from transformers.utils import logging
10
+
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.modeling_utils import PreTrainedModel
13
+ from transformers.activations import ACT2FN
14
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
15
+ from transformers.integrations.fsdp import is_fsdp_managed_module
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutputWithPastAndCrossAttentions,
18
+ Seq2SeqLMOutput,
19
+ Seq2SeqModelOutput,
20
+ )
21
+ from transformers.generation.configuration_utils import GenerationConfig
22
+ from transformers.generation.beam_search import BeamSearchScorer
23
+ from transformers.generation.logits_process import LogitsProcessorList
24
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
29
+ """
30
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
31
+ are ignored. This is modified from fairseq's `utils.make_positions`.
32
+ """
33
+ mask = input_ids.ne(padding_idx).int()
34
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
35
+ return incremental_indices.long() + padding_idx
36
+
37
+
38
+ # Modified from transformers.models.m2m_100.modeling_m2m_100.M2M100Attention
39
+ # and transformers.models.m2m_100.modeling_m2m_100.M2M100SdpaAttention
40
+ class MitreSdpaAttention(nn.Module):
41
+
42
+ def __init__(
43
+ self,
44
+ embed_dim: int,
45
+ num_heads: int,
46
+ dropout: float = 0.0,
47
+ bias: bool = True,
48
+ config: Optional[MitreConfig] = None,
49
+ ):
50
+ super().__init__()
51
+ self.embed_dim = embed_dim
52
+ self.num_heads = num_heads
53
+ self.dropout = dropout
54
+ self.head_dim = embed_dim // num_heads
55
+ self.config = config
56
+
57
+ if (self.head_dim * num_heads) != self.embed_dim:
58
+ raise ValueError(
59
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
60
+ f" and `num_heads`: {num_heads})."
61
+ )
62
+ self.scaling = self.head_dim**-0.5
63
+
64
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
65
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
66
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
67
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
68
+
69
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
70
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
78
+ """
79
+ Input shape: Batch x Time x Channel
80
+ Output objects: attn_output, attn_weights (always be None), past_key_value
81
+ """
82
+ """
83
+ 1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
84
+ Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
85
+ leading to 'attn_weights' always being None in output.
86
+ The plan of improving this point has a low priority.
87
+ 2. We plan to improve this code with Flash Attention v2.
88
+ """
89
+ bsz, tgt_len, _ = hidden_states.size()
90
+
91
+ # get query proj
92
+ query_states = self.q_proj(hidden_states)
93
+ if past_key_value is not None:
94
+ # reuse k, v, self_attention
95
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
96
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
97
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
98
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
99
+ else:
100
+ # self_attention
101
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
102
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
103
+
104
+ past_key_value = (key_states, value_states)
105
+
106
+ query_states = self._shape(query_states, tgt_len, bsz)
107
+
108
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
109
+ query_states,
110
+ key_states,
111
+ value_states,
112
+ attn_mask=attention_mask,
113
+ dropout_p=self.dropout if self.training else 0.0,
114
+ is_causal=False,
115
+ )
116
+
117
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
118
+ raise ValueError(
119
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
120
+ f" {attn_output.size()}"
121
+ )
122
+
123
+ attn_output = attn_output.transpose(1, 2)
124
+
125
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
126
+ # partitioned across GPUs when using tensor-parallelism.
127
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
128
+
129
+ attn_output = self.out_proj(attn_output)
130
+
131
+ return attn_output, None, past_key_value
132
+
133
+
134
+ # Modified from transformers.models.m2m_100.modeling_m2m100.M2M100DecoderLayer
135
+ class MitreDecoderLayer(nn.Module):
136
+ def __init__(self, config: MitreConfig):
137
+ super().__init__()
138
+ self.embed_dim = config.d_model
139
+
140
+ self.self_attn = MitreSdpaAttention(
141
+ embed_dim=self.embed_dim,
142
+ num_heads=config.decoder_attention_heads,
143
+ dropout=config.attention_dropout,
144
+ config=config,
145
+ )
146
+ self.dropout = config.dropout
147
+ self.activation_fn = ACT2FN[config.activation_function]
148
+ self.activation_dropout = config.activation_dropout
149
+
150
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
151
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
152
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
153
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
154
+
155
+ def forward(
156
+ self,
157
+ hidden_states: torch.Tensor,
158
+ attention_mask: Optional[torch.Tensor] = None,
159
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
160
+ use_cache: Optional[bool] = True,
161
+ ) -> torch.Tensor:
162
+ """
163
+ Args:
164
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
165
+ attention_mask (`torch.FloatTensor`): attention mask of size
166
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
167
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
168
+ """
169
+ residual = hidden_states
170
+ hidden_states = self.self_attn_layer_norm(hidden_states)
171
+
172
+ # Self Attention
173
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
174
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
175
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
176
+ hidden_states, _, present_key_value = self.self_attn(
177
+ hidden_states=hidden_states,
178
+ past_key_value=self_attn_past_key_value,
179
+ attention_mask=attention_mask,
180
+ )
181
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
182
+ hidden_states = residual + hidden_states
183
+
184
+ # Fully Connected
185
+ residual = hidden_states
186
+ hidden_states = self.final_layer_norm(hidden_states)
187
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
188
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
189
+ hidden_states = self.fc2(hidden_states)
190
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
191
+ hidden_states = residual + hidden_states
192
+
193
+ outputs = (hidden_states,)
194
+
195
+ if use_cache:
196
+ outputs += (present_key_value,)
197
+
198
+ return outputs
199
+
200
+
201
+ class MitrePreTrainedModel(PreTrainedModel):
202
+ config_class = MitreConfig
203
+ base_model_prefix = "model"
204
+ supports_gradient_checkpointing = True
205
+ _no_split_modules = ["MitreDecoderLayer"]
206
+ # we plan to implement codes for falsh attention v2
207
+ _supports_flash_attn_2 = False
208
+ _supports_sdpa = True
209
+
210
+ def _init_weights(self, module):
211
+ std = self.config.init_std
212
+ if isinstance(module, nn.Linear):
213
+ module.weight.data.normal_(mean=0.0, std=std)
214
+ if module.bias is not None:
215
+ module.bias.data.zero_()
216
+ elif isinstance(module, nn.Embedding):
217
+ module.weight.data.normal_(mean=0.0, std=std)
218
+ if module.padding_idx is not None:
219
+ module.weight.data[module.padding_idx].zero_()
220
+
221
+
222
+ class MitreDecoder(MitrePreTrainedModel):
223
+ """
224
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MitreDecoderLayer`]
225
+
226
+ Args:
227
+ config: MitreConfig
228
+ embed_tokens (nn.Embedding): output embedding
229
+ """
230
+
231
+ def __init__(self, config: MitreConfig):
232
+ super().__init__(config)
233
+ self.dropout = config.dropout
234
+ self.padding_idx = config.pad_token_id
235
+ self.max_target_positions = config.max_position_embeddings
236
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
237
+
238
+ self.embed_tokens = MitreScaledWordEmbedding(
239
+ config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
240
+ )
241
+
242
+ self.src_embed_positions = MitreSinusoidalPositionalEmbedding(
243
+ config.max_position_embeddings,
244
+ config.d_model,
245
+ self.padding_idx,
246
+ )
247
+ self.register_embed_positions = MitreSinusoidalPositionalEmbedding(
248
+ config.max_position_embeddings,
249
+ config.d_model,
250
+ self.padding_idx,
251
+ )
252
+ self.tgt_embed_positions = MitreSinusoidalPositionalEmbedding(
253
+ config.max_position_embeddings,
254
+ config.d_model,
255
+ self.padding_idx,
256
+ )
257
+ self.layers = nn.ModuleList([MitreDecoderLayer(config) for _ in range(config.decoder_layers)])
258
+ if config._attn_implementation != "sdpa":
259
+ raise NotImplementedError("Other attention mechanism are not implemented yet.")
260
+
261
+ # TODO implement flash atten v2 for MITRE
262
+ # self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
263
+ self._use_sdpa = config._attn_implementation == "sdpa"
264
+ self.layer_norm = nn.LayerNorm(config.d_model)
265
+
266
+ self.gradient_checkpointing = False
267
+ self._future_mask = torch.empty(0)
268
+ # Initialize weights and apply final processing
269
+ self.post_init()
270
+
271
+ def create_registers(self, input_ids):
272
+ '''
273
+ create registers by duplicating the language tag respective to each sentence.
274
+ length(registers) = length(real_tokens) = length(tokens) - length(pads)
275
+ '''
276
+ register_nums = (~input_ids.eq(self.padding_idx)).sum(dim=1)
277
+ max_register_nums = register_nums.max().item()
278
+ total_token_nums = input_ids.size(1) + max_register_nums
279
+ batch_size = input_ids.size(0)
280
+ registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums)
281
+ return registers, register_nums, total_token_nums
282
+
283
+ def combine_src_and_registers(self, input_ids, registers, register_nums, total_token_nums):
284
+ '''
285
+ return a expanded_src_tokens for positional embedding.
286
+ '''
287
+ pads = torch.full_like(registers, self.padding_idx)
288
+ expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1)
289
+ indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device)
290
+ indices = indices + register_nums.unsqueeze(1)
291
+
292
+ batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, indices.size(1)).contiguous()
293
+ return expanded_src_tokens, batch_indices, indices
294
+
295
+ def fill_with_neg_inf(self, t):
296
+ return t.float().fill_(float("-inf")).type_as(t)
297
+
298
+ def build_future_mask(self, embeds, src_length, register_nums, padding_mask=None, past_key_values_length=0):
299
+ b = register_nums.size(0)
300
+ ns = src_length - register_nums
301
+ if past_key_values_length == 0:
302
+ # in training
303
+ # 1. create mask by cache
304
+ dim = embeds.size(1)
305
+ if (
306
+ self._future_mask.size(0) == 0
307
+ or self._future_mask.size(0) < dim
308
+ ):
309
+ self._future_mask = torch.triu(self.fill_with_neg_inf(torch.zeros([dim, dim])), 1)
310
+ if self._future_mask.device == embeds.device:
311
+ mask = self._future_mask[:dim, :dim].clone()
312
+ else:
313
+ mask = self._future_mask[:dim, :dim].to(embeds, copy=True)
314
+
315
+ # 2. bi-directional attention in source tokens and registers
316
+ mask[ :src_length, :src_length] = 0.
317
+
318
+ # 3. create batch mask
319
+ batch_mask = mask.unsqueeze(0).expand(b, -1, -1).clone().contiguous()
320
+
321
+ # 4. mask source tokens -> registers
322
+ # 5. mask target -> source tokens
323
+ batch_indices = torch.arange(b).to(batch_mask.device).view(-1, 1, 1).expand(b, dim, dim).contiguous()
324
+ row_indices = torch.arange(dim).to(batch_mask.device).view(1, -1, 1).expand(b, dim, dim).contiguous()
325
+ col_indices = torch.arange(dim).to(batch_mask.device).view(1, 1, -1).expand(b, dim, dim).contiguous()
326
+ source_indices = (row_indices < ns.view(-1, 1, 1)) & (col_indices >= ns.view(-1, 1, 1)) & (col_indices < (ns + register_nums).view(-1, 1, 1)).contiguous()
327
+ target_indices = (row_indices >= (ns + register_nums).view(-1, 1, 1)) & (col_indices < ns.view(-1, 1, 1)).contiguous()
328
+ # 4
329
+ batch_mask[batch_indices[source_indices], row_indices[source_indices], col_indices[source_indices]] = float('-inf')
330
+ # 5
331
+ batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf')
332
+ # shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len
333
+ batch_mask = batch_mask.unsqueeze(1)
334
+ # 6. masking pads
335
+ if padding_mask is not None:
336
+ if padding_mask.any():
337
+ padding_mask = padding_mask.to(batch_mask.device).unsqueeze(1).unsqueeze(2)
338
+ batch_mask = batch_mask.masked_fill(padding_mask == 1, float('-inf'))
339
+
340
+ elif past_key_values_length > 0:
341
+ # in generation
342
+ mask = torch.zeros(past_key_values_length + 1)
343
+ mask = mask.to(embeds, copy=True)
344
+ batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous()
345
+
346
+ batch_indices = torch.arange(b).view(-1, 1).expand(b, past_key_values_length + 1).to(batch_mask.device)
347
+ token_indices = torch.arange(past_key_values_length + 1).view(1, -1).expand(b, past_key_values_length + 1).to(batch_mask.device)
348
+ target_to_source_mask = token_indices < ns.view(-1, 1)
349
+
350
+ batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf')
351
+ batch_mask = batch_mask.unsqueeze(1)
352
+
353
+ # ensure contiguous
354
+ batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
355
+ return batch_mask
356
+
357
+
358
+ def forward(
359
+ self,
360
+ input_ids: Optional[torch.Tensor] = None,
361
+ decoder_input_ids: Optional[torch.Tensor] = None,
362
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
363
+ use_cache: Optional[bool] = None,
364
+ output_attentions: Optional[bool] = None,
365
+ output_hidden_states: Optional[bool] = None,
366
+ registering_cache: dict = None,
367
+ ):
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = (
370
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
+ )
372
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
373
+
374
+ # past_key_values_length
375
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
376
+
377
+ decoder_input_shape = decoder_input_ids.size()
378
+ decoder_input_ids = decoder_input_ids.view(-1, decoder_input_shape[-1])
379
+ padding_mask = None
380
+
381
+ if past_key_values_length > 0:
382
+ register_nums = registering_cache["register_nums"]
383
+ src_length = registering_cache["src_length"]
384
+
385
+ if input_ids is not None and past_key_values_length == 0:
386
+ # .view() additionally ensure that the memory is contiguous
387
+ input_shape = input_ids.size()
388
+ input_ids = input_ids.view(-1, input_shape[-1])
389
+
390
+ registers, register_nums, total_token_nums = self.create_registers(input_ids)
391
+ expanded_src_tokens, batch_indices, indices = self.combine_src_and_registers(input_ids, registers, register_nums, total_token_nums)
392
+
393
+ # positional embedding for source tokens and registers
394
+ inputs_embeds = self.embed_tokens(expanded_src_tokens)
395
+ inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums])
396
+ inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:])
397
+ inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1)
398
+ inputs_embeds = inputs_embeds[batch_indices, indices]
399
+
400
+
401
+ # padding mask
402
+ source_tokens = expanded_src_tokens[batch_indices, indices]
403
+ src_length = source_tokens.shape[1]
404
+
405
+ # replace the inference trigger with langtok
406
+ # namely, enc-tgt-dec-tgt strategy
407
+ if decoder_input_ids[0][0].item() != source_tokens[0][-1].item():
408
+ decoder_input_ids[:, 0] = source_tokens[:, -1]
409
+
410
+ tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
411
+ padding_mask = tokens.eq(self.padding_idx)
412
+
413
+ decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
414
+ decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
415
+ if past_key_values_length == 0:
416
+ hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
417
+ else:
418
+ hidden_states = decoder_inputs_embeds
419
+
420
+ attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, padding_mask, past_key_values_length)
421
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
422
+
423
+ if self.gradient_checkpointing and self.training:
424
+ if use_cache:
425
+ logger.warning_once(
426
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting" " `use_cache=False`..."
427
+ )
428
+ use_cache = False
429
+
430
+ # decoder layers
431
+ all_hidden_states = () if output_hidden_states else None
432
+ all_self_attns = () if output_attentions else None
433
+ all_cross_attentions = () if output_attentions else None
434
+ next_decoder_cache = () if use_cache else None
435
+
436
+ for idx, decoder_layer in enumerate(self.layers):
437
+ if output_hidden_states:
438
+ all_hidden_states += (hidden_states,)
439
+
440
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
441
+
442
+ if self.gradient_checkpointing and self.training:
443
+ layer_outputs = self._gradient_checkpointing_func(
444
+ decoder_layer.__call__,
445
+ hidden_states,
446
+ attention_mask,
447
+ past_key_value=None,
448
+ use_cache=use_cache,
449
+ )
450
+ else:
451
+ layer_outputs = decoder_layer(
452
+ hidden_states,
453
+ attention_mask=attention_mask,
454
+ past_key_value=past_key_value,
455
+ use_cache=use_cache,
456
+ )
457
+
458
+ hidden_states = layer_outputs[0]
459
+
460
+ if use_cache:
461
+ next_decoder_cache += (layer_outputs[1],)
462
+
463
+ if past_key_values_length == 0:
464
+ hidden_states = hidden_states[:,src_length:,:]
465
+
466
+ hidden_states = self.layer_norm(hidden_states)
467
+
468
+ # add hidden states from the last decoder layer
469
+ if output_hidden_states:
470
+ all_hidden_states += (hidden_states,)
471
+
472
+ next_cache = next_decoder_cache if use_cache else None
473
+
474
+ model_output = BaseModelOutputWithPastAndCrossAttentions(
475
+ last_hidden_state=hidden_states,
476
+ past_key_values=next_cache,
477
+ hidden_states=all_hidden_states,
478
+ attentions=all_self_attns,
479
+ cross_attentions=all_cross_attentions,
480
+ )
481
+ model_output.registering_cache = {
482
+ "register_nums": register_nums,
483
+ "src_length": src_length
484
+ }
485
+ return model_output
486
+
487
+
488
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding
489
+ class MitreScaledWordEmbedding(nn.Embedding):
490
+ """
491
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
492
+ """
493
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
494
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
495
+ self.embed_scale = embed_scale
496
+
497
+ def forward(self, input_ids: torch.Tensor):
498
+ return super().forward(input_ids) * self.embed_scale
499
+
500
+
501
+ class MitreSinusoidalPositionalEmbedding(nn.Module):
502
+ """This module produces sinusoidal positional embeddings of any length."""
503
+
504
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
505
+ super().__init__()
506
+ self.offset = 2
507
+ self.embedding_dim = embedding_dim
508
+ self.padding_idx = padding_idx
509
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
510
+
511
+ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
512
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
513
+ if hasattr(self, "weights"):
514
+ # in forward put the weights on the correct dtype and device of the param
515
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
516
+
517
+ self.register_buffer("weights", emb_weights, persistent=False)
518
+
519
+ @staticmethod
520
+ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
521
+ """
522
+ Build sinusoidal embeddings.
523
+
524
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
525
+ "Attention Is All You Need".
526
+ """
527
+ half_dim = embedding_dim // 2
528
+ emb = math.log(10000) / (half_dim - 1)
529
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
530
+ emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
531
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
532
+ if embedding_dim % 2 == 1:
533
+ # zero pad
534
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
535
+ if padding_idx is not None:
536
+ emb[padding_idx, :] = 0
537
+
538
+ return emb.to(torch.get_default_dtype())
539
+
540
+ @torch.no_grad()
541
+ def forward(
542
+ self, input_ids: torch.Tensor = None, past_key_values_length: int = 0, src_length: int = 0
543
+ ):
544
+ bsz, seq_len = input_ids.size()
545
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
546
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
547
+ input_ids.device
548
+ )
549
+
550
+ if past_key_values_length > 0 and src_length > 0:
551
+ position_ids = torch.where(position_ids == 1, position_ids, position_ids - src_length)
552
+
553
+ # expand embeddings if needed
554
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
555
+
556
+ if max_pos > self.weights.size(0):
557
+ self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
558
+
559
+ return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
560
+
561
+ class MitreModel(MitrePreTrainedModel):
562
+ _tied_weights_keys = ["decoder.embed_tokens.weight"]
563
+
564
+ def __init__(self, config: MitreConfig):
565
+ super().__init__(config)
566
+
567
+ self.decoder = MitreDecoder(config)
568
+
569
+ # Initialize weights and apply final processing
570
+ self.post_init()
571
+
572
+ def get_input_embeddings(self):
573
+ return self.decoder.embed_tokens
574
+
575
+ def get_decoder(self):
576
+ return self.decoder
577
+
578
+ def forward(
579
+ self,
580
+ input_ids: Optional[torch.LongTensor] = None,
581
+ decoder_input_ids: Optional[torch.Tensor] = None,
582
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
583
+ use_cache: Optional[bool] = None,
584
+ output_attentions: Optional[bool] = None,
585
+ output_hidden_states: Optional[bool] = None,
586
+ registering_cache: dict = None,
587
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
588
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
589
+ output_hidden_states = (
590
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
591
+ )
592
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
593
+
594
+ decoder_outputs = self.decoder(
595
+ input_ids=input_ids,
596
+ decoder_input_ids=decoder_input_ids,
597
+ past_key_values=past_key_values,
598
+ use_cache=use_cache,
599
+ output_hidden_states=output_hidden_states,
600
+ registering_cache=registering_cache
601
+ )
602
+
603
+ model_output = Seq2SeqModelOutput(
604
+ last_hidden_state=decoder_outputs.last_hidden_state,
605
+ past_key_values=decoder_outputs.past_key_values,
606
+ decoder_hidden_states=decoder_outputs.hidden_states,
607
+ decoder_attentions=decoder_outputs.attentions,
608
+ )
609
+ model_output.registering_cache = decoder_outputs.registering_cache
610
+ return model_output
611
+
612
+ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
613
+ base_model_prefix = "model"
614
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
615
+
616
+ def __init__(self, config: MitreConfig):
617
+ super().__init__(config)
618
+ self.model = MitreModel(config)
619
+ self.lm_head = nn.Linear(config.d_model, self.model.decoder.embed_tokens.num_embeddings, bias=False)
620
+
621
+ # Initialize weights and apply final processing
622
+ self.post_init()
623
+
624
+ def get_decoder(self):
625
+ return self.model.get_decoder()
626
+
627
+ def get_output_embeddings(self):
628
+ return self.lm_head
629
+
630
+ def set_output_embeddings(self, new_embeddings):
631
+ self.lm_head = new_embeddings
632
+
633
+ def forward(
634
+ self,
635
+ input_ids: Optional[torch.LongTensor] = None,
636
+ decoder_input_ids: Optional[torch.LongTensor] = None,
637
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
638
+ labels: Optional[torch.LongTensor] = None,
639
+ use_cache: Optional[bool] = None,
640
+ output_hidden_states: Optional[bool] = None,
641
+ registering_cache: dict = None,
642
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
643
+ outputs = self.model(
644
+ input_ids=input_ids,
645
+ decoder_input_ids=decoder_input_ids,
646
+ past_key_values=past_key_values,
647
+ use_cache=use_cache,
648
+ output_hidden_states=output_hidden_states,
649
+ registering_cache=registering_cache,
650
+ )
651
+
652
+ lm_logits = self.lm_head(outputs[0])
653
+
654
+ if labels is not None:
655
+ raise NotImplementedError("Please implement your loss function here.")
656
+
657
+ model_output = Seq2SeqLMOutput(
658
+ loss=None,
659
+ logits=lm_logits,
660
+ past_key_values=outputs.past_key_values,
661
+ decoder_hidden_states=outputs.decoder_hidden_states,
662
+ decoder_attentions=outputs.decoder_attentions,
663
+ )
664
+ model_output.registering_cache = outputs.registering_cache
665
+ return model_output
666
+
667
+ @staticmethod
668
+ def _reorder_cache(past_key_values, beam_idx):
669
+ reordered_past = ()
670
+ for layer_past in past_key_values:
671
+ reordered_past += (
672
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
673
+ )
674
+ return reordered_past
675
+
676
+ @staticmethod
677
+ def _reorder_register_nums(register_nums, beam_idx):
678
+ return register_nums.index_select(0, beam_idx.to(register_nums.device))
679
+
680
+ @staticmethod
681
+ def _expand_inputs_for_generation(
682
+ input_ids: Optional[torch.LongTensor] = None,
683
+ beam_size: int = 1,
684
+ ) -> torch.LongTensor:
685
+ """
686
+ Expands input_ids from [batch_size, len(tokens)] to [batch_size * expand_size, , len(tokens)]
687
+ This is simplified from 'transformers.generation.utils.GenerationMixin._expand_inputs_for_generation'
688
+ """
689
+ if beam_size == 1:
690
+ return input_ids
691
+
692
+ return input_ids.repeat_interleave(beam_size, dim=0)
693
+
694
+ def generate(self,
695
+ input_ids: Optional[torch.Tensor] = None,
696
+ generation_config: Optional[GenerationConfig] = None,
697
+ **kwargs: Dict
698
+ ):
699
+ """
700
+ Inference with beam search.
701
+ This code is simplified from 'transformers.generation.utils.GenerationMixin.generate'.
702
+ This code follows the style of m2m and nllb.
703
+ Therefore, there are two points need improvement.
704
+ TODO
705
+ 1. early_stop in beam search.
706
+ Current early_stop is at the beam search level instead of model level. Specficially,
707
+ although beamscorer generates eos to the sequence, the sequence is filled by 'pad(1)'.
708
+ As a result, the sequence, which has already finished, will be computed by the model
709
+ continuously. We plan to remove the finished token as Fairseq's style.
710
+ 2. build self-attention mask.
711
+ Current building happens within the model. Thus, when running beam search, we have to
712
+ create a mask whose size is (beam_size * batch_size) from scratch. If we create the mask
713
+ outside of the model, we can create the mask by duplicating beam_size times.
714
+ Moreover, we can prepare a cache of mask in beam search to avoid create mask many times.
715
+ """
716
+ if generation_config != None:
717
+ assert type(generation_config) is GenerationConfig
718
+ self.generation_config = generation_config
719
+ self.generation_config.update(**kwargs)
720
+
721
+ generation_config = self.generation_config
722
+
723
+ batch_size = input_ids.shape[0]
724
+ beam_size = generation_config.num_beams
725
+ device = input_ids.device
726
+ max_cache_length = generation_config.max_length
727
+ eos_token_id = torch.Tensor([generation_config.eos_token_id])
728
+
729
+ # initial the target tokens
730
+ decoder_input_ids = torch.full(
731
+ (batch_size, 1),
732
+ self.generation_config.decoder_start_token_id,
733
+ dtype=input_ids.dtype,
734
+ device=device
735
+ )
736
+
737
+ beam_scorer = BeamSearchScorer(
738
+ batch_size=batch_size,
739
+ num_beams=beam_size,
740
+ device=device,
741
+ length_penalty=self.generation_config.length_penalty,
742
+ do_early_stopping=self.generation_config.early_stopping,
743
+ num_beam_hyps_to_keep=self.generation_config.num_return_sequences,
744
+ max_length=max_cache_length,
745
+ )
746
+
747
+ input_ids = self._expand_inputs_for_generation(input_ids, beam_size)
748
+ decoder_input_ids = self._expand_inputs_for_generation(decoder_input_ids, beam_size)
749
+ # decoder_input_ids.to(device)
750
+ cur_len = decoder_input_ids.shape[1]
751
+
752
+ this_peer_finished = False
753
+ past_key_values = None
754
+ registering_cache = None
755
+
756
+ logits_processor = LogitsProcessorList()
757
+ stopping_criteria = StoppingCriteriaList()
758
+
759
+ beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device)
760
+ beam_scores[:, 1:] = -1e9
761
+ beam_scores = beam_scores.view((batch_size * beam_size,))
762
+ while not this_peer_finished:
763
+
764
+ if past_key_values is not None:
765
+ decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
766
+ else:
767
+ decoder_input_ids_for_generation = decoder_input_ids
768
+
769
+ outputs = self(input_ids, decoder_input_ids_for_generation, past_key_values=past_key_values, use_cache=True, registering_cache=registering_cache)
770
+
771
+ del input_ids
772
+ input_ids = None
773
+
774
+ past_key_values = outputs.past_key_values
775
+ registering_cache = outputs.registering_cache
776
+
777
+ next_token_logits = outputs.logits[:, -1, :].clone().float()
778
+ next_token_logits = next_token_logits.to(device)
779
+
780
+ next_token_scores = nn.functional.log_softmax(
781
+ next_token_logits, dim=-1
782
+ ) # (batch_size * num_beams, vocab_size)
783
+
784
+ next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores)
785
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
786
+ next_token_scores_processed
787
+ )
788
+
789
+ # reshape for beam search
790
+ vocab_size = next_token_scores.shape[-1]
791
+ next_token_scores = next_token_scores.view(batch_size, beam_size * vocab_size)
792
+
793
+ # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
794
+ # non eos token per beam.
795
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
796
+ n_tokens_to_keep = max(2, 1 + n_eos_tokens) * beam_size
797
+ next_token_scores, next_tokens = torch.topk(
798
+ next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
799
+ )
800
+
801
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
802
+ next_tokens = next_tokens % vocab_size
803
+ beam_outputs = beam_scorer.process(
804
+ decoder_input_ids,
805
+ next_token_scores,
806
+ next_tokens,
807
+ next_indices,
808
+ pad_token_id=generation_config.pad_token_id,
809
+ eos_token_id=generation_config.eos_token_id,
810
+ decoder_prompt_len=1,
811
+ )
812
+ beam_scores = beam_outputs["next_beam_scores"]
813
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
814
+ beam_idx = beam_outputs["next_beam_indices"]
815
+ decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
816
+
817
+ del outputs
818
+
819
+ past_key_values = self._reorder_cache(past_key_values, beam_idx)
820
+ registering_cache["register_nums"] = self._reorder_register_nums(registering_cache["register_nums"], beam_idx)
821
+
822
+ cur_len = cur_len + 1
823
+
824
+ if beam_scorer.is_done:
825
+ this_peer_finished = True
826
+
827
+ sequence_outputs = beam_scorer.finalize(
828
+ decoder_input_ids,
829
+ beam_scores,
830
+ next_tokens,
831
+ next_indices,
832
+ pad_token_id=generation_config.pad_token_id,
833
+ eos_token_id=eos_token_id,
834
+ max_length=stopping_criteria.max_length,
835
+ decoder_prompt_len=1,
836
+ )
837
+
838
+ return sequence_outputs["sequences"]
839
+
840
+
841
+ MitreForConditionalGeneration.register_for_auto_class("AutoModel")
tokenization_mitre.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from shutil import copyfile
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+ import torch
7
+
8
+ import sentencepiece
9
+
10
+ from transformers.tokenization_utils import PreTrainedTokenizer
11
+ from transformers.utils import logging
12
+
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+ SPIECE_UNDERLINE = "▁"
17
+
18
+ VOCAB_FILES_NAMES = {
19
+ "vocab_file": "vocab.json",
20
+ "spm_file": "mitre_spm.model",
21
+ "tokenizer_config_file": "tokenizer_config.json",
22
+ }
23
+
24
+ # follow iso639-2
25
+ FAIRSEQ_LANGUAGE_CODES = ["en", "de", "nl", "sv", "da", "af", "fr", "es", "it", "pt", "ro", "ru", "cs", "pl", "bg", "uk", "id", "jv", "ms", "tl", "ja", "zh", "ko", "vi"]
26
+
27
+ # This is the tokenizer of MITRE.
28
+ # This code is modified from transformers.models.m2m_100.tokenization_m2m_100.M2M100Tokenizer
29
+ class MitreTokenizer(PreTrainedTokenizer):
30
+ vocab_files_names = VOCAB_FILES_NAMES
31
+ model_input_names = ["input_ids", "attention_mask"]
32
+
33
+ prefix_tokens: List[int] = []
34
+ suffix_tokens: List[int] = []
35
+
36
+ def __init__(
37
+ self,
38
+ vocab_file,
39
+ spm_file,
40
+ bos_token="<s>",
41
+ eos_token="</s>",
42
+ sep_token="</s>",
43
+ pad_token="<pad>",
44
+ unk_token="<unk>",
45
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
46
+ **kwargs,
47
+ ) -> None:
48
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
49
+ fairseq_language_code = FAIRSEQ_LANGUAGE_CODES
50
+ self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in fairseq_language_code}
51
+
52
+ additional_special_tokens = kwargs.pop("additional_special_tokens", [])
53
+ for lang_code in fairseq_language_code:
54
+ token = self.get_lang_token(lang_code)
55
+ if token not in additional_special_tokens:
56
+ additional_special_tokens.append(token)
57
+
58
+ self.vocab_file = vocab_file
59
+ self.encoder = load_json(vocab_file)
60
+ self.decoder = {v: k for k, v in self.encoder.items()}
61
+ self.spm_file = spm_file
62
+ self.sp_model = load_spm(spm_file, self.sp_model_kwargs)
63
+
64
+ self.encoder_size = len(self.encoder)
65
+
66
+ self.lang_token_to_id = {
67
+ self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code)
68
+ }
69
+ self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code)}
70
+ self.id_to_lang_token = {v: k for k, v in self.lang_token_to_id.items()}
71
+ # default
72
+ self.tgt_lang = "en"
73
+
74
+ super().__init__(
75
+ bos_token=bos_token,
76
+ eos_token=eos_token,
77
+ sep_token=sep_token,
78
+ unk_token=unk_token,
79
+ pad_token=pad_token,
80
+ sp_model_kwargs=self.sp_model_kwargs,
81
+ additional_special_tokens=additional_special_tokens,
82
+ **kwargs,
83
+ )
84
+
85
+ @property
86
+ def vocab_size(self) -> int:
87
+ return len(self.encoder)
88
+
89
+ def get_vocab(self) -> Dict:
90
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
91
+ vocab.update(self.added_tokens_encoder)
92
+ return vocab
93
+
94
+ def _tokenize(self, text: str) -> List[str]:
95
+ return self.sp_model.encode(text, out_type=str)
96
+
97
+ def _convert_token_to_id(self, token):
98
+ if token in self.lang_token_to_id:
99
+ return self.lang_token_to_id[token]
100
+ return self.encoder.get(token, self.encoder[self.unk_token])
101
+
102
+ def _convert_id_to_token(self, index: int) -> str:
103
+ """Converts an index (integer) in a token (str) using the decoder."""
104
+ if index in self.id_to_lang_token:
105
+ return self.id_to_lang_token[index]
106
+ return self.decoder.get(index, self.unk_token)
107
+
108
+ def convert_tokens_to_string(self, tokens):
109
+ """Converts a sequence of tokens (string) in a single string."""
110
+ current_sub_tokens = []
111
+ out_string = ""
112
+ for token in tokens:
113
+ # make sure that special tokens are not decoded using sentencepiece model
114
+ if token in self.all_special_tokens:
115
+ out_string += self.sp_model.decode(current_sub_tokens) + token
116
+ current_sub_tokens = []
117
+ else:
118
+ current_sub_tokens.append(token)
119
+ out_string += self.sp_model.decode(current_sub_tokens)
120
+ return out_string.strip()
121
+
122
+ def __getstate__(self) -> Dict:
123
+ state = self.__dict__.copy()
124
+ state["sp_model"] = None
125
+ return state
126
+
127
+ def __setstate__(self, d: Dict) -> None:
128
+ self.__dict__ = d
129
+
130
+ # for backward compatibility
131
+ if not hasattr(self, "sp_model_kwargs"):
132
+ self.sp_model_kwargs = {}
133
+
134
+ self.sp_model = load_spm(self.spm_file, self.sp_model_kwargs)
135
+
136
+ def build_inputs_with_special_tokens(
137
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
138
+ ) -> List[int]:
139
+ if token_ids_1 is None:
140
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
141
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
142
+
143
+ def _switch_to_input_mode(self):
144
+ self.set_tgt_lang_special_tokens(self.tgt_lang)
145
+
146
+ def _switch_to_target_mode(self):
147
+ self.clear_lang_special_tokens()
148
+
149
+ def clear_lang_special_tokens(self) -> None:
150
+ self.prefix_tokens = []
151
+ self.suffix_tokens = [self.eos_token_id]
152
+
153
+ def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
154
+ """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
155
+ lang_token = self.get_lang_token(tgt_lang)
156
+ self.cur_lang_id = self.lang_token_to_id[lang_token]
157
+ self.prefix_tokens = [self.cur_lang_id]
158
+ self.suffix_tokens = [self.eos_token_id]
159
+
160
+ def get_lang_token(self, lang: str) -> str:
161
+ return self.lang_code_to_token[lang]
162
+
163
+ def get_lang_id(self, lang: str) -> int:
164
+ lang_token = self.get_lang_token(lang)
165
+ return self.lang_token_to_id[lang_token]
166
+
167
+ def encode_source_tokens_to_input_ids(self, inputs, target_language="en"):
168
+ """pads + target language id + source tokens id + eos id"""
169
+ self.tgt_lang = target_language
170
+ input_ids = self.__call__(inputs, add_special_tokens=True, padding_side='left', padding=True, return_attention_mask=False, return_tensors="pt")
171
+ return input_ids["input_ids"]
172
+
173
+ def encode_source_tokens_to_input_ids_with_different_tags(self, inputs_text, target_languages_list: list):
174
+ """
175
+ 'encode_source_tokens_to_input_ids' only supports a language tag,
176
+ but sevenral in a batch could have different language tags.
177
+ """
178
+ self.tgt_lang = "en"
179
+ input_ids = self.__call__(inputs_text, add_special_tokens=True, padding_side='left', padding=True, return_attention_mask=False, return_tensors="pt")["input_ids"]
180
+ _, max_indices = torch.max(input_ids, dim=1)
181
+ input_ids[torch.arange(max_indices.shape[0]), max_indices] = torch.LongTensor([self.lang_token_to_id[self.get_lang_token(lang_code)] for lang_code in target_languages_list])
182
+ return input_ids
183
+
184
+ def encode_target_tokens_to_labels(self, inputs_text):
185
+ """target tokens id + eos id + pads"""
186
+ input_ids = self.__call__(text_target=inputs_text, add_special_tokens=True, padding_side='right', padding=True, return_attention_mask=False, return_tensors="pt")
187
+ return input_ids["input_ids"]
188
+
189
+ def encode_target_tokens_to_input_ids(self, inputs_text):
190
+ """eos id + target tokens id + pads, namely, left shifted"""
191
+ input_ids = self.__call__(text_target=inputs_text, add_special_tokens=False, padding_side='right', padding=True, return_attention_mask=False, return_tensors="pt")
192
+ labels_without_eos = input_ids["input_ids"]
193
+ return torch.cat((torch.full((labels_without_eos.size(0), 1), self.eos_token_id), labels_without_eos), dim=1)
194
+
195
+
196
+ def load_spm(path: str, sp_model_kwargs: Dict[str, Any]) -> sentencepiece.SentencePieceProcessor:
197
+ spm = sentencepiece.SentencePieceProcessor(**sp_model_kwargs)
198
+ spm.Load(str(path))
199
+ return spm
200
+
201
+
202
+ def load_json(path: str) -> Union[Dict, List]:
203
+ with open(path, "r") as f:
204
+ return json.load(f)
205
+
206
+
207
+ def save_json(data, path: str) -> None:
208
+ with open(path, "w") as f:
209
+ json.dump(data, f, indent=2)
210
+
211
+ MitreTokenizer.register_for_auto_class("AutoTokenizer")
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "src_lang": null,
3
+ "tgt_lang": null,
4
+ "bos_token": "<s>",
5
+ "eos_token": "</s>",
6
+ "sep_token": "</s>",
7
+ "unk_token": "<unk>",
8
+ "pad_token": "<pad>",
9
+ "model_max_length": 256,
10
+ "name_or_path": "naist-nlp/mitre_913m",
11
+ "tokenizer_class": "MitreTokenizer",
12
+ "auto_map": {
13
+ "AutoTokenizer": ["tokenization_mitre.MitreTokenizer"]
14
+ }
15
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff