Poorvaja commited on
Commit
15338e9
·
verified ·
1 Parent(s): c9020d8

Upload modeling_falcon.py

Browse files
Files changed (1) hide show
  1. modeling_falcon.py +1262 -0
modeling_falcon.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Falcon model."""
16
+
17
+ import math
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
24
+ from torch.nn import functional as F
25
+
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPastAndCrossAttentions,
28
+ CausalLMOutputWithCrossAttentions,
29
+ QuestionAnsweringModelOutput,
30
+ SequenceClassifierOutputWithPast,
31
+ TokenClassifierOutput,
32
+ )
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
35
+ from .configuration_falcon import FalconConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
+ "tiiuae/falcon-40b",
42
+ "tiiuae/falcon-40b-instruct",
43
+ "tiiuae/falcon-7b",
44
+ "tiiuae/falcon-7b-instruct",
45
+ "tiiuae/falcon-rw-7b",
46
+ "tiiuae/falcon-rw-1b",
47
+ ]
48
+ _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
49
+ _CONFIG_FOR_DOC = "FalconConfig"
50
+
51
+
52
+ # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
53
+ # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
54
+ class FalconLinear(nn.Linear):
55
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
56
+ hidden_states = input @ self.weight.T
57
+ if self.bias is None:
58
+ return hidden_states
59
+ return hidden_states + self.bias
60
+
61
+
62
+ # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
63
+ def rotate_half(x):
64
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
+ return torch.cat((-x2, x1), dim=-1)
66
+
67
+
68
+ class FalconRotaryEmbedding(nn.Module):
69
+ """Implementation of RotaryEmbedding from GPT-NeoX.
70
+ This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
71
+ n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
72
+ """
73
+
74
+ def __init__(self, head_dim: int, base=10000):
75
+ super().__init__()
76
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
77
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
78
+ self.head_dim = head_dim
79
+ self.seq_len_cached = -1
80
+ self.cos_cached: torch.Tensor | None = None
81
+ self.sin_cached: torch.Tensor | None = None
82
+
83
+ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
84
+ total_length = seq_len + past_key_values_length
85
+ if total_length > self.seq_len_cached:
86
+ self.seq_len_cached = total_length
87
+ t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
88
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
89
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
90
+
91
+ if dtype in [torch.float16, torch.bfloat16]:
92
+ emb = emb.float()
93
+
94
+ self.cos_cached = emb.cos()[None, :, :]
95
+ self.sin_cached = emb.sin()[None, :, :]
96
+
97
+ self.cos_cached = self.cos_cached.type(dtype)
98
+ self.sin_cached = self.sin_cached.type(dtype)
99
+
100
+ return (
101
+ self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
102
+ self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
103
+ )
104
+
105
+ def forward(self, query, key, past_key_values_length=0):
106
+ batch, seq_len, head_dim = query.shape
107
+ cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
108
+ return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
109
+
110
+
111
+ def _make_causal_mask(
112
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
113
+ ) -> torch.BoolTensor:
114
+ """
115
+ Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
116
+ just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
117
+ target_length, target_length+past_key_values_length]`.
118
+ """
119
+ batch_size, target_length = input_ids_shape
120
+
121
+ mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
122
+ # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
123
+ # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
124
+ # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
125
+ past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
126
+ mask = torch.cat([past_mask, mask], dim=-1)
127
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
128
+ return expanded_mask
129
+
130
+
131
+ def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
132
+ """
133
+ Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
134
+ """
135
+ batch_size, total_length = mask.shape
136
+ seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
137
+
138
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
139
+ return expanded_mask.expand(batch_size, 1, seq_length, total_length)
140
+
141
+
142
+ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
143
+ batch_size, seq_length = attention_mask.shape
144
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
145
+ base = torch.tensor(
146
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
147
+ )
148
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
149
+ slopes = torch.pow(base, powers)
150
+
151
+ if closest_power_of_2 != num_heads:
152
+ extra_base = torch.tensor(
153
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
154
+ )
155
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
156
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
157
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
158
+
159
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
160
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
161
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
162
+ # => the query_length dimension will then be broadcasted correctly
163
+ # This is more or less identical to T5's relative position bias:
164
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
165
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
166
+ alibi = slopes[..., None].bfloat16() * arange_tensor
167
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
168
+
169
+
170
+ # Copied from transformers.models.bloom.modeling_bloom.dropout_add
171
+ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
+ """
173
+ Dropout add function
174
+
175
+ Args:
176
+ x (`torch.tensor`, *required*):
177
+ input tensor
178
+ residual (`torch.tensor`, *required*):
179
+ residual tensor
180
+ prob (`float`, *required*):
181
+ dropout probability
182
+ training (`bool`, *required*):
183
+ training mode
184
+ """
185
+ out = F.dropout(x, p=prob, training=training)
186
+ out = residual + out
187
+ return out
188
+
189
+
190
+ class FalconAttention(nn.Module):
191
+ def __init__(self, config: FalconConfig):
192
+ super().__init__()
193
+
194
+ self.hidden_size = config.hidden_size
195
+ self.num_heads = config.num_attention_heads
196
+ self.head_dim = self.hidden_size // self.num_heads
197
+ self.split_size = self.hidden_size
198
+ self.hidden_dropout = config.hidden_dropout
199
+
200
+ if self.head_dim * self.num_heads != self.hidden_size:
201
+ raise ValueError(
202
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
203
+ f" {self.num_heads})."
204
+ )
205
+
206
+ self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
207
+
208
+ # Layer-wise attention scaling
209
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
210
+ self.beta = self.inv_norm_factor
211
+ if config.new_decoder_architecture:
212
+ qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
213
+ elif config.multi_query:
214
+ qkv_out_dim = self.hidden_size + 2 * self.head_dim
215
+ else:
216
+ qkv_out_dim = 3 * self.hidden_size
217
+ self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
218
+ self.new_decoder_architecture = config.new_decoder_architecture
219
+ self.multi_query = config.multi_query
220
+ self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
221
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
222
+ self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
223
+
224
+ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
+ """
226
+ Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
227
+
228
+ Args:
229
+ fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
230
+
231
+ Returns:
232
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
233
+ value: [batch_size, seq_length, num_heads, head_dim]
234
+ """
235
+ if self.new_decoder_architecture:
236
+ batch, seq_len, _ = fused_qkv.shape
237
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
238
+ query = qkv[:, :, :, :-2]
239
+ key = qkv[:, :, :, [-2]]
240
+ value = qkv[:, :, :, [-1]]
241
+ key = torch.broadcast_to(key, query.shape)
242
+ value = torch.broadcast_to(value, query.shape)
243
+
244
+ query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
245
+ return query, key, value
246
+ elif not self.multi_query:
247
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
248
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
249
+ return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
250
+ else:
251
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
252
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
253
+ return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
254
+
255
+ # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
256
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
257
+ """
258
+ Merge heads together over the last dimenstion
259
+
260
+ Args:
261
+ x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
262
+
263
+ Returns:
264
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
265
+ """
266
+ # What we want to achieve is:
267
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
268
+ batch_size_and_num_heads, seq_length, _ = x.shape
269
+ batch_size = batch_size_and_num_heads // self.num_heads
270
+
271
+ # First view to decompose the batch size
272
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
273
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
274
+
275
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
276
+ x = x.permute(0, 2, 1, 3)
277
+
278
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
279
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
280
+
281
+ def forward(
282
+ self,
283
+ hidden_states: torch.Tensor,
284
+ alibi: Optional[torch.Tensor],
285
+ attention_mask: torch.Tensor,
286
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
287
+ head_mask: Optional[torch.Tensor] = None,
288
+ use_cache: bool = False,
289
+ output_attentions: bool = False,
290
+ ):
291
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
292
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
293
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
294
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
295
+
296
+ batch_size, query_length, _, _ = query_layer.shape
297
+
298
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
299
+ key_layer = key_layer.transpose(1, 2).reshape(
300
+ batch_size * num_kv_heads,
301
+ query_length,
302
+ self.head_dim,
303
+ )
304
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
305
+
306
+ past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
307
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
308
+
309
+ if layer_past is not None:
310
+ past_key, past_value = layer_past
311
+ # concatenate along seq_length dimension:
312
+ # - key: [batch_size * self.num_heads, kv_length, head_dim]
313
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
314
+ key_layer = torch.cat((past_key, key_layer), dim=1)
315
+ value_layer = torch.cat((past_value, value_layer), dim=1)
316
+
317
+ _, kv_length, _ = key_layer.shape
318
+ if use_cache:
319
+ present = (key_layer, value_layer)
320
+ else:
321
+ present = None
322
+
323
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
324
+
325
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
326
+ key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
327
+ value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
328
+
329
+ if alibi is None:
330
+ if output_attentions:
331
+ # F.scaled_dot_product_attention doesn't return the attention weights, so we have
332
+ # to do it by hand if we want them
333
+ attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
334
+ attention_scores /= math.sqrt(self.head_dim)
335
+
336
+ attention_scores = F.softmax(
337
+ attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
338
+ )
339
+ attn_output = attention_scores @ value_layer_
340
+ else:
341
+ attn_output = F.scaled_dot_product_attention(
342
+ query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
343
+ )
344
+ attention_scores = None
345
+
346
+ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
347
+ attn_output = attn_output.permute(0, 2, 1, 3)
348
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
349
+
350
+ output_tensor = self.dense(attn_output)
351
+
352
+ if output_attentions:
353
+ return output_tensor, present, attention_scores
354
+ else:
355
+ return output_tensor, present
356
+
357
+ else:
358
+ matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
359
+
360
+ # change view to [batch_size, num_heads, q_length, kv_length]
361
+ attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
362
+
363
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
364
+ input_dtype = attention_scores.dtype
365
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
366
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
367
+ attention_scores = attention_scores.to(torch.float32)
368
+ # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
369
+ # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
370
+ # equivalent and more performant, but there might be a numerical difference. If you're reading this
371
+ # and you'd like to experiment and maybe file a PR, feel free!
372
+ attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
373
+ attention_logits *= self.inv_norm_factor
374
+ attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
375
+ # [batch_size, num_heads, q_length, kv_length]
376
+ attention_probs = self.attention_dropout(attention_probs)
377
+
378
+ if head_mask is not None:
379
+ attention_probs = attention_probs * head_mask
380
+
381
+ # change view [batch_size, num_heads, q_length, kv_length]
382
+ attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
383
+
384
+ # matmul: [batch_size * num_heads, q_length, head_dim]
385
+ context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
386
+
387
+ # change view [batch_size, num_heads, q_length, head_dim]
388
+ context_layer = self._merge_heads(context_layer)
389
+
390
+ output_tensor = self.dense(context_layer)
391
+
392
+ if output_attentions:
393
+ return output_tensor, present, attention_probs
394
+ else:
395
+ return output_tensor, present
396
+
397
+
398
+ class FalconMLP(nn.Module):
399
+ def __init__(self, config: FalconConfig):
400
+ super().__init__()
401
+ hidden_size = config.hidden_size
402
+
403
+ self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
404
+ self.act = nn.GELU()
405
+ self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
406
+ self.hidden_dropout = config.hidden_dropout
407
+
408
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
409
+ x = self.act(self.dense_h_to_4h(x))
410
+ x = self.dense_4h_to_h(x)
411
+ return x
412
+
413
+
414
+ class FalconDecoderLayer(nn.Module):
415
+ def __init__(self, config: FalconConfig):
416
+ super().__init__()
417
+ hidden_size = config.hidden_size
418
+ self.num_heads = config.num_attention_heads
419
+ self.self_attention = FalconAttention(config)
420
+ self.mlp = FalconMLP(config)
421
+ self.hidden_dropout = config.hidden_dropout
422
+ self.config = config
423
+
424
+ if config.new_decoder_architecture:
425
+ # The layer norm before self-attention
426
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
427
+ # The layer norm before the MLP
428
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
429
+ else:
430
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
431
+ if not config.parallel_attn:
432
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
433
+
434
+ def forward(
435
+ self,
436
+ hidden_states: torch.Tensor,
437
+ alibi: Optional[torch.Tensor],
438
+ attention_mask: torch.Tensor,
439
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
440
+ head_mask: Optional[torch.Tensor] = None,
441
+ use_cache: bool = False,
442
+ output_attentions: bool = False,
443
+ ):
444
+ residual = hidden_states
445
+
446
+ if self.config.new_decoder_architecture:
447
+ attention_layernorm_out = self.ln_attn(hidden_states)
448
+ mlp_layernorm_out = self.ln_mlp(hidden_states)
449
+ else:
450
+ attention_layernorm_out = self.input_layernorm(hidden_states)
451
+
452
+ # Self attention.
453
+ attn_outputs = self.self_attention(
454
+ attention_layernorm_out,
455
+ layer_past=layer_past,
456
+ attention_mask=attention_mask,
457
+ alibi=alibi,
458
+ head_mask=head_mask,
459
+ use_cache=use_cache,
460
+ output_attentions=output_attentions,
461
+ )
462
+
463
+ attention_output = attn_outputs[0]
464
+
465
+ if not self.config.new_decoder_architecture:
466
+ if self.config.parallel_attn:
467
+ mlp_layernorm_out = attention_layernorm_out
468
+ else:
469
+ residual = dropout_add(
470
+ attention_output, residual, self.config.attention_dropout, training=self.training
471
+ )
472
+ mlp_layernorm_out = self.post_attention_layernorm(residual)
473
+
474
+ outputs = attn_outputs[1:]
475
+
476
+ # MLP.
477
+ mlp_output = self.mlp(mlp_layernorm_out)
478
+
479
+ if self.config.new_decoder_architecture or self.config.parallel_attn:
480
+ mlp_output += attention_output
481
+
482
+ output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
483
+
484
+ if use_cache:
485
+ outputs = (output,) + outputs
486
+ else:
487
+ outputs = (output,) + outputs[1:]
488
+
489
+ return outputs # hidden_states, present, attentions
490
+
491
+
492
+ FALCON_START_DOCSTRING = r"""
493
+
494
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
495
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
496
+
497
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
498
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
499
+ and behavior.
500
+
501
+ Parameters:
502
+ config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
503
+ Initializing with a config file does not load the weights associated with the model, only the
504
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
505
+ """
506
+
507
+ FALCON_INPUTS_DOCSTRING = r"""
508
+ Args:
509
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
510
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
511
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
512
+
513
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
514
+ `input_ids`.
515
+
516
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
517
+ [`PreTrainedTokenizer.__call__`] for details.
518
+
519
+ [What are input IDs?](../glossary#input-ids)
520
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
521
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
522
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
523
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
524
+
525
+ Each element of `past_key_values` is a tuple (past_key, past_value):
526
+ - past_key: [batch_size * num_heads, head_dim, kv_length]
527
+ - past_value: [batch_size * num_heads, kv_length, head_dim]
528
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
529
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
530
+
531
+ - 1 for tokens that are **not masked**,
532
+ - 0 for tokens that are **masked**.
533
+
534
+ [What are attention masks?](../glossary#attention-mask)
535
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
536
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
537
+
538
+ - 1 indicates the head is **not masked**,
539
+ - 0 indicates the head is **masked**.
540
+
541
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
542
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
543
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
544
+ model's internal embedding lookup matrix.
545
+
546
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
547
+ `past_key_values`).
548
+ use_cache (`bool`, *optional*):
549
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
550
+ `past_key_values`).
551
+ output_attentions (`bool`, *optional*):
552
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
553
+ tensors for more detail.
554
+ output_hidden_states (`bool`, *optional*):
555
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
556
+ more detail.
557
+ return_dict (`bool`, *optional*):
558
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
559
+ """
560
+
561
+
562
+ class FalconPreTrainedModel(PreTrainedModel):
563
+ """
564
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
565
+ models.
566
+ """
567
+
568
+ config_class = FalconConfig
569
+ base_model_prefix = "transformer"
570
+ supports_gradient_checkpointing = True
571
+ _no_split_modules = ["FalconDecoderLayer"]
572
+
573
+ def __init__(self, *inputs, **kwargs):
574
+ super().__init__(*inputs, **kwargs)
575
+
576
+ def _init_weights(self, module: nn.Module):
577
+ """Initialize the weights."""
578
+ if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
579
+ # Slightly different from the TF version which uses truncated_normal for initialization
580
+ # cf https://github.com/pytorch/pytorch/pull/5617
581
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
582
+ if module.bias is not None:
583
+ module.bias.data.zero_()
584
+ elif isinstance(module, nn.Embedding):
585
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
586
+ if module.padding_idx is not None:
587
+ module.weight.data[module.padding_idx].zero_()
588
+ elif isinstance(module, LayerNorm):
589
+ module.bias.data.zero_()
590
+ module.weight.data.fill_(1.0)
591
+
592
+ # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
593
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
594
+ if isinstance(module, FalconModel):
595
+ module.gradient_checkpointing = value
596
+
597
+ @staticmethod
598
+ def _convert_cache_to_standard_format(
599
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
600
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
601
+ """
602
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
603
+ num_heads, ...]))
604
+ """
605
+ batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
606
+ # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
607
+ # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
608
+ # on whether we use multi_query attention.
609
+ num_heads = batch_size_times_num_heads // batch_size
610
+ return tuple(
611
+ (
612
+ layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
613
+ layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
614
+ )
615
+ for layer_past in past_key_value
616
+ )
617
+
618
+ @staticmethod
619
+ def _convert_to_rw_cache(
620
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
621
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
622
+ batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
623
+ batch_size_times_num_heads = batch_size * num_heads
624
+ # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
625
+ return tuple(
626
+ (
627
+ layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
628
+ layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
629
+ )
630
+ for layer_past in past_key_value
631
+ )
632
+
633
+
634
+ @add_start_docstrings(
635
+ "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
636
+ FALCON_START_DOCSTRING,
637
+ )
638
+ class FalconModel(FalconPreTrainedModel):
639
+ def __init__(self, config: FalconConfig):
640
+ super().__init__(config)
641
+
642
+ self.embed_dim = config.hidden_size
643
+ self.num_heads = config.num_attention_heads
644
+ self.use_alibi = config.alibi
645
+
646
+ # Embedding + LN Embedding
647
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
648
+
649
+ # Transformer blocks
650
+ self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
651
+
652
+ # Final Layer Norm
653
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
654
+
655
+ self.gradient_checkpointing = False
656
+
657
+ # Initialize weights and apply final processing
658
+ self.post_init()
659
+
660
+ def get_input_embeddings(self):
661
+ return self.word_embeddings
662
+
663
+ @staticmethod
664
+ def _prepare_attn_mask(
665
+ attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
666
+ ) -> torch.BoolTensor:
667
+ # Create a causal mask
668
+ # The attention mask we receive as input should cover the whole extended sequence, including any past
669
+ # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
670
+ # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
671
+ if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
672
+ raise ValueError(
673
+ "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
674
+ f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
675
+ f" {past_key_values_length}."
676
+ )
677
+ combined_attention_mask = None
678
+ device = attention_mask.device
679
+ _, seq_length = input_shape
680
+
681
+ if seq_length > 1:
682
+ combined_attention_mask = _make_causal_mask(
683
+ input_shape, device=device, past_key_values_length=past_key_values_length
684
+ )
685
+
686
+ # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
687
+ expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
688
+ combined_attention_mask = (
689
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
690
+ )
691
+
692
+ return combined_attention_mask
693
+
694
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
695
+ self.word_embeddings = new_embeddings
696
+
697
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
698
+ @add_code_sample_docstrings(
699
+ checkpoint=_CHECKPOINT_FOR_DOC,
700
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
701
+ config_class=_CONFIG_FOR_DOC,
702
+ )
703
+ def forward(
704
+ self,
705
+ input_ids: Optional[torch.LongTensor] = None,
706
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
707
+ attention_mask: Optional[torch.Tensor] = None,
708
+ head_mask: Optional[torch.LongTensor] = None,
709
+ inputs_embeds: Optional[torch.LongTensor] = None,
710
+ use_cache: Optional[bool] = None,
711
+ output_attentions: Optional[bool] = None,
712
+ output_hidden_states: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
715
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
+ output_hidden_states = (
717
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
+ )
719
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
720
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
721
+
722
+ if input_ids is not None and inputs_embeds is not None:
723
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
724
+ elif input_ids is not None:
725
+ batch_size, seq_length = input_ids.shape
726
+ elif inputs_embeds is not None:
727
+ batch_size, seq_length, _ = inputs_embeds.shape
728
+ else:
729
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
730
+
731
+ if past_key_values is None:
732
+ past_key_values = tuple([None] * len(self.h))
733
+ else:
734
+ past_key_values = self._convert_to_rw_cache(past_key_values)
735
+
736
+ # Prepare head mask if needed
737
+ # 1.0 in head_mask indicate we keep the head
738
+ # attention_probs has shape batch_size x num_heads x N x N
739
+ # head_mask has shape n_layer x batch x num_heads x N x N
740
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
741
+
742
+ if inputs_embeds is None:
743
+ inputs_embeds = self.word_embeddings(input_ids)
744
+
745
+ hidden_states = inputs_embeds
746
+
747
+ presents = () if use_cache else None
748
+ all_self_attentions = () if output_attentions else None
749
+ all_hidden_states = () if output_hidden_states else None
750
+
751
+ # Compute alibi tensor: check build_alibi_tensor documentation
752
+ past_key_values_length = 0
753
+ if past_key_values[0] is not None:
754
+ past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
755
+ if attention_mask is None:
756
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
757
+ else:
758
+ attention_mask = attention_mask.to(hidden_states.device)
759
+
760
+ if self.use_alibi:
761
+ alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
762
+ else:
763
+ alibi = None
764
+
765
+ causal_mask = self._prepare_attn_mask(
766
+ attention_mask,
767
+ input_shape=(batch_size, seq_length),
768
+ past_key_values_length=past_key_values_length,
769
+ )
770
+
771
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
772
+ if output_hidden_states:
773
+ all_hidden_states = all_hidden_states + (hidden_states,)
774
+
775
+ if self.gradient_checkpointing and self.training:
776
+ if use_cache:
777
+ logger.warning(
778
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
779
+ )
780
+ use_cache = False
781
+
782
+ def create_custom_forward(module):
783
+ def custom_forward(*inputs):
784
+ # None for past_key_value
785
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
786
+
787
+ return custom_forward
788
+
789
+ outputs = torch.utils.checkpoint.checkpoint(
790
+ create_custom_forward(block),
791
+ hidden_states,
792
+ alibi,
793
+ causal_mask,
794
+ head_mask[i],
795
+ )
796
+ else:
797
+ outputs = block(
798
+ hidden_states,
799
+ layer_past=layer_past,
800
+ attention_mask=causal_mask,
801
+ head_mask=head_mask[i],
802
+ use_cache=use_cache,
803
+ output_attentions=output_attentions,
804
+ alibi=alibi,
805
+ )
806
+
807
+ hidden_states = outputs[0]
808
+ if use_cache is True:
809
+ presents = presents + (outputs[1],)
810
+
811
+ if output_attentions:
812
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
813
+
814
+ # Add last hidden state
815
+ hidden_states = self.ln_f(hidden_states)
816
+
817
+ if output_hidden_states:
818
+ all_hidden_states = all_hidden_states + (hidden_states,)
819
+
820
+ if presents is not None:
821
+ presents = self._convert_cache_to_standard_format(presents, batch_size)
822
+
823
+ if not return_dict:
824
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
825
+
826
+ return BaseModelOutputWithPastAndCrossAttentions(
827
+ last_hidden_state=hidden_states,
828
+ past_key_values=presents,
829
+ hidden_states=all_hidden_states,
830
+ attentions=all_self_attentions,
831
+ )
832
+
833
+
834
+ @add_start_docstrings(
835
+ "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
836
+ FALCON_START_DOCSTRING,
837
+ )
838
+ class FalconForCausalLM(FalconPreTrainedModel):
839
+ _tied_weights_keys = ["lm_head.weight"]
840
+
841
+ def __init__(self, config: FalconConfig):
842
+ super().__init__(config)
843
+ self.transformer = FalconModel(config)
844
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
845
+
846
+ # Initialize weights and apply final processing
847
+ self.post_init()
848
+
849
+ def get_output_embeddings(self):
850
+ return self.lm_head
851
+
852
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
853
+ self.lm_head = new_embeddings
854
+
855
+ def prepare_inputs_for_generation(
856
+ self,
857
+ input_ids: torch.LongTensor,
858
+ past_key_values: Optional[torch.Tensor] = None,
859
+ attention_mask: Optional[torch.Tensor] = None,
860
+ **kwargs,
861
+ ) -> dict:
862
+ if past_key_values is not None:
863
+ input_ids = input_ids[:, -1:]
864
+
865
+ return {
866
+ "input_ids": input_ids,
867
+ "past_key_values": past_key_values,
868
+ "use_cache": kwargs.get("use_cache"),
869
+ "attention_mask": attention_mask,
870
+ }
871
+
872
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
873
+ @add_code_sample_docstrings(
874
+ checkpoint=_CHECKPOINT_FOR_DOC,
875
+ output_type=CausalLMOutputWithCrossAttentions,
876
+ config_class=_CONFIG_FOR_DOC,
877
+ )
878
+ def forward(
879
+ self,
880
+ input_ids: Optional[torch.LongTensor] = None,
881
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ head_mask: Optional[torch.Tensor] = None,
884
+ inputs_embeds: Optional[torch.Tensor] = None,
885
+ labels: Optional[torch.Tensor] = None,
886
+ use_cache: Optional[bool] = None,
887
+ output_attentions: Optional[bool] = None,
888
+ output_hidden_states: Optional[bool] = None,
889
+ return_dict: Optional[bool] = None,
890
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
891
+ r"""
892
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
893
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
894
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
895
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
896
+ """
897
+
898
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
899
+
900
+ transformer_outputs = self.transformer(
901
+ input_ids,
902
+ past_key_values=past_key_values,
903
+ attention_mask=attention_mask,
904
+ head_mask=head_mask,
905
+ inputs_embeds=inputs_embeds,
906
+ use_cache=use_cache,
907
+ output_attentions=output_attentions,
908
+ output_hidden_states=output_hidden_states,
909
+ return_dict=return_dict,
910
+ )
911
+ hidden_states = transformer_outputs[0]
912
+
913
+ lm_logits = self.lm_head(hidden_states)
914
+
915
+ loss = None
916
+ if labels is not None:
917
+ # Shift so that tokens < n predict n
918
+ shift_logits = lm_logits[..., :-1, :].contiguous()
919
+ shift_labels = labels[..., 1:].contiguous()
920
+ batch_size, seq_length, vocab_size = shift_logits.shape
921
+ # Flatten the tokens
922
+ loss_fct = CrossEntropyLoss()
923
+ loss = loss_fct(
924
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
925
+ )
926
+
927
+ if not return_dict:
928
+ output = (lm_logits,) + transformer_outputs[1:]
929
+ return ((loss,) + output) if loss is not None else output
930
+
931
+ return CausalLMOutputWithCrossAttentions(
932
+ loss=loss,
933
+ logits=lm_logits,
934
+ past_key_values=transformer_outputs.past_key_values,
935
+ hidden_states=transformer_outputs.hidden_states,
936
+ attentions=transformer_outputs.attentions,
937
+ )
938
+
939
+ def _reorder_cache(
940
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
941
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
942
+ """
943
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
944
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
945
+ beam_idx at every generation step.
946
+
947
+ Output shares the same memory storage as `past`.
948
+ """
949
+
950
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
951
+ device_to_beam_idx = {
952
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
953
+ }
954
+ reordered_past = tuple(
955
+ (
956
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
957
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
958
+ )
959
+ for layer_past in past
960
+ )
961
+ return reordered_past
962
+
963
+
964
+ @add_start_docstrings(
965
+ """
966
+ The Falcon Model transformer with a sequence classification head on top (linear layer).
967
+
968
+ [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
969
+ (e.g. GPT-1) do.
970
+
971
+ Since it does classification on the last token, it requires to know the position of the last token. If a
972
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
973
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
974
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
975
+ each row of the batch).
976
+ """,
977
+ FALCON_START_DOCSTRING,
978
+ )
979
+ class FalconForSequenceClassification(FalconPreTrainedModel):
980
+ def __init__(self, config: FalconConfig):
981
+ super().__init__(config)
982
+ self.num_labels = config.num_labels
983
+ self.transformer = FalconModel(config)
984
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
985
+
986
+ # Initialize weights and apply final processing
987
+ self.post_init()
988
+
989
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
990
+ @add_code_sample_docstrings(
991
+ checkpoint=_CHECKPOINT_FOR_DOC,
992
+ output_type=SequenceClassifierOutputWithPast,
993
+ config_class=_CONFIG_FOR_DOC,
994
+ )
995
+ def forward(
996
+ self,
997
+ input_ids: Optional[torch.LongTensor] = None,
998
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
999
+ attention_mask: Optional[torch.Tensor] = None,
1000
+ head_mask: Optional[torch.Tensor] = None,
1001
+ inputs_embeds: Optional[torch.Tensor] = None,
1002
+ labels: Optional[torch.Tensor] = None,
1003
+ use_cache: Optional[bool] = None,
1004
+ output_attentions: Optional[bool] = None,
1005
+ output_hidden_states: Optional[bool] = None,
1006
+ return_dict: Optional[bool] = None,
1007
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1008
+ r"""
1009
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1010
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1011
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1012
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1013
+ """
1014
+
1015
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1016
+
1017
+ transformer_outputs = self.transformer(
1018
+ input_ids,
1019
+ past_key_values=past_key_values,
1020
+ attention_mask=attention_mask,
1021
+ head_mask=head_mask,
1022
+ inputs_embeds=inputs_embeds,
1023
+ use_cache=use_cache,
1024
+ output_attentions=output_attentions,
1025
+ output_hidden_states=output_hidden_states,
1026
+ return_dict=return_dict,
1027
+ )
1028
+
1029
+ hidden_states = transformer_outputs[0]
1030
+ logits = self.score(hidden_states)
1031
+
1032
+ if input_ids is not None:
1033
+ batch_size = input_ids.shape[0]
1034
+ else:
1035
+ batch_size = inputs_embeds.shape[0]
1036
+
1037
+ if self.config.pad_token_id is None and batch_size != 1:
1038
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1039
+ if self.config.pad_token_id is None:
1040
+ sequence_lengths = -1
1041
+ else:
1042
+ if input_ids is not None:
1043
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
1044
+ else:
1045
+ sequence_lengths = -1
1046
+ logger.warning(
1047
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1048
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1049
+ )
1050
+
1051
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1052
+
1053
+ loss = None
1054
+ if labels is not None:
1055
+ if self.config.problem_type is None:
1056
+ if self.num_labels == 1:
1057
+ self.config.problem_type = "regression"
1058
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1059
+ self.config.problem_type = "single_label_classification"
1060
+ else:
1061
+ self.config.problem_type = "multi_label_classification"
1062
+
1063
+ if self.config.problem_type == "regression":
1064
+ loss_fct = MSELoss()
1065
+ if self.num_labels == 1:
1066
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1067
+ else:
1068
+ loss = loss_fct(pooled_logits, labels)
1069
+ elif self.config.problem_type == "single_label_classification":
1070
+ loss_fct = CrossEntropyLoss()
1071
+ loss = loss_fct(pooled_logits, labels)
1072
+ elif self.config.problem_type == "multi_label_classification":
1073
+ loss_fct = BCEWithLogitsLoss()
1074
+ loss = loss_fct(pooled_logits, labels)
1075
+ if not return_dict:
1076
+ output = (pooled_logits,) + transformer_outputs[1:]
1077
+ return ((loss,) + output) if loss is not None else output
1078
+
1079
+ return SequenceClassifierOutputWithPast(
1080
+ loss=loss,
1081
+ logits=pooled_logits,
1082
+ past_key_values=transformer_outputs.past_key_values,
1083
+ hidden_states=transformer_outputs.hidden_states,
1084
+ attentions=transformer_outputs.attentions,
1085
+ )
1086
+
1087
+
1088
+ @add_start_docstrings(
1089
+ """
1090
+ Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1091
+ Named-Entity-Recognition (NER) tasks.
1092
+ """,
1093
+ FALCON_START_DOCSTRING,
1094
+ )
1095
+ class FalconForTokenClassification(FalconPreTrainedModel):
1096
+ def __init__(self, config: FalconConfig):
1097
+ super().__init__(config)
1098
+ self.num_labels = config.num_labels
1099
+
1100
+ self.transformer = FalconModel(config)
1101
+ if getattr(config, "classifier_dropout", None) is not None:
1102
+ classifier_dropout = config.classifier_dropout
1103
+ elif getattr(config, "hidden_dropout", None) is not None:
1104
+ classifier_dropout = config.hidden_dropout
1105
+ else:
1106
+ classifier_dropout = 0.1
1107
+ self.dropout = nn.Dropout(classifier_dropout)
1108
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1109
+
1110
+ # Initialize weights and apply final processing
1111
+ self.post_init()
1112
+
1113
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1114
+ @add_code_sample_docstrings(
1115
+ checkpoint=_CHECKPOINT_FOR_DOC,
1116
+ output_type=TokenClassifierOutput,
1117
+ config_class=_CONFIG_FOR_DOC,
1118
+ )
1119
+ def forward(
1120
+ self,
1121
+ input_ids: Optional[torch.LongTensor] = None,
1122
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1123
+ attention_mask: Optional[torch.Tensor] = None,
1124
+ head_mask: Optional[torch.Tensor] = None,
1125
+ inputs_embeds: Optional[torch.Tensor] = None,
1126
+ labels: Optional[torch.Tensor] = None,
1127
+ use_cache: Optional[bool] = None,
1128
+ output_attentions: Optional[bool] = None,
1129
+ output_hidden_states: Optional[bool] = None,
1130
+ return_dict: Optional[bool] = None,
1131
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1132
+ r"""
1133
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1134
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1135
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1136
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1137
+ """
1138
+
1139
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1140
+
1141
+ transformer_outputs = self.transformer(
1142
+ input_ids,
1143
+ past_key_values=past_key_values,
1144
+ attention_mask=attention_mask,
1145
+ head_mask=head_mask,
1146
+ inputs_embeds=inputs_embeds,
1147
+ use_cache=use_cache,
1148
+ output_attentions=output_attentions,
1149
+ output_hidden_states=output_hidden_states,
1150
+ return_dict=return_dict,
1151
+ )
1152
+
1153
+ hidden_states = transformer_outputs[0]
1154
+ hidden_states = self.dropout(hidden_states)
1155
+ logits = self.classifier(hidden_states)
1156
+
1157
+ loss = None
1158
+ if labels is not None:
1159
+ batch_size, seq_length = labels.shape
1160
+ loss_fct = CrossEntropyLoss()
1161
+ loss = loss_fct(
1162
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1163
+ )
1164
+
1165
+ if not return_dict:
1166
+ output = (logits,) + transformer_outputs[2:]
1167
+ return ((loss,) + output) if loss is not None else output
1168
+
1169
+ return TokenClassifierOutput(
1170
+ loss=loss,
1171
+ logits=logits,
1172
+ hidden_states=transformer_outputs.hidden_states,
1173
+ attentions=transformer_outputs.attentions,
1174
+ )
1175
+
1176
+
1177
+ @add_start_docstrings(
1178
+ """
1179
+ The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1180
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1181
+ """,
1182
+ FALCON_START_DOCSTRING,
1183
+ )
1184
+ class FalconForQuestionAnswering(FalconPreTrainedModel):
1185
+ def __init__(self, config):
1186
+ super().__init__(config)
1187
+ self.transformer = FalconModel(config)
1188
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1189
+
1190
+ # Initialize weights and apply final processing
1191
+ self.post_init()
1192
+
1193
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1194
+ def forward(
1195
+ self,
1196
+ input_ids: Optional[torch.LongTensor] = None,
1197
+ attention_mask: Optional[torch.FloatTensor] = None,
1198
+ head_mask: Optional[torch.FloatTensor] = None,
1199
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1200
+ start_positions: Optional[torch.LongTensor] = None,
1201
+ end_positions: Optional[torch.LongTensor] = None,
1202
+ output_attentions: Optional[bool] = None,
1203
+ output_hidden_states: Optional[bool] = None,
1204
+ return_dict: Optional[bool] = None,
1205
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1206
+ r"""
1207
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1208
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1209
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1210
+ are not taken into account for computing the loss.
1211
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1212
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1213
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1214
+ are not taken into account for computing the loss.
1215
+ """
1216
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1217
+
1218
+ outputs = self.transformer(
1219
+ input_ids,
1220
+ attention_mask=attention_mask,
1221
+ head_mask=head_mask,
1222
+ inputs_embeds=inputs_embeds,
1223
+ output_attentions=output_attentions,
1224
+ output_hidden_states=output_hidden_states,
1225
+ return_dict=return_dict,
1226
+ )
1227
+
1228
+ sequence_output = outputs[0]
1229
+
1230
+ logits = self.qa_outputs(sequence_output)
1231
+ start_logits, end_logits = logits.split(1, dim=-1)
1232
+ start_logits = start_logits.squeeze(-1).contiguous()
1233
+ end_logits = end_logits.squeeze(-1).contiguous()
1234
+
1235
+ total_loss = None
1236
+ if start_positions is not None and end_positions is not None:
1237
+ # If we are on multi-GPU, split add a dimension
1238
+ if len(start_positions.size()) > 1:
1239
+ start_positions = start_positions.squeeze(-1)
1240
+ if len(end_positions.size()) > 1:
1241
+ end_positions = end_positions.squeeze(-1)
1242
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1243
+ ignored_index = start_logits.size(1)
1244
+ start_positions = start_positions.clamp(0, ignored_index)
1245
+ end_positions = end_positions.clamp(0, ignored_index)
1246
+
1247
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1248
+ start_loss = loss_fct(start_logits, start_positions)
1249
+ end_loss = loss_fct(end_logits, end_positions)
1250
+ total_loss = (start_loss + end_loss) / 2
1251
+
1252
+ if not return_dict:
1253
+ output = (start_logits, end_logits) + outputs[2:]
1254
+ return ((total_loss,) + output) if total_loss is not None else output
1255
+
1256
+ return QuestionAnsweringModelOutput(
1257
+ loss=total_loss,
1258
+ start_logits=start_logits,
1259
+ end_logits=end_logits,
1260
+ hidden_states=outputs.hidden_states,
1261
+ attentions=outputs.attentions,
1262
+ )