jefson08 commited on
Commit
b9ee0d2
·
verified ·
1 Parent(s): b2e47b3

Model save

Browse files
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - generated_from_trainer
5
+ model-index:
6
+ - name: indictrans-en-ne-checkpoint-1B
7
+ results: []
8
+ ---
9
+
10
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
11
+ should probably proofread and complete it, then remove this comment. -->
12
+
13
+ # indictrans-en-ne-checkpoint-1B
14
+
15
+ This model was trained from scratch on an unknown dataset.
16
+ It achieves the following results on the evaluation set:
17
+ - Loss: 8.7896
18
+
19
+ ## Model description
20
+
21
+ More information needed
22
+
23
+ ## Intended uses & limitations
24
+
25
+ More information needed
26
+
27
+ ## Training and evaluation data
28
+
29
+ More information needed
30
+
31
+ ## Training procedure
32
+
33
+ ### Training hyperparameters
34
+
35
+ The following hyperparameters were used during training:
36
+ - learning_rate: 1e-05
37
+ - train_batch_size: 16
38
+ - eval_batch_size: 2
39
+ - seed: 42
40
+ - gradient_accumulation_steps: 32
41
+ - total_train_batch_size: 512
42
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
43
+ - lr_scheduler_type: linear
44
+ - num_epochs: 2
45
+
46
+ ### Training results
47
+
48
+ | Training Loss | Epoch | Step | Validation Loss |
49
+ |:-------------:|:------:|:----:|:---------------:|
50
+ | 5.3443 | 0.6824 | 1000 | 8.8698 |
51
+ | 5.3574 | 1.3648 | 2000 | 8.7896 |
52
+
53
+
54
+ ### Framework versions
55
+
56
+ - Transformers 4.44.2
57
+ - Pytorch 2.2.1+cu121
58
+ - Datasets 2.21.0
59
+ - Tokenizers 0.19.1
configuration_indictrans.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans config."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, Mapping, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
24
+ from transformers.onnx.utils import compute_effective_axis_dimension
25
+ from transformers.utils import TensorType, is_torch_available
26
+
27
+
28
+ # Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
29
+ class IndicTransConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
32
+ IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the IT2
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50265):
41
+ Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`IT2Model`] or
43
+ d_model (`int`, *optional*, defaults to 1024):
44
+ Dimensionality of the layers and the pooler layer.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the model should return the last key/values attentions (not used by all models).
81
+ ```"""
82
+ model_type = "IndicTrans"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "num_attention_heads": "encoder_attention_heads",
86
+ "hidden_size": "d_model",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ encoder_vocab_size=None,
92
+ decoder_vocab_size=None,
93
+ encoder_embed_dim=512,
94
+ decoder_embed_dim=512,
95
+ max_source_positions=210,
96
+ max_target_positions=210,
97
+ encoder_layers=6,
98
+ encoder_ffn_dim=2048,
99
+ encoder_attention_heads=8,
100
+ decoder_layers=6,
101
+ decoder_ffn_dim=2048,
102
+ decoder_attention_heads=8,
103
+ encoder_layerdrop=0.00,
104
+ decoder_layerdrop=0.00,
105
+ use_cache=True,
106
+ is_encoder_decoder=True,
107
+ activation_function="relu",
108
+ encoder_normalize_before=False,
109
+ decoder_normalize_before=False,
110
+ layernorm_embedding=False,
111
+ share_decoder_input_output_embed=False,
112
+ dropout=0.1,
113
+ attention_dropout=0.0,
114
+ activation_dropout=0.0,
115
+ init_std=0.02,
116
+ scale_embedding=True,
117
+ decoder_start_token_id=2,
118
+ pad_token_id=1,
119
+ bos_token_id=0,
120
+ eos_token_id=2,
121
+ attn_implementation="eager",
122
+ **kwargs,
123
+ ):
124
+ self.encoder_vocab_size = encoder_vocab_size
125
+ self.decoder_vocab_size = decoder_vocab_size
126
+ self.encoder_normalize_before = encoder_normalize_before
127
+ self.decoder_normalize_before = decoder_normalize_before
128
+ self.layernorm_embedding = layernorm_embedding
129
+ self.max_source_positions = max_source_positions
130
+ self.max_target_positions = max_target_positions
131
+ self.encoder_embed_dim = encoder_embed_dim
132
+ self.decoder_embed_dim = decoder_embed_dim
133
+ self.encoder_ffn_dim = encoder_ffn_dim
134
+ self.encoder_layers = encoder_layers
135
+ self.encoder_attention_heads = encoder_attention_heads
136
+ self.decoder_ffn_dim = decoder_ffn_dim
137
+ self.decoder_layers = decoder_layers
138
+ self.decoder_attention_heads = decoder_attention_heads
139
+ self.dropout = dropout
140
+ self.attention_dropout = attention_dropout
141
+ self.activation_dropout = activation_dropout
142
+ self.activation_function = activation_function
143
+ self.init_std = init_std
144
+ self.encoder_layerdrop = encoder_layerdrop
145
+ self.decoder_layerdrop = decoder_layerdrop
146
+ self.use_cache = use_cache
147
+ self.num_hidden_layers = encoder_layers
148
+ self.scale_embedding = scale_embedding
149
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
150
+ self.attn_implementation = attn_implementation
151
+
152
+ super().__init__(
153
+ pad_token_id=pad_token_id,
154
+ bos_token_id=bos_token_id,
155
+ eos_token_id=eos_token_id,
156
+ is_encoder_decoder=is_encoder_decoder,
157
+ decoder_start_token_id=decoder_start_token_id,
158
+ **kwargs,
159
+ )
160
+
161
+
162
+ class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
163
+ @property
164
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
165
+ common_inputs = OrderedDict(
166
+ [
167
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
168
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
169
+ ]
170
+ )
171
+
172
+ if self.use_past:
173
+ common_inputs["decoder_input_ids"] = {0: "batch"}
174
+ common_inputs["decoder_attention_mask"] = {
175
+ 0: "batch",
176
+ 1: "past_decoder_sequence + sequence",
177
+ }
178
+ else:
179
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
180
+ common_inputs["decoder_attention_mask"] = {
181
+ 0: "batch",
182
+ 1: "decoder_sequence",
183
+ }
184
+
185
+ if self.use_past:
186
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
187
+ return common_inputs
188
+
189
+ # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
190
+ # A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
191
+ # answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
192
+ # was done for BART so that it can be updated if need be.
193
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
194
+ self,
195
+ tokenizer: PreTrainedTokenizer,
196
+ batch_size: int = -1,
197
+ seq_length: int = -1,
198
+ is_pair: bool = False,
199
+ framework: Optional[TensorType] = None,
200
+ ) -> Mapping[str, Any]:
201
+ # Copied from OnnxConfig.generate_dummy_inputs
202
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
203
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
204
+ batch_size = compute_effective_axis_dimension(
205
+ batch_size,
206
+ fixed_dimension=OnnxConfig.default_fixed_batch,
207
+ num_token_to_add=0,
208
+ )
209
+
210
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
211
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
212
+ seq_length = compute_effective_axis_dimension(
213
+ seq_length,
214
+ fixed_dimension=OnnxConfig.default_fixed_sequence,
215
+ num_token_to_add=token_to_add,
216
+ )
217
+
218
+ # Generate dummy inputs according to compute batch and sequence
219
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
220
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
221
+ return common_inputs
222
+
223
+ # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
224
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
225
+ self,
226
+ tokenizer: PreTrainedTokenizer,
227
+ batch_size: int = -1,
228
+ seq_length: int = -1,
229
+ is_pair: bool = False,
230
+ framework: Optional[TensorType] = None,
231
+ ) -> Mapping[str, Any]:
232
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
233
+ tokenizer, batch_size, seq_length, is_pair, framework
234
+ )
235
+
236
+ # Generate decoder inputs
237
+ decoder_seq_length = seq_length if not self.use_past else 1
238
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
239
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
240
+ )
241
+ decoder_inputs = {
242
+ f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
243
+ }
244
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
245
+
246
+ if self.use_past:
247
+ if not is_torch_available():
248
+ raise ValueError(
249
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
250
+ )
251
+ else:
252
+ import torch
253
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
254
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
255
+ (
256
+ num_encoder_attention_heads,
257
+ num_decoder_attention_heads,
258
+ ) = self.num_attention_heads
259
+ encoder_shape = (
260
+ batch,
261
+ num_encoder_attention_heads,
262
+ encoder_seq_length,
263
+ self._config.hidden_size // num_encoder_attention_heads,
264
+ )
265
+ decoder_past_length = decoder_seq_length + 3
266
+ decoder_shape = (
267
+ batch,
268
+ num_decoder_attention_heads,
269
+ decoder_past_length,
270
+ self._config.hidden_size // num_decoder_attention_heads,
271
+ )
272
+
273
+ common_inputs["decoder_attention_mask"] = torch.cat(
274
+ [
275
+ common_inputs["decoder_attention_mask"],
276
+ torch.ones(batch, decoder_past_length),
277
+ ],
278
+ dim=1,
279
+ )
280
+
281
+ common_inputs["past_key_values"] = []
282
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
283
+ num_encoder_layers, num_decoder_layers = self.num_layers
284
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
285
+ max_num_layers = (
286
+ max(num_encoder_layers, num_decoder_layers) - min_num_layers
287
+ )
288
+ remaining_side_name = (
289
+ "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
290
+ )
291
+
292
+ for _ in range(min_num_layers):
293
+ common_inputs["past_key_values"].append(
294
+ (
295
+ torch.zeros(decoder_shape),
296
+ torch.zeros(decoder_shape),
297
+ torch.zeros(encoder_shape),
298
+ torch.zeros(encoder_shape),
299
+ )
300
+ )
301
+ # TODO: test this.
302
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
303
+ for _ in range(min_num_layers, max_num_layers):
304
+ common_inputs["past_key_values"].append(
305
+ (torch.zeros(shape), torch.zeros(shape))
306
+ )
307
+ return common_inputs
308
+
309
+ generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "decoder_start_token_id": 2,
5
+ "eos_token_id": 2,
6
+ "pad_token_id": 1,
7
+ "transformers_version": "4.44.2"
8
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18c168d82ac3779b336e4e660cc5381ed4cbb05c7778b094a455d211426045f6
3
  size 2231178416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d524ff246017aa6469a29817518dad6f01f70e907d9c12d3194000645967b670
3
  size 2231178416
modeling_indictrans.py ADDED
@@ -0,0 +1,1801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans model."""
16
+
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_attn_mask_utils import (
28
+ _prepare_4d_attention_mask,
29
+ _prepare_4d_attention_mask_for_sdpa,
30
+ _prepare_4d_causal_attention_mask,
31
+ _prepare_4d_causal_attention_mask_for_sdpa,
32
+ )
33
+
34
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutput,
37
+ BaseModelOutputWithPastAndCrossAttentions,
38
+ Seq2SeqLMOutput,
39
+ Seq2SeqModelOutput
40
+ )
41
+
42
+ from transformers.utils import (
43
+ logging,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ )
47
+
48
+ from transformers.modeling_utils import PreTrainedModel
49
+
50
+ from .configuration_indictrans import IndicTransConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
+
57
+ try:
58
+ if is_flash_attn_2_available():
59
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
60
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
+ except:
62
+ pass
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
66
+ def _get_unpad_data(attention_mask):
67
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
70
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
71
+ return (
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ )
76
+
77
+
78
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
79
+ def shift_tokens_right(
80
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
81
+ ):
82
+ """
83
+ Shift input ids one token to the right.
84
+ """
85
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
86
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
87
+ shifted_input_ids[:, 0] = decoder_start_token_id
88
+
89
+ if pad_token_id is None:
90
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
91
+ # replace possible -100 values in labels by `pad_token_id`
92
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
93
+
94
+ return shifted_input_ids
95
+
96
+
97
+ def create_position_ids_from_input_ids(
98
+ input_ids, padding_idx, past_key_values_length=0
99
+ ):
100
+ """
101
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
102
+ are ignored. This is modified from fairseq's `utils.make_positions`.
103
+ """
104
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
105
+ mask = input_ids.ne(padding_idx).int()
106
+ incremental_indices = (
107
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
108
+ ) * mask
109
+ return incremental_indices.long() + padding_idx
110
+
111
+
112
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
113
+ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
114
+ """This module produces sinusoidal positional embeddings of any length."""
115
+
116
+ def __init__(
117
+ self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
118
+ ):
119
+ super().__init__()
120
+ self.offset = 2
121
+ self.embedding_dim = embedding_dim
122
+ self.padding_idx = padding_idx
123
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
124
+
125
+ def make_weights(
126
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
127
+ ):
128
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
129
+ if hasattr(self, "weights"):
130
+ # in forward put the weights on the correct dtype and device of the param
131
+ emb_weights = emb_weights.to(
132
+ dtype=self.weights.dtype, device=self.weights.device
133
+ )
134
+
135
+ self.register_buffer("weights", emb_weights, persistent=False)
136
+
137
+ @staticmethod
138
+ def get_embedding(
139
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
140
+ ):
141
+ """
142
+ Build sinusoidal embeddings.
143
+
144
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
145
+ "Attention Is All You Need".
146
+ """
147
+ half_dim = embedding_dim // 2
148
+ emb = math.log(10000) / (half_dim - 1)
149
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
150
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
151
+ 1
152
+ ) * emb.unsqueeze(0)
153
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
154
+ num_embeddings, -1
155
+ )
156
+ if embedding_dim % 2 == 1:
157
+ # zero pad
158
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
159
+ if padding_idx is not None:
160
+ emb[padding_idx, :] = 0
161
+
162
+ return emb.to(torch.get_default_dtype())
163
+
164
+ @torch.no_grad()
165
+ def forward(
166
+ self,
167
+ input_ids: torch.Tensor = None,
168
+ inputs_embeds: torch.Tensor = None,
169
+ past_key_values_length: int = 0,
170
+ ):
171
+ if input_ids is not None:
172
+ bsz, seq_len = input_ids.size()
173
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
174
+ position_ids = create_position_ids_from_input_ids(
175
+ input_ids, self.padding_idx, past_key_values_length
176
+ ).to(input_ids.device)
177
+ else:
178
+ bsz, seq_len = inputs_embeds.size()[:-1]
179
+ position_ids = self.create_position_ids_from_inputs_embeds(
180
+ inputs_embeds, past_key_values_length
181
+ )
182
+
183
+ # expand embeddings if needed
184
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
185
+ if max_pos > self.weights.size(0):
186
+ self.make_weights(
187
+ max_pos + self.offset, self.embedding_dim, self.padding_idx
188
+ )
189
+
190
+ return (
191
+ self.weights.index_select(0, position_ids.view(-1))
192
+ .view(bsz, seq_len, self.weights.shape[-1])
193
+ .detach()
194
+ )
195
+
196
+ def create_position_ids_from_inputs_embeds(
197
+ self, inputs_embeds, past_key_values_length
198
+ ):
199
+ """
200
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
201
+
202
+ Args:
203
+ inputs_embeds: torch.Tensor
204
+
205
+ Returns: torch.Tensor
206
+ """
207
+ input_shape = inputs_embeds.size()[:-1]
208
+ sequence_length = input_shape[1]
209
+
210
+ position_ids = torch.arange(
211
+ self.padding_idx + 1,
212
+ sequence_length + self.padding_idx + 1,
213
+ dtype=torch.long,
214
+ device=inputs_embeds.device,
215
+ )
216
+ return (
217
+ position_ids.unsqueeze(0).expand(input_shape).contiguous()
218
+ + past_key_values_length
219
+ )
220
+
221
+
222
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
223
+ class IndicTransAttention(nn.Module):
224
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
225
+
226
+ def __init__(
227
+ self,
228
+ embed_dim: int,
229
+ num_heads: int,
230
+ dropout: float = 0.0,
231
+ is_decoder: bool = False,
232
+ bias: bool = True,
233
+ is_causal: bool = False,
234
+ config: Optional[IndicTransConfig] = None,
235
+ ):
236
+ super().__init__()
237
+ self.embed_dim = embed_dim
238
+ self.num_heads = num_heads
239
+ self.dropout = dropout
240
+ self.head_dim = embed_dim // num_heads
241
+ self.config = config
242
+
243
+ if (self.head_dim * num_heads) != self.embed_dim:
244
+ raise ValueError(
245
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
246
+ f" and `num_heads`: {num_heads})."
247
+ )
248
+ self.scaling = self.head_dim**-0.5
249
+ self.is_decoder = is_decoder
250
+ self.is_causal = is_causal
251
+
252
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
253
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
254
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
255
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
256
+
257
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
258
+ return (
259
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
260
+ .transpose(1, 2)
261
+ .contiguous()
262
+ )
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ key_value_states: Optional[torch.Tensor] = None,
268
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
269
+ attention_mask: Optional[torch.Tensor] = None,
270
+ layer_head_mask: Optional[torch.Tensor] = None,
271
+ output_attentions: bool = False,
272
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
273
+ """Input shape: Batch x Time x Channel"""
274
+
275
+ # if key_value_states are provided this layer is used as a cross-attention layer
276
+ # for the decoder
277
+ is_cross_attention = key_value_states is not None
278
+
279
+ bsz, tgt_len, _ = hidden_states.size()
280
+
281
+ # get query proj
282
+ query_states = self.q_proj(hidden_states) * self.scaling
283
+ # get key, value proj
284
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
285
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
286
+ # the provided `key_value_states` to support prefix tuning
287
+ if (
288
+ is_cross_attention
289
+ and past_key_value is not None
290
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
291
+ ):
292
+ # reuse k,v, cross_attentions
293
+ key_states = past_key_value[0]
294
+ value_states = past_key_value[1]
295
+ elif is_cross_attention:
296
+ # cross_attentions
297
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
298
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
299
+ elif past_key_value is not None:
300
+ # reuse k, v, self_attention
301
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
302
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
303
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
304
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
305
+ else:
306
+ # self_attention
307
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
308
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
309
+
310
+ if self.is_decoder:
311
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
312
+ # Further calls to cross_attention layer can then reuse all cross-attention
313
+ # key/value_states (first "if" case)
314
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
315
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
316
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
317
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
318
+ past_key_value = (key_states, value_states)
319
+
320
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
321
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
322
+ key_states = key_states.reshape(*proj_shape)
323
+ value_states = value_states.reshape(*proj_shape)
324
+
325
+ src_len = key_states.size(1)
326
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
327
+
328
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
329
+ raise ValueError(
330
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
331
+ f" {attn_weights.size()}"
332
+ )
333
+
334
+ if attention_mask is not None:
335
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
336
+ raise ValueError(
337
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
338
+ )
339
+ attn_weights = (
340
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
341
+ + attention_mask
342
+ )
343
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
344
+
345
+ attn_weights = F.softmax(attn_weights, dim=-1)
346
+
347
+ if layer_head_mask is not None:
348
+ if layer_head_mask.size() != (self.num_heads,):
349
+ raise ValueError(
350
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
351
+ f" {layer_head_mask.size()}"
352
+ )
353
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
354
+ bsz, self.num_heads, tgt_len, src_len
355
+ )
356
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
357
+
358
+ if output_attentions:
359
+ # this operation is a bit awkward, but it's required to
360
+ # make sure that attn_weights keeps its gradient.
361
+ # In order to do so, attn_weights have to be reshaped
362
+ # twice and have to be reused in the following
363
+ attn_weights_reshaped = attn_weights.view(
364
+ bsz, self.num_heads, tgt_len, src_len
365
+ )
366
+ attn_weights = attn_weights_reshaped.view(
367
+ bsz * self.num_heads, tgt_len, src_len
368
+ )
369
+ else:
370
+ attn_weights_reshaped = None
371
+
372
+ attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
373
+
374
+ attn_output = torch.bmm(attn_probs, value_states)
375
+
376
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
377
+ raise ValueError(
378
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
379
+ f" {attn_output.size()}"
380
+ )
381
+
382
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
383
+ attn_output = attn_output.transpose(1, 2)
384
+
385
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
386
+ # partitioned across GPUs when using tensor-parallelism.
387
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
388
+
389
+ attn_output = self.out_proj(attn_output)
390
+
391
+ return attn_output, attn_weights_reshaped, past_key_value
392
+
393
+
394
+ class IndicTransFlashAttention2(IndicTransAttention):
395
+ """
396
+ IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
397
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
398
+ flash attention and deal with padding tokens in case the input contains any of them.
399
+ """
400
+
401
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
402
+ def __init__(self, *args, **kwargs):
403
+ super().__init__(*args, **kwargs)
404
+
405
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
406
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
407
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
408
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
409
+
410
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
411
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: torch.Tensor,
416
+ key_value_states: Optional[torch.Tensor] = None,
417
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
418
+ attention_mask: Optional[torch.Tensor] = None,
419
+ layer_head_mask: Optional[torch.Tensor] = None,
420
+ output_attentions: bool = False,
421
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
422
+ # IndicTransFlashAttention2 attention does not support output_attentions
423
+ if output_attentions:
424
+ raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
425
+
426
+ # if key_value_states are provided this layer is used as a cross-attention layer
427
+ # for the decoder
428
+ is_cross_attention = key_value_states is not None
429
+
430
+ bsz, q_len, _ = hidden_states.size()
431
+
432
+ # get query proj
433
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
434
+ # get key, value proj
435
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
436
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
437
+ # the provided `key_value_states` to support prefix tuning
438
+ if (
439
+ is_cross_attention
440
+ and past_key_value is not None
441
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
442
+ ):
443
+ # reuse k,v, cross_attentions
444
+ key_states = past_key_value[0].transpose(1, 2)
445
+ value_states = past_key_value[1].transpose(1, 2)
446
+ elif is_cross_attention:
447
+ # cross_attentions
448
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
449
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
450
+ elif past_key_value is not None:
451
+ # reuse k, v, self_attention
452
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
453
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
454
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
455
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
456
+ else:
457
+ # self_attention
458
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
459
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
460
+
461
+ if self.is_decoder:
462
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
463
+ # Further calls to cross_attention layer can then reuse all cross-attention
464
+ # key/value_states (first "if" case)
465
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
466
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
467
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
468
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
469
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
470
+
471
+ kv_seq_len = key_states.shape[-2]
472
+ if past_key_value is not None:
473
+ kv_seq_len += past_key_value[0].shape[-2]
474
+
475
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
476
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
477
+ # cast them back in the correct dtype just to be sure everything works as expected.
478
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
479
+ # in fp32. (LlamaRMSNorm handles it correctly)
480
+
481
+ input_dtype = query_states.dtype
482
+ if input_dtype == torch.float32:
483
+ if torch.is_autocast_enabled():
484
+ target_dtype = torch.get_autocast_gpu_dtype()
485
+ # Handle the case where the model is quantized
486
+ elif hasattr(self.config, "_pre_quantization_dtype"):
487
+ target_dtype = self.config._pre_quantization_dtype
488
+ else:
489
+ target_dtype = self.q_proj.weight.dtype
490
+
491
+ logger.warning_once(
492
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
493
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
494
+ f" {target_dtype}."
495
+ )
496
+
497
+ query_states = query_states.to(target_dtype)
498
+ key_states = key_states.to(target_dtype)
499
+ value_states = value_states.to(target_dtype)
500
+
501
+ attn_output = self._flash_attention_forward(
502
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
503
+ )
504
+
505
+ attn_output = attn_output.reshape(bsz, q_len, -1)
506
+ attn_output = self.out_proj(attn_output)
507
+
508
+ if not output_attentions:
509
+ attn_weights = None
510
+
511
+ return attn_output, attn_weights, past_key_value
512
+
513
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
514
+ def _flash_attention_forward(
515
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
516
+ ):
517
+ """
518
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
519
+ first unpad the input, then computes the attention scores and pad the final attention scores.
520
+
521
+ Args:
522
+ query_states (`torch.Tensor`):
523
+ Input query states to be passed to Flash Attention API
524
+ key_states (`torch.Tensor`):
525
+ Input key states to be passed to Flash Attention API
526
+ value_states (`torch.Tensor`):
527
+ Input value states to be passed to Flash Attention API
528
+ attention_mask (`torch.Tensor`):
529
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
530
+ position of padding tokens and 1 for the position of non-padding tokens.
531
+ dropout (`float`):
532
+ Attention dropout
533
+ softmax_scale (`float`, *optional*):
534
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
535
+ """
536
+ if not self._flash_attn_uses_top_left_mask:
537
+ causal = self.is_causal
538
+ else:
539
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
540
+ causal = self.is_causal and query_length != 1
541
+
542
+ # Contains at least one padding token in the sequence
543
+ if attention_mask is not None:
544
+ batch_size = query_states.shape[0]
545
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
546
+ query_states, key_states, value_states, attention_mask, query_length
547
+ )
548
+
549
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
550
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
551
+
552
+ attn_output_unpad = flash_attn_varlen_func(
553
+ query_states,
554
+ key_states,
555
+ value_states,
556
+ cu_seqlens_q=cu_seqlens_q,
557
+ cu_seqlens_k=cu_seqlens_k,
558
+ max_seqlen_q=max_seqlen_in_batch_q,
559
+ max_seqlen_k=max_seqlen_in_batch_k,
560
+ dropout_p=dropout,
561
+ softmax_scale=softmax_scale,
562
+ causal=causal,
563
+ )
564
+
565
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
566
+ else:
567
+ attn_output = flash_attn_func(
568
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
569
+ )
570
+
571
+ return attn_output
572
+
573
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
574
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
575
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
576
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
577
+
578
+ key_layer = index_first_axis(
579
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
580
+ )
581
+ value_layer = index_first_axis(
582
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
583
+ )
584
+ if query_length == kv_seq_len:
585
+ query_layer = index_first_axis(
586
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
587
+ )
588
+ cu_seqlens_q = cu_seqlens_k
589
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
590
+ indices_q = indices_k
591
+ elif query_length == 1:
592
+ max_seqlen_in_batch_q = 1
593
+ cu_seqlens_q = torch.arange(
594
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
595
+ ) # There is a memcpy here, that is very bad.
596
+ indices_q = cu_seqlens_q[:-1]
597
+ query_layer = query_layer.squeeze(1)
598
+ else:
599
+ # The -q_len: slice assumes left padding.
600
+ attention_mask = attention_mask[:, -query_length:]
601
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
602
+
603
+ return (
604
+ query_layer,
605
+ key_layer,
606
+ value_layer,
607
+ indices_q,
608
+ (cu_seqlens_q, cu_seqlens_k),
609
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
610
+ )
611
+
612
+
613
+ class IndicTransSdpaAttention(IndicTransAttention):
614
+ def forward(
615
+ self,
616
+ hidden_states: torch.Tensor,
617
+ key_value_states: Optional[torch.Tensor] = None,
618
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
619
+ attention_mask: Optional[torch.Tensor] = None,
620
+ layer_head_mask: Optional[torch.Tensor] = None,
621
+ output_attentions: bool = False,
622
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
623
+ """Input shape: Batch x Time x Channel"""
624
+ if output_attentions or layer_head_mask is not None:
625
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
626
+ logger.warning_once(
627
+ "IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
628
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
629
+ )
630
+ return super().forward(
631
+ hidden_states,
632
+ key_value_states=key_value_states,
633
+ past_key_value=past_key_value,
634
+ attention_mask=attention_mask,
635
+ layer_head_mask=layer_head_mask,
636
+ output_attentions=output_attentions,
637
+ )
638
+
639
+ # if key_value_states are provided this layer is used as a cross-attention layer
640
+ # for the decoder
641
+ is_cross_attention = key_value_states is not None
642
+
643
+ bsz, tgt_len, _ = hidden_states.size()
644
+
645
+ # get query proj
646
+ query_states = self.q_proj(hidden_states)
647
+ # get key, value proj
648
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
649
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
650
+ # the provided `key_value_states` to support prefix tuning
651
+ if (
652
+ is_cross_attention
653
+ and past_key_value is not None
654
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
655
+ ):
656
+ # reuse k,v, cross_attentions
657
+ key_states = past_key_value[0]
658
+ value_states = past_key_value[1]
659
+ elif is_cross_attention:
660
+ # cross_attentions
661
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
662
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
663
+ elif past_key_value is not None:
664
+ # reuse k, v, self_attention
665
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
666
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
667
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
668
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
669
+ else:
670
+ # self_attention
671
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
672
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
673
+
674
+ if self.is_decoder:
675
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
676
+ # Further calls to cross_attention layer can then reuse all cross-attention
677
+ # key/value_states (first "if" case)
678
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
679
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
680
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
681
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
682
+ past_key_value = (key_states, value_states)
683
+
684
+ query_states = self._shape(query_states, tgt_len, bsz)
685
+
686
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
687
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
688
+ attn_output = F.scaled_dot_product_attention(
689
+ query_states,
690
+ key_states,
691
+ value_states,
692
+ attn_mask=attention_mask,
693
+ dropout_p=self.dropout if self.training else 0.0,
694
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
695
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
696
+ )
697
+
698
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
699
+ raise ValueError(
700
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
701
+ f" {attn_output.size()}"
702
+ )
703
+
704
+ attn_output = attn_output.transpose(1, 2)
705
+
706
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
707
+ # partitioned across GPUs when using tensor-parallelism.
708
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
709
+
710
+ attn_output = self.out_proj(attn_output)
711
+
712
+ return attn_output, None, past_key_value
713
+
714
+
715
+ INDICTRANS_ATTENTION_CLASSES = {
716
+ "eager": IndicTransAttention,
717
+ "sdpa": IndicTransSdpaAttention,
718
+ "flash_attention_2": IndicTransFlashAttention2,
719
+ }
720
+
721
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
722
+ class IndicTransEncoderLayer(nn.Module):
723
+ def __init__(self, config: IndicTransConfig):
724
+ super().__init__()
725
+ self.embed_dim = config.encoder_embed_dim
726
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
727
+ embed_dim=self.embed_dim,
728
+ num_heads=config.encoder_attention_heads,
729
+ dropout=config.attention_dropout,
730
+ config=config,
731
+ )
732
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
733
+ self.dropout = config.dropout
734
+ self.activation_fn = ACT2FN[config.activation_function]
735
+ self.activation_dropout = config.activation_dropout
736
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
737
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
738
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
739
+ self.normalize_before = config.encoder_normalize_before
740
+
741
+ def forward(
742
+ self,
743
+ hidden_states: torch.Tensor,
744
+ attention_mask: torch.Tensor,
745
+ layer_head_mask: torch.Tensor,
746
+ output_attentions: bool = False,
747
+ ) -> torch.Tensor:
748
+ """
749
+ Args:
750
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
751
+ attention_mask (`torch.FloatTensor`): attention mask of size
752
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
753
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
754
+ `(encoder_attention_heads,)`.
755
+ output_attentions (`bool`, *optional*):
756
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
757
+ returned tensors for more detail.
758
+ """
759
+ residual = hidden_states
760
+ if self.normalize_before:
761
+ hidden_states = self.self_attn_layer_norm(hidden_states)
762
+ hidden_states, attn_weights, _ = self.self_attn(
763
+ hidden_states=hidden_states,
764
+ attention_mask=attention_mask,
765
+ layer_head_mask=layer_head_mask,
766
+ output_attentions=output_attentions,
767
+ )
768
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
769
+ hidden_states = residual + hidden_states
770
+ if not self.normalize_before:
771
+ hidden_states = self.self_attn_layer_norm(hidden_states)
772
+
773
+ residual = hidden_states
774
+ if self.normalize_before:
775
+ hidden_states = self.final_layer_norm(hidden_states)
776
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
777
+ hidden_states = F.dropout(
778
+ hidden_states, p=self.activation_dropout, training=self.training
779
+ )
780
+ hidden_states = self.fc2(hidden_states)
781
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
782
+ hidden_states = residual + hidden_states
783
+ if not self.normalize_before:
784
+ hidden_states = self.final_layer_norm(hidden_states)
785
+
786
+ if hidden_states.dtype == torch.float16 and (
787
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
788
+ ):
789
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
790
+ hidden_states = torch.clamp(
791
+ hidden_states, min=-clamp_value, max=clamp_value
792
+ )
793
+
794
+ outputs = (hidden_states,)
795
+
796
+ if output_attentions:
797
+ outputs += (attn_weights,)
798
+
799
+ return outputs
800
+
801
+
802
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
803
+ class IndicTransDecoderLayer(nn.Module):
804
+ def __init__(self, config: IndicTransConfig):
805
+ super().__init__()
806
+ self.embed_dim = config.decoder_embed_dim
807
+
808
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
809
+ embed_dim=self.embed_dim,
810
+ num_heads=config.decoder_attention_heads,
811
+ dropout=config.attention_dropout,
812
+ is_decoder=True,
813
+ is_causal=True,
814
+ config=config,
815
+ )
816
+ self.dropout = config.dropout
817
+ self.activation_fn = ACT2FN[config.activation_function]
818
+ self.activation_dropout = config.activation_dropout
819
+
820
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
821
+ self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
822
+ self.embed_dim,
823
+ config.decoder_attention_heads,
824
+ dropout=config.attention_dropout,
825
+ is_decoder=True,
826
+ config=config,
827
+ )
828
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
829
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
830
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
831
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
832
+ self.normalize_before = config.decoder_normalize_before
833
+
834
+ def forward(
835
+ self,
836
+ hidden_states: torch.Tensor,
837
+ attention_mask: Optional[torch.Tensor] = None,
838
+ encoder_hidden_states: Optional[torch.Tensor] = None,
839
+ encoder_attention_mask: Optional[torch.Tensor] = None,
840
+ layer_head_mask: Optional[torch.Tensor] = None,
841
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
842
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
843
+ output_attentions: Optional[bool] = False,
844
+ use_cache: Optional[bool] = True,
845
+ ) -> torch.Tensor:
846
+ """
847
+ Args:
848
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
849
+ attention_mask (`torch.FloatTensor`): attention mask of size
850
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
851
+ encoder_hidden_states (`torch.FloatTensor`):
852
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
853
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
854
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
855
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
856
+ `(encoder_attention_heads,)`.
857
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
858
+ size `(decoder_attention_heads,)`.
859
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
860
+ output_attentions (`bool`, *optional*):
861
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
862
+ returned tensors for more detail.
863
+ """
864
+ residual = hidden_states
865
+ if self.normalize_before:
866
+ hidden_states = self.self_attn_layer_norm(hidden_states)
867
+
868
+ # Self Attention
869
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
870
+ self_attn_past_key_value = (
871
+ past_key_value[:2] if past_key_value is not None else None
872
+ )
873
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
874
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
875
+ hidden_states=hidden_states,
876
+ past_key_value=self_attn_past_key_value,
877
+ attention_mask=attention_mask,
878
+ layer_head_mask=layer_head_mask,
879
+ output_attentions=output_attentions,
880
+ )
881
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
882
+ hidden_states = residual + hidden_states
883
+ if not self.normalize_before:
884
+ hidden_states = self.self_attn_layer_norm(hidden_states)
885
+
886
+ # Cross-Attention Block
887
+ cross_attn_present_key_value = None
888
+ cross_attn_weights = None
889
+ if encoder_hidden_states is not None:
890
+ residual = hidden_states
891
+ if self.normalize_before:
892
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
893
+
894
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
895
+ cross_attn_past_key_value = (
896
+ past_key_value[-2:] if past_key_value is not None else None
897
+ )
898
+ (
899
+ hidden_states,
900
+ cross_attn_weights,
901
+ cross_attn_present_key_value,
902
+ ) = self.encoder_attn(
903
+ hidden_states=hidden_states,
904
+ key_value_states=encoder_hidden_states,
905
+ attention_mask=encoder_attention_mask,
906
+ layer_head_mask=cross_attn_layer_head_mask,
907
+ past_key_value=cross_attn_past_key_value,
908
+ output_attentions=output_attentions,
909
+ )
910
+ hidden_states = F.dropout(
911
+ hidden_states, p=self.dropout, training=self.training
912
+ )
913
+ hidden_states = residual + hidden_states
914
+ if not self.normalize_before:
915
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
916
+
917
+ # add cross-attn to positions 3,4 of present_key_value tuple
918
+ present_key_value = present_key_value + cross_attn_present_key_value
919
+
920
+ # Fully Connected
921
+ residual = hidden_states
922
+ if self.normalize_before:
923
+ hidden_states = self.final_layer_norm(hidden_states)
924
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
925
+ hidden_states = F.dropout(
926
+ hidden_states, p=self.activation_dropout, training=self.training
927
+ )
928
+ hidden_states = self.fc2(hidden_states)
929
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
930
+ hidden_states = residual + hidden_states
931
+ if not self.normalize_before:
932
+ hidden_states = self.final_layer_norm(hidden_states)
933
+
934
+ outputs = (hidden_states,)
935
+
936
+ if output_attentions:
937
+ outputs += (self_attn_weights, cross_attn_weights)
938
+
939
+ if use_cache:
940
+ outputs += (present_key_value,)
941
+
942
+ return outputs
943
+
944
+
945
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
946
+ class IndicTransPreTrainedModel(PreTrainedModel):
947
+ config_class = IndicTransConfig
948
+ base_model_prefix = "model"
949
+ supports_gradient_checkpointing = True
950
+ _no_split_modules = ["IndicTransAttention"]
951
+
952
+ def _init_weights(self, module):
953
+ std = self.config.init_std
954
+ if isinstance(module, nn.Linear):
955
+ module.weight.data.normal_(mean=0.0, std=std)
956
+ if module.bias is not None:
957
+ module.bias.data.zero_()
958
+ elif isinstance(module, nn.Embedding):
959
+ module.weight.data.normal_(mean=0.0, std=std)
960
+ if module.padding_idx is not None:
961
+ module.weight.data[module.padding_idx].zero_()
962
+
963
+ def _set_gradient_checkpointing(self, module, value=False):
964
+ if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
965
+ module.gradient_checkpointing = value
966
+
967
+
968
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
969
+ class IndicTransEncoder(IndicTransPreTrainedModel):
970
+ """
971
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
972
+ [`IndicTransEncoderLayer`].
973
+
974
+ Args:
975
+ config: IndicTransConfig
976
+ embed_tokens (nn.Embedding): output embedding
977
+ """
978
+
979
+ def __init__(
980
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
981
+ ):
982
+ super().__init__(config)
983
+
984
+ self.dropout = config.dropout
985
+ self.layerdrop = config.encoder_layerdrop
986
+
987
+ embed_dim = config.encoder_embed_dim
988
+ self.padding_idx = config.pad_token_id
989
+ self.max_source_positions = config.max_source_positions
990
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
991
+
992
+ self.embed_tokens = nn.Embedding(
993
+ config.encoder_vocab_size, embed_dim, self.padding_idx
994
+ )
995
+
996
+ if embed_tokens is not None:
997
+ self.embed_tokens.weight = embed_tokens.weight
998
+
999
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1000
+ config.max_source_positions,
1001
+ embed_dim,
1002
+ self.padding_idx,
1003
+ )
1004
+ self.layers = nn.ModuleList(
1005
+ [IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
1006
+ )
1007
+ self.layer_norm = (
1008
+ nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
1009
+ )
1010
+ self.layernorm_embedding = (
1011
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1012
+ )
1013
+
1014
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1015
+ self._use_sdpa = config._attn_implementation == "sdpa"
1016
+
1017
+ self.gradient_checkpointing = False
1018
+ # Initialize weights and apply final processing
1019
+ self.post_init()
1020
+
1021
+ def forward(
1022
+ self,
1023
+ input_ids: Optional[torch.Tensor] = None,
1024
+ attention_mask: Optional[torch.Tensor] = None,
1025
+ head_mask: Optional[torch.Tensor] = None,
1026
+ inputs_embeds: Optional[torch.Tensor] = None,
1027
+ output_attentions: Optional[bool] = None,
1028
+ output_hidden_states: Optional[bool] = None,
1029
+ return_dict: Optional[bool] = None,
1030
+ ):
1031
+ r"""
1032
+ Args:
1033
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1034
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1035
+ provide it.
1036
+
1037
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1038
+ [`PreTrainedTokenizer.__call__`] for details.
1039
+
1040
+ [What are input IDs?](../glossary#input-ids)
1041
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1042
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1043
+
1044
+ - 1 for tokens that are **not masked**,
1045
+ - 0 for tokens that are **masked**.
1046
+
1047
+ [What are attention masks?](../glossary#attention-mask)
1048
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
1049
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1050
+
1051
+ - 1 indicates the head is **not masked**,
1052
+ - 0 indicates the head is **masked**.
1053
+
1054
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1055
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1056
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1057
+ than the model's internal embedding lookup matrix.
1058
+ output_attentions (`bool`, *optional*):
1059
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1060
+ returned tensors for more detail.
1061
+ output_hidden_states (`bool`, *optional*):
1062
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1063
+ for more detail.
1064
+ return_dict (`bool`, *optional*):
1065
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1066
+ """
1067
+ output_attentions = (
1068
+ output_attentions
1069
+ if output_attentions is not None
1070
+ else self.config.output_attentions
1071
+ )
1072
+ output_hidden_states = (
1073
+ output_hidden_states
1074
+ if output_hidden_states is not None
1075
+ else self.config.output_hidden_states
1076
+ )
1077
+ return_dict = (
1078
+ return_dict if return_dict is not None else self.config.use_return_dict
1079
+ )
1080
+
1081
+ # retrieve input_ids and inputs_embeds
1082
+ if input_ids is not None and inputs_embeds is not None:
1083
+ raise ValueError(
1084
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1085
+ )
1086
+ elif input_ids is not None:
1087
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1088
+ input_shape = input_ids.size()
1089
+ input_ids = input_ids.view(-1, input_shape[-1])
1090
+ elif inputs_embeds is not None:
1091
+ input_shape = inputs_embeds.size()[:-1]
1092
+ else:
1093
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1094
+
1095
+ if inputs_embeds is None:
1096
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1097
+
1098
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
1099
+ embed_pos = embed_pos.to(inputs_embeds.device)
1100
+
1101
+ hidden_states = inputs_embeds + embed_pos
1102
+ if self.layernorm_embedding is not None:
1103
+ hidden_states = self.layernorm_embedding(hidden_states)
1104
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1105
+
1106
+ if attention_mask is not None:
1107
+ if self._use_flash_attention_2:
1108
+ attention_mask = attention_mask if 0 in attention_mask else None
1109
+ elif self._use_sdpa and head_mask is None and not output_attentions:
1110
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1111
+ # the manual implementation that requires a 4D causal mask in all cases.
1112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1113
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
1114
+ else:
1115
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1116
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1117
+
1118
+
1119
+ encoder_states = () if output_hidden_states else None
1120
+ all_attentions = () if output_attentions else None
1121
+
1122
+ # check if head_mask has a correct number of layers specified if desired
1123
+ if head_mask is not None:
1124
+ if head_mask.size()[0] != len(self.layers):
1125
+ raise ValueError(
1126
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
1127
+ f" {head_mask.size()[0]}."
1128
+ )
1129
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1130
+
1131
+ for idx, encoder_layer in enumerate(self.layers):
1132
+ if output_hidden_states:
1133
+ encoder_states = encoder_states + (hidden_states,)
1134
+
1135
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1136
+ dropout_probability = torch.rand([])
1137
+
1138
+ skip_the_layer = (
1139
+ True
1140
+ if self.training and (dropout_probability < self.layerdrop)
1141
+ else False
1142
+ )
1143
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1144
+ # under deepspeed zero3 all gpus must run in sync
1145
+
1146
+ if self.gradient_checkpointing and self.training:
1147
+ # create gradient checkpointing function
1148
+ def create_custom_forward(module):
1149
+ def custom_forward(*inputs):
1150
+ return module(*inputs, output_attentions)
1151
+
1152
+ return custom_forward
1153
+
1154
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1155
+ create_custom_forward(encoder_layer),
1156
+ hidden_states,
1157
+ attention_mask,
1158
+ (head_mask[idx] if head_mask is not None else None),
1159
+ )
1160
+ else:
1161
+ layer_outputs = encoder_layer(
1162
+ hidden_states,
1163
+ attention_mask,
1164
+ layer_head_mask=(
1165
+ head_mask[idx] if head_mask is not None else None
1166
+ ),
1167
+ output_attentions=output_attentions,
1168
+ )
1169
+
1170
+ hidden_states = layer_outputs[0]
1171
+
1172
+ if skip_the_layer:
1173
+ layer_outputs = (None, None)
1174
+
1175
+ if output_attentions:
1176
+ all_attentions = all_attentions + (layer_outputs[1],)
1177
+
1178
+ if self.layer_norm is not None:
1179
+ hidden_states = self.layer_norm(hidden_states)
1180
+
1181
+ if output_hidden_states:
1182
+ encoder_states = encoder_states + (hidden_states,)
1183
+
1184
+ if not return_dict:
1185
+ return tuple(
1186
+ v
1187
+ for v in [hidden_states, encoder_states, all_attentions]
1188
+ if v is not None
1189
+ )
1190
+ return BaseModelOutput(
1191
+ last_hidden_state=hidden_states,
1192
+ hidden_states=encoder_states,
1193
+ attentions=all_attentions,
1194
+ )
1195
+
1196
+
1197
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
1198
+ class IndicTransDecoder(IndicTransPreTrainedModel):
1199
+ """
1200
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
1201
+
1202
+ Args:
1203
+ config: IndicTransConfig
1204
+ embed_tokens (nn.Embedding): output embedding
1205
+ """
1206
+
1207
+ def __init__(
1208
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
1209
+ ):
1210
+ super().__init__(config)
1211
+ self.dropout = config.dropout
1212
+ self.layerdrop = config.decoder_layerdrop
1213
+
1214
+ embed_dim = config.encoder_embed_dim
1215
+ self.padding_idx = config.pad_token_id
1216
+ self.max_target_positions = config.max_target_positions
1217
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
1218
+
1219
+ self.embed_tokens = nn.Embedding(
1220
+ config.decoder_vocab_size, embed_dim, self.padding_idx
1221
+ )
1222
+
1223
+ if embed_tokens is not None:
1224
+ self.embed_tokens.weight = embed_tokens.weight
1225
+
1226
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1227
+ config.max_target_positions,
1228
+ embed_dim,
1229
+ self.padding_idx,
1230
+ )
1231
+ self.layers = nn.ModuleList(
1232
+ [IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
1233
+ )
1234
+ self.layer_norm = (
1235
+ nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
1236
+ )
1237
+ self.layernorm_embedding = (
1238
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1239
+ )
1240
+
1241
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1242
+ self._use_sdpa = config._attn_implementation == "sdpa"
1243
+
1244
+ self.gradient_checkpointing = False
1245
+ # Initialize weights and apply final processing
1246
+ self.post_init()
1247
+
1248
+ def forward(
1249
+ self,
1250
+ input_ids: Optional[torch.Tensor] = None,
1251
+ attention_mask: Optional[torch.Tensor] = None,
1252
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1253
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1254
+ head_mask: Optional[torch.Tensor] = None,
1255
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1256
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1257
+ inputs_embeds: Optional[torch.Tensor] = None,
1258
+ use_cache: Optional[bool] = None,
1259
+ output_attentions: Optional[bool] = None,
1260
+ output_hidden_states: Optional[bool] = None,
1261
+ return_dict: Optional[bool] = None,
1262
+ ):
1263
+ r"""
1264
+ Args:
1265
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1266
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1267
+ provide it.
1268
+
1269
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1270
+ [`PreTrainedTokenizer.__call__`] for details.
1271
+
1272
+ [What are input IDs?](../glossary#input-ids)
1273
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1274
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1275
+
1276
+ - 1 for tokens that are **not masked**,
1277
+ - 0 for tokens that are **masked**.
1278
+
1279
+ [What are attention masks?](../glossary#attention-mask)
1280
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1281
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1282
+ of the decoder.
1283
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1284
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
1285
+ selected in `[0, 1]`:
1286
+
1287
+ - 1 for tokens that are **not masked**,
1288
+ - 0 for tokens that are **masked**.
1289
+
1290
+ [What are attention masks?](../glossary#attention-mask)
1291
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1292
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1293
+
1294
+ - 1 indicates the head is **not masked**,
1295
+ - 0 indicates the head is **masked**.
1296
+
1297
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1298
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
1299
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
1300
+
1301
+ - 1 indicates the head is **not masked**,
1302
+ - 0 indicates the head is **masked**.
1303
+
1304
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1305
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1306
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1307
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1308
+
1309
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1310
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1311
+
1312
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1313
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1314
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
1315
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
1316
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
1317
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
1318
+ embedding lookup matrix.
1319
+ output_attentions (`bool`, *optional*):
1320
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1321
+ returned tensors for more detail.
1322
+ output_hidden_states (`bool`, *optional*):
1323
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1324
+ for more detail.
1325
+ return_dict (`bool`, *optional*):
1326
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1327
+ """
1328
+ output_attentions = (
1329
+ output_attentions
1330
+ if output_attentions is not None
1331
+ else self.config.output_attentions
1332
+ )
1333
+ output_hidden_states = (
1334
+ output_hidden_states
1335
+ if output_hidden_states is not None
1336
+ else self.config.output_hidden_states
1337
+ )
1338
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1339
+ return_dict = (
1340
+ return_dict if return_dict is not None else self.config.use_return_dict
1341
+ )
1342
+
1343
+ # retrieve input_ids and inputs_embeds
1344
+ if input_ids is not None and inputs_embeds is not None:
1345
+ raise ValueError(
1346
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1347
+ )
1348
+ elif input_ids is not None:
1349
+ input_shape = input_ids.size()
1350
+ input_ids = input_ids.view(-1, input_shape[-1])
1351
+ elif inputs_embeds is not None:
1352
+ input_shape = inputs_embeds.size()[:-1]
1353
+ else:
1354
+ raise ValueError(
1355
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1356
+ )
1357
+
1358
+ # past_key_values_length
1359
+ past_key_values_length = (
1360
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1361
+ )
1362
+
1363
+ if inputs_embeds is None:
1364
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1365
+
1366
+
1367
+ if self._use_flash_attention_2:
1368
+ # 2d mask is passed through the layers
1369
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1370
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1371
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1372
+ # the manual implementation that requires a 4D causal mask in all cases.
1373
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1374
+ attention_mask,
1375
+ input_shape,
1376
+ inputs_embeds,
1377
+ past_key_values_length,
1378
+ )
1379
+ else:
1380
+ # 4d mask is passed through the layers
1381
+ attention_mask = _prepare_4d_causal_attention_mask(
1382
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1383
+ )
1384
+
1385
+ # expand encoder attention mask
1386
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1387
+ if self._use_flash_attention_2:
1388
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1389
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
1390
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1391
+ # the manual implementation that requires a 4D causal mask in all cases.
1392
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1393
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1394
+ encoder_attention_mask,
1395
+ inputs_embeds.dtype,
1396
+ tgt_len=input_shape[-1],
1397
+ )
1398
+ else:
1399
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1400
+ encoder_attention_mask = _prepare_4d_attention_mask(
1401
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1402
+ )
1403
+
1404
+ # embed positions
1405
+ positions = self.embed_positions(
1406
+ input_ids, inputs_embeds, past_key_values_length
1407
+ )
1408
+ positions = positions.to(inputs_embeds.device)
1409
+
1410
+ hidden_states = inputs_embeds + positions
1411
+ if self.layernorm_embedding is not None:
1412
+ hidden_states = self.layernorm_embedding(hidden_states)
1413
+
1414
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1415
+
1416
+ if self.gradient_checkpointing and self.training:
1417
+ if use_cache:
1418
+ logger.warning_once(
1419
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
1420
+ " `use_cache=False`..."
1421
+ )
1422
+ use_cache = False
1423
+
1424
+ # decoder layers
1425
+ all_hidden_states = () if output_hidden_states else None
1426
+ all_self_attns = () if output_attentions else None
1427
+ all_cross_attentions = () if output_attentions else None
1428
+ next_decoder_cache = () if use_cache else None
1429
+
1430
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1431
+ for attn_mask, mask_name in zip(
1432
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1433
+ ):
1434
+ if attn_mask is not None:
1435
+ if attn_mask.size()[0] != len(self.layers):
1436
+ raise ValueError(
1437
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1438
+ f" {head_mask.size()[0]}."
1439
+ )
1440
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1441
+
1442
+ for idx, decoder_layer in enumerate(self.layers):
1443
+ if output_hidden_states:
1444
+ all_hidden_states += (hidden_states,)
1445
+
1446
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1447
+ dropout_probability = torch.rand([])
1448
+
1449
+ skip_the_layer = (
1450
+ True
1451
+ if self.training and (dropout_probability < self.layerdrop)
1452
+ else False
1453
+ )
1454
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1455
+ # under deepspeed zero3 all gpus must run in sync
1456
+
1457
+ past_key_value = (
1458
+ past_key_values[idx] if past_key_values is not None else None
1459
+ )
1460
+
1461
+ if self.gradient_checkpointing and self.training:
1462
+
1463
+ def create_custom_forward(module):
1464
+ def custom_forward(*inputs):
1465
+ # None for past_key_value
1466
+ return module(*inputs, output_attentions, use_cache)
1467
+
1468
+ return custom_forward
1469
+
1470
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1471
+ create_custom_forward(decoder_layer),
1472
+ hidden_states,
1473
+ attention_mask,
1474
+ encoder_hidden_states,
1475
+ encoder_attention_mask,
1476
+ head_mask[idx] if head_mask is not None else None,
1477
+ cross_attn_head_mask[idx]
1478
+ if cross_attn_head_mask is not None
1479
+ else None,
1480
+ None,
1481
+ )
1482
+ else:
1483
+ layer_outputs = decoder_layer(
1484
+ hidden_states,
1485
+ attention_mask=attention_mask,
1486
+ encoder_hidden_states=encoder_hidden_states,
1487
+ encoder_attention_mask=encoder_attention_mask,
1488
+ layer_head_mask=(
1489
+ head_mask[idx] if head_mask is not None else None
1490
+ ),
1491
+ cross_attn_layer_head_mask=(
1492
+ cross_attn_head_mask[idx]
1493
+ if cross_attn_head_mask is not None
1494
+ else None
1495
+ ),
1496
+ past_key_value=past_key_value,
1497
+ output_attentions=output_attentions,
1498
+ use_cache=use_cache,
1499
+ )
1500
+
1501
+ hidden_states = layer_outputs[0]
1502
+
1503
+ if skip_the_layer:
1504
+ continue
1505
+
1506
+ if use_cache:
1507
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1508
+
1509
+ if output_attentions:
1510
+ all_self_attns += (layer_outputs[1],)
1511
+ all_cross_attentions += (layer_outputs[2],)
1512
+
1513
+ if self.layer_norm is not None:
1514
+ hidden_states = self.layer_norm(hidden_states)
1515
+
1516
+ # add hidden states from the last decoder layer
1517
+ if output_hidden_states:
1518
+ all_hidden_states += (hidden_states,)
1519
+
1520
+ next_cache = next_decoder_cache if use_cache else None
1521
+ if not return_dict:
1522
+ return tuple(
1523
+ v
1524
+ for v in [
1525
+ hidden_states,
1526
+ next_cache,
1527
+ all_hidden_states,
1528
+ all_self_attns,
1529
+ all_cross_attentions,
1530
+ ]
1531
+ if v is not None
1532
+ )
1533
+ return BaseModelOutputWithPastAndCrossAttentions(
1534
+ last_hidden_state=hidden_states,
1535
+ past_key_values=next_cache,
1536
+ hidden_states=all_hidden_states,
1537
+ attentions=all_self_attns,
1538
+ cross_attentions=all_cross_attentions,
1539
+ )
1540
+
1541
+
1542
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
1543
+ class IndicTransModel(IndicTransPreTrainedModel):
1544
+ _tied_weights_keys = None
1545
+
1546
+ def __init__(self, config: IndicTransConfig):
1547
+ super().__init__(config)
1548
+
1549
+ self.encoder = IndicTransEncoder(config)
1550
+ self.decoder = IndicTransDecoder(config)
1551
+
1552
+ # Initialize weights and apply final processing
1553
+ self.post_init()
1554
+
1555
+ def get_encoder(self):
1556
+ return self.encoder
1557
+
1558
+ def get_decoder(self):
1559
+ return self.decoder
1560
+
1561
+ def forward(
1562
+ self,
1563
+ input_ids: Optional[torch.LongTensor] = None,
1564
+ attention_mask: Optional[torch.Tensor] = None,
1565
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1566
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1567
+ head_mask: Optional[torch.Tensor] = None,
1568
+ decoder_head_mask: Optional[torch.Tensor] = None,
1569
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1570
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1571
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1572
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1573
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1574
+ use_cache: Optional[bool] = None,
1575
+ output_attentions: Optional[bool] = None,
1576
+ output_hidden_states: Optional[bool] = None,
1577
+ return_dict: Optional[bool] = None,
1578
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1579
+ output_attentions = (
1580
+ output_attentions
1581
+ if output_attentions is not None
1582
+ else self.config.output_attentions
1583
+ )
1584
+ output_hidden_states = (
1585
+ output_hidden_states
1586
+ if output_hidden_states is not None
1587
+ else self.config.output_hidden_states
1588
+ )
1589
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1590
+ return_dict = (
1591
+ return_dict if return_dict is not None else self.config.use_return_dict
1592
+ )
1593
+
1594
+ if encoder_outputs is None:
1595
+ encoder_outputs = self.encoder(
1596
+ input_ids=input_ids,
1597
+ attention_mask=attention_mask,
1598
+ head_mask=head_mask,
1599
+ inputs_embeds=inputs_embeds,
1600
+ output_attentions=output_attentions,
1601
+ output_hidden_states=output_hidden_states,
1602
+ return_dict=return_dict,
1603
+ )
1604
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1605
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1606
+ encoder_outputs = BaseModelOutput(
1607
+ last_hidden_state=encoder_outputs[0],
1608
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1609
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1610
+ )
1611
+
1612
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1613
+ decoder_outputs = self.decoder(
1614
+ input_ids=decoder_input_ids,
1615
+ attention_mask=decoder_attention_mask,
1616
+ encoder_hidden_states=encoder_outputs[0],
1617
+ encoder_attention_mask=attention_mask,
1618
+ head_mask=decoder_head_mask,
1619
+ cross_attn_head_mask=cross_attn_head_mask,
1620
+ past_key_values=past_key_values,
1621
+ inputs_embeds=decoder_inputs_embeds,
1622
+ use_cache=use_cache,
1623
+ output_attentions=output_attentions,
1624
+ output_hidden_states=output_hidden_states,
1625
+ return_dict=return_dict,
1626
+ )
1627
+
1628
+ if not return_dict:
1629
+ return decoder_outputs + encoder_outputs
1630
+
1631
+ return Seq2SeqModelOutput(
1632
+ last_hidden_state=decoder_outputs.last_hidden_state,
1633
+ past_key_values=decoder_outputs.past_key_values,
1634
+ decoder_hidden_states=decoder_outputs.hidden_states,
1635
+ decoder_attentions=decoder_outputs.attentions,
1636
+ cross_attentions=decoder_outputs.cross_attentions,
1637
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1638
+ encoder_hidden_states=encoder_outputs.hidden_states,
1639
+ encoder_attentions=encoder_outputs.attentions,
1640
+ )
1641
+
1642
+
1643
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1644
+ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1645
+ base_model_prefix = "model"
1646
+ _tied_weights_keys = None
1647
+ _label_smoothing = 0.0
1648
+
1649
+ def __init__(self, config: IndicTransConfig):
1650
+ super().__init__(config)
1651
+ self.model = IndicTransModel(config)
1652
+ self.lm_head = nn.Linear(
1653
+ config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1654
+ )
1655
+
1656
+ if config.share_decoder_input_output_embed:
1657
+ self.lm_head.weight = self.model.decoder.embed_tokens.weight
1658
+
1659
+ self.post_init()
1660
+
1661
+ def tie_weights(self):
1662
+ pass
1663
+
1664
+ def get_encoder(self):
1665
+ return self.model.get_encoder()
1666
+
1667
+ def get_decoder(self):
1668
+ return self.model.get_decoder()
1669
+
1670
+ def get_output_embeddings(self):
1671
+ return self.lm_head
1672
+
1673
+ def set_output_embeddings(self, new_embeddings):
1674
+ self.lm_head = new_embeddings
1675
+
1676
+ def set_label_smoothing(self, label_smoothing):
1677
+ self._label_smoothing = label_smoothing
1678
+
1679
+ def forward(
1680
+ self,
1681
+ input_ids: Optional[torch.LongTensor] = None,
1682
+ attention_mask: Optional[torch.Tensor] = None,
1683
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1684
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1685
+ head_mask: Optional[torch.Tensor] = None,
1686
+ decoder_head_mask: Optional[torch.Tensor] = None,
1687
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1688
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1689
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1690
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1691
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1692
+ labels: Optional[torch.LongTensor] = None,
1693
+ use_cache: Optional[bool] = None,
1694
+ output_attentions: Optional[bool] = None,
1695
+ output_hidden_states: Optional[bool] = None,
1696
+ return_dict: Optional[bool] = None,
1697
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1698
+ r"""
1699
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1700
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1701
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1702
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1703
+
1704
+ Returns:
1705
+ """
1706
+ return_dict = (
1707
+ return_dict if return_dict is not None else self.config.use_return_dict
1708
+ )
1709
+
1710
+ if labels is not None:
1711
+ if decoder_input_ids is None:
1712
+ decoder_input_ids = shift_tokens_right(
1713
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1714
+ )
1715
+
1716
+ outputs = self.model(
1717
+ input_ids,
1718
+ attention_mask=attention_mask,
1719
+ decoder_input_ids=decoder_input_ids,
1720
+ encoder_outputs=encoder_outputs,
1721
+ decoder_attention_mask=decoder_attention_mask,
1722
+ head_mask=head_mask,
1723
+ decoder_head_mask=decoder_head_mask,
1724
+ cross_attn_head_mask=cross_attn_head_mask,
1725
+ past_key_values=past_key_values,
1726
+ inputs_embeds=inputs_embeds,
1727
+ decoder_inputs_embeds=decoder_inputs_embeds,
1728
+ use_cache=use_cache,
1729
+ output_attentions=output_attentions,
1730
+ output_hidden_states=output_hidden_states,
1731
+ return_dict=return_dict,
1732
+ )
1733
+ lm_logits = self.lm_head(outputs[0])
1734
+
1735
+ masked_lm_loss = None
1736
+ if labels is not None:
1737
+ # move labels to the correct device to enable PP
1738
+ labels = labels.to(lm_logits.device)
1739
+ masked_lm_loss = F.cross_entropy(
1740
+ input=lm_logits.view(-1, self.config.decoder_vocab_size),
1741
+ target=labels.view(-1),
1742
+ ignore_index=-100,
1743
+ label_smoothing=self._label_smoothing,
1744
+ )
1745
+
1746
+ if not return_dict:
1747
+ output = (lm_logits,) + outputs[1:]
1748
+ return (
1749
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1750
+ )
1751
+
1752
+ return Seq2SeqLMOutput(
1753
+ loss=masked_lm_loss,
1754
+ logits=lm_logits,
1755
+ past_key_values=outputs.past_key_values,
1756
+ decoder_hidden_states=outputs.decoder_hidden_states,
1757
+ decoder_attentions=outputs.decoder_attentions,
1758
+ cross_attentions=outputs.cross_attentions,
1759
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1760
+ encoder_hidden_states=outputs.encoder_hidden_states,
1761
+ encoder_attentions=outputs.encoder_attentions,
1762
+ )
1763
+
1764
+ def prepare_inputs_for_generation(
1765
+ self,
1766
+ decoder_input_ids,
1767
+ past_key_values=None,
1768
+ attention_mask=None,
1769
+ head_mask=None,
1770
+ decoder_head_mask=None,
1771
+ cross_attn_head_mask=None,
1772
+ use_cache=None,
1773
+ encoder_outputs=None,
1774
+ **kwargs,
1775
+ ):
1776
+ # cut decoder_input_ids if past is used
1777
+ if past_key_values is not None:
1778
+ decoder_input_ids = decoder_input_ids[:, -1:]
1779
+
1780
+ return {
1781
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1782
+ "encoder_outputs": encoder_outputs,
1783
+ "past_key_values": past_key_values,
1784
+ "decoder_input_ids": decoder_input_ids,
1785
+ "attention_mask": attention_mask,
1786
+ "head_mask": head_mask,
1787
+ "decoder_head_mask": decoder_head_mask,
1788
+ "cross_attn_head_mask": cross_attn_head_mask,
1789
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1790
+ }
1791
+
1792
+ @staticmethod
1793
+ def _reorder_cache(past_key_values, beam_idx):
1794
+ reordered_past = ()
1795
+ for layer_past in past_key_values:
1796
+ reordered_past += (
1797
+ tuple(
1798
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1799
+ ),
1800
+ )
1801
+ return reordered_past
runs/Aug26_07-49-43_ip-10-192-11-38/events.out.tfevents.1724658584.ip-10-192-11-38.8568.0 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f069cf9c3bb28fd588be83fa6450fd2065b8611e5c9a4742a57a5923e7938182
3
- size 129916
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d79327fa84d4cb241882c99c1df01e6ef18d4d3c4bdbad2deb5489e775b7423
3
+ size 130270