sgugger commited on
Commit
561820f
·
1 Parent(s): f1ba7d3

Revert in-library PR (#90)

Browse files

- Revert "Move to in-library checkpoint (#81)" (46f48b52085123273538fbbbfc655047720a63d2)

README.md CHANGED
@@ -16,7 +16,7 @@ license: apache-2.0
16
 
17
  *Paper coming soon 😊.*
18
 
19
- ⚠️ Falcon is now available as a core model in the `transformers` library! To use the in-library version, please install the latest version of `transformers` with `pip install git+https://github.com/huggingface/transformers.git`, then simply remove the `trust_remote_code=True` argument from `from_pretrained()`.
20
 
21
  # Call for Proposals : Falcon 40B - World's Top Ranked AI Model Empowers Exceptional Use Cases with Training Compute Power in Call for Proposals
22
 
@@ -57,6 +57,7 @@ pipeline = transformers.pipeline(
57
  model=model,
58
  tokenizer=tokenizer,
59
  torch_dtype=torch.bfloat16,
 
60
  device_map="auto",
61
  )
62
  sequences = pipeline(
@@ -127,6 +128,7 @@ pipeline = transformers.pipeline(
127
  model=model,
128
  tokenizer=tokenizer,
129
  torch_dtype=torch.bfloat16,
 
130
  device_map="auto",
131
  )
132
  sequences = pipeline(
@@ -267,4 +269,4 @@ To learn more about the pretraining dataset, see the 📓 [RefinedWeb paper](htt
267
  Falcon-40B is made available under the Apache 2.0 license.
268
 
269
  ## Contact
270
 
16
 
17
  *Paper coming soon 😊.*
18
 
19
+
20
 
21
  # Call for Proposals : Falcon 40B - World's Top Ranked AI Model Empowers Exceptional Use Cases with Training Compute Power in Call for Proposals
22
 
 
57
  model=model,
58
  tokenizer=tokenizer,
59
  torch_dtype=torch.bfloat16,
60
+ trust_remote_code=True,
61
  device_map="auto",
62
  )
63
  sequences = pipeline(
 
128
  model=model,
129
  tokenizer=tokenizer,
130
  torch_dtype=torch.bfloat16,
131
+ trust_remote_code=True,
132
  device_map="auto",
133
  )
134
  sequences = pipeline(
 
269
  Falcon-40B is made available under the Apache 2.0 license.
270
 
271
  ## Contact
272
config.json CHANGED
@@ -2,16 +2,16 @@
2
  "alibi": false,
3
  "apply_residual_connection_post_layernorm": false,
4
  "architectures": [
5
- "FalconForCausalLM"
6
  ],
7
  "attention_dropout": 0.0,
8
  "auto_map": {
9
- "AutoConfig": "configuration_falcon.FalconConfig",
10
- "AutoModel": "modeling_falcon.FalconModel",
11
- "AutoModelForSequenceClassification": "modeling_falcon.FalconForSequenceClassification",
12
- "AutoModelForTokenClassification": "modeling_falcon.FalconForTokenClassification",
13
- "AutoModelForQuestionAnswering": "modeling_falcon.FalconForQuestionAnswering",
14
- "AutoModelForCausalLM": "modeling_falcon.FalconForCausalLM"
15
  },
16
  "bias": false,
17
  "bos_token_id": 11,
@@ -20,11 +20,10 @@
20
  "hidden_size": 8192,
21
  "initializer_range": 0.02,
22
  "layer_norm_epsilon": 1e-05,
23
- "model_type": "falcon",
24
- "new_decoder_architecture": true,
25
- "num_attention_heads": 128,
26
- "num_hidden_layers": 60,
27
- "num_kv_heads": 8,
28
  "parallel_attn": true,
29
  "torch_dtype": "bfloat16",
30
  "transformers_version": "4.27.4",
 
2
  "alibi": false,
3
  "apply_residual_connection_post_layernorm": false,
4
  "architectures": [
5
+ "RWForCausalLM"
6
  ],
7
  "attention_dropout": 0.0,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_RW.RWConfig",
10
+ "AutoModel": "modelling_RW.RWModel",
11
+ "AutoModelForSequenceClassification": "modelling_RW.RWForSequenceClassification",
12
+ "AutoModelForTokenClassification": "modelling_RW.RWForTokenClassification",
13
+ "AutoModelForQuestionAnswering": "modelling_RW.RWForQuestionAnswering",
14
+ "AutoModelForCausalLM": "modelling_RW.RWForCausalLM"
15
  },
16
  "bias": false,
17
  "bos_token_id": 11,
 
20
  "hidden_size": 8192,
21
  "initializer_range": 0.02,
22
  "layer_norm_epsilon": 1e-05,
23
+ "model_type": "RefinedWeb",
24
+ "n_head": 128,
25
+ "n_head_kv": 8,
26
+ "n_layer": 60,
 
27
  "parallel_attn": true,
28
  "torch_dtype": "bfloat16",
29
  "transformers_version": "4.27.4",
configuration_RW.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Bloom configuration"""
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class RWConfig(PretrainedConfig):
24
+ model_type = "RefinedWeb"
25
+ keys_to_ignore_at_inference = ["past_key_values"]
26
+ attribute_map = {
27
+ "num_hidden_layers": "n_layer",
28
+ "num_attention_heads": "n_head",
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ vocab_size=250880,
34
+ hidden_size=64,
35
+ n_layer=2,
36
+ n_head=8,
37
+ layer_norm_epsilon=1e-5,
38
+ initializer_range=0.02,
39
+ use_cache=True,
40
+ bos_token_id=1,
41
+ eos_token_id=2,
42
+ apply_residual_connection_post_layernorm=False,
43
+ hidden_dropout=0.0,
44
+ attention_dropout=0.0,
45
+ n_head_kv=None,
46
+ alibi=False,
47
+ **kwargs,
48
+ ):
49
+ self.vocab_size = vocab_size
50
+ # Backward compatibility with n_embed kwarg
51
+ n_embed = kwargs.pop("n_embed", None)
52
+ self.hidden_size = hidden_size if n_embed is None else n_embed
53
+ self.n_layer = n_layer
54
+ self.n_head = n_head
55
+ self.layer_norm_epsilon = layer_norm_epsilon
56
+ self.initializer_range = initializer_range
57
+ self.use_cache = use_cache
58
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
59
+ self.hidden_dropout = hidden_dropout
60
+ self.attention_dropout = attention_dropout
61
+
62
+ self.bos_token_id = bos_token_id
63
+ self.eos_token_id = eos_token_id
64
+ self.n_head_kv = n_head if n_head_kv is None else n_head_kv
65
+ self.alibi = alibi
66
+
67
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
68
+
69
+ @property
70
+ def head_dim(self):
71
+ return self.hidden_size // self.n_head
72
+
73
+ @property
74
+ def rotary(self):
75
+ return not self.alibi
configuration_falcon.py DELETED
@@ -1,147 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Falcon configuration"""
16
- from transformers.configuration_utils import PretrainedConfig
17
- from transformers.utils import logging
18
-
19
-
20
- logger = logging.get_logger(__name__)
21
-
22
- FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
23
- "tiiuae/falcon-40b": "https://huggingface.co/tiiuae/falcon-40b/resolve/main/config.json",
24
- "tiiuae/falcon-7b": "https://huggingface.co/tiiuae/falcon-7b/resolve/main/config.json",
25
- }
26
-
27
-
28
- class FalconConfig(PretrainedConfig):
29
- r"""
30
- This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
31
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
- defaults will yield a similar configuration to that of the
33
- [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
34
-
35
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
- documentation from [`PretrainedConfig`] for more information.
37
-
38
-
39
- Args:
40
- vocab_size (`int`, *optional*, defaults to 65024):
41
- Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the
42
- `inputs_ids` passed when calling [`FalconModel`]
43
- hidden_size (`int`, *optional*, defaults to 4544):
44
- Dimension of the hidden representations.
45
- num_hidden_layers (`int`, *optional*, defaults to 32):
46
- Number of hidden layers in the Transformer decoder.
47
- num_attention_heads (`int`, *optional*, defaults to 71):
48
- Number of attention heads for each attention layer in the Transformer encoder.
49
- initializer_range (`float`, *optional*, defaults to 0.02):
50
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51
- use_cache (`bool`, *optional*, defaults to `True`):
52
- Whether the model should return the last key/values attentions (not used by all models). Only relevant if
53
- `config.is_decoder=True`.
54
- layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
55
- The epsilon used by the layer normalization layers.
56
- hidden_dropout (`float`, *optional*, defaults to 0.0):
57
- The dropout probability for MLP layers.
58
- attention_dropout (`float`, *optional*, defaults to 0.0):
59
- The dropout probability for attention layers.
60
- num_kv_heads (`int`, *optional*):
61
- Number of key-value heads to use per attention layer. If unset, defaults to the same value as
62
- `num_attention_heads`.
63
- alibi (`bool`, *optional*, defaults to `False`):
64
- Whether to use ALiBi positional biases during self-attention.
65
- new_decoder_architecture (`bool`, *optional*, defaults to `False`):
66
- Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
67
- arguments are ignored, as the new decoder always uses parallel attention.
68
- multi_query (`bool`, *optional*, defaults to `True`):
69
- Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
70
- parallel_attn (`bool`, *optional*, defaults to `True`):
71
- Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
72
- instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
73
- bias (`bool`, *optional*, defaults to `False`):
74
- Whether to use bias on Linear layers.
75
- bos_token_id (`int`, *optional*, defaults to 11):
76
- The id of the "beginning-of-sequence" token.
77
- eos_token_id (`int`, *optional*, defaults to 11):
78
- The id of the "end-of-sequence" token.
79
-
80
- Example:
81
-
82
- ```python
83
- >>> from transformers import FalconModel, FalconConfig
84
-
85
- >>> # Initializing a small (2-layer) Falcon configuration
86
- >>> configuration = FalconConfig(num_hidden_layers=2)
87
-
88
- >>> # Initializing a model from the small configuration
89
- >>> model = FalconModel(configuration)
90
-
91
- >>> # Accessing the model configuration
92
- >>> configuration = model.config
93
- ```"""
94
- model_type = "falcon"
95
- keys_to_ignore_at_inference = ["past_key_values"]
96
-
97
- def __init__(
98
- self,
99
- vocab_size=65024,
100
- hidden_size=4544,
101
- num_hidden_layers=32,
102
- num_attention_heads=71,
103
- layer_norm_epsilon=1e-5,
104
- initializer_range=0.02,
105
- use_cache=True,
106
- hidden_dropout=0.0,
107
- attention_dropout=0.0,
108
- num_kv_heads=None,
109
- alibi=False,
110
- new_decoder_architecture=False,
111
- multi_query=True,
112
- parallel_attn=True,
113
- bias=False,
114
- bos_token_id=11,
115
- eos_token_id=11,
116
- **kwargs,
117
- ):
118
- self.vocab_size = vocab_size
119
- # Backward compatibility with n_embed kwarg
120
- n_embed = kwargs.pop("n_embed", None)
121
- self.hidden_size = hidden_size if n_embed is None else n_embed
122
- self.num_hidden_layers = num_hidden_layers
123
- self.num_attention_heads = num_attention_heads
124
- self.layer_norm_epsilon = layer_norm_epsilon
125
- self.initializer_range = initializer_range
126
- self.use_cache = use_cache
127
- self.hidden_dropout = hidden_dropout
128
- self.attention_dropout = attention_dropout
129
-
130
- self.bos_token_id = bos_token_id
131
- self.eos_token_id = eos_token_id
132
- self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
133
- self.alibi = alibi
134
- self.new_decoder_architecture = new_decoder_architecture
135
- self.multi_query = multi_query # Ignored when new_decoder_architecture is True
136
- self.parallel_attn = parallel_attn
137
- self.bias = bias
138
-
139
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
140
-
141
- @property
142
- def head_dim(self):
143
- return self.hidden_size // self.num_attention_heads
144
-
145
- @property
146
- def rotary(self):
147
- return not self.alibi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generation_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "_from_model_config": true,
3
- "bos_token_id": 11,
4
- "eos_token_id": 11,
5
- "transformers_version": "4.31.0.dev0"
6
- }
 
1
  {
2
  "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.27.4"
6
+ }
modeling_falcon.py → modelling_RW.py RENAMED
@@ -1,20 +1,9 @@
1
- # coding=utf-8
2
- # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch Falcon model."""
16
 
17
  import math
 
18
  from typing import Optional, Tuple, Union
19
 
20
  import torch
@@ -31,60 +20,59 @@ from transformers.modeling_outputs import (
31
  TokenClassifierOutput,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
- from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
35
- from .configuration_falcon import FalconConfig
36
-
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
- FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
- "tiiuae/falcon-40b",
42
- "tiiuae/falcon-40b-instruct",
43
- "tiiuae/falcon-7b",
44
- "tiiuae/falcon-7b-instruct",
45
- "tiiuae/falcon-rw-7b",
46
- "tiiuae/falcon-rw-1b",
47
- ]
48
- _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
49
- _CONFIG_FOR_DOC = "FalconConfig"
50
-
51
-
52
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
53
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
54
- class FalconLinear(nn.Linear):
55
  def forward(self, input: torch.Tensor) -> torch.Tensor:
56
- hidden_states = input @ self.weight.T
57
  if self.bias is None:
58
- return hidden_states
59
- return hidden_states + self.bias
 
 
60
 
 
61
 
62
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
63
  def rotate_half(x):
64
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
- return torch.cat((-x2, x1), dim=-1)
66
 
67
 
68
- class FalconRotaryEmbedding(nn.Module):
69
  """Implementation of RotaryEmbedding from GPT-NeoX.
70
- This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
71
- n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
72
  """
73
 
74
- def __init__(self, head_dim: int, base=10000):
 
 
 
 
75
  super().__init__()
76
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
77
  self.register_buffer("inv_freq", inv_freq, persistent=False)
78
  self.head_dim = head_dim
79
- self.seq_len_cached = -1
 
80
  self.cos_cached: torch.Tensor | None = None
81
  self.sin_cached: torch.Tensor | None = None
82
 
83
- def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
84
- total_length = seq_len + past_key_values_length
85
- if total_length > self.seq_len_cached:
86
- self.seq_len_cached = total_length
87
- t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
 
 
 
 
88
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
89
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
90
 
@@ -97,46 +85,36 @@ class FalconRotaryEmbedding(nn.Module):
97
  self.cos_cached = self.cos_cached.type(dtype)
98
  self.sin_cached = self.sin_cached.type(dtype)
99
 
100
- return (
101
- self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
102
- self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
103
- )
104
 
105
- def forward(self, query, key, past_key_values_length=0):
106
- batch, seq_len, head_dim = query.shape
107
- cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
108
- return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
109
 
110
 
111
  def _make_causal_mask(
112
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
113
  ) -> torch.BoolTensor:
114
- """
115
- Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
116
- just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
117
- target_length, target_length+past_key_values_length]`.
118
- """
119
  batch_size, target_length = input_ids_shape
 
 
 
 
 
 
 
120
 
121
- mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
122
- # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
123
- # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
124
- # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
125
- past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
126
- mask = torch.cat([past_mask, mask], dim=-1)
127
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
128
  return expanded_mask
129
 
130
 
131
- def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
132
- """
133
- Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
134
- """
135
- batch_size, total_length = mask.shape
136
- seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
137
 
138
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
139
- return expanded_mask.expand(batch_size, 1, seq_length, total_length)
140
 
141
 
142
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
@@ -167,32 +145,18 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
167
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
168
 
169
 
170
- # Copied from transformers.models.bloom.modeling_bloom.dropout_add
171
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
- """
173
- Dropout add function
174
-
175
- Args:
176
- x (`torch.tensor`, *required*):
177
- input tensor
178
- residual (`torch.tensor`, *required*):
179
- residual tensor
180
- prob (`float`, *required*):
181
- dropout probability
182
- training (`bool`, *required*):
183
- training mode
184
- """
185
  out = F.dropout(x, p=prob, training=training)
186
  out = residual + out
187
  return out
188
 
189
 
190
- class FalconAttention(nn.Module):
191
- def __init__(self, config: FalconConfig):
192
  super().__init__()
193
 
194
  self.hidden_size = config.hidden_size
195
- self.num_heads = config.num_attention_heads
196
  self.head_dim = self.hidden_size // self.num_heads
197
  self.split_size = self.hidden_size
198
  self.hidden_dropout = config.hidden_dropout
@@ -203,62 +167,59 @@ class FalconAttention(nn.Module):
203
  f" {self.num_heads})."
204
  )
205
 
206
- self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
207
 
208
  # Layer-wise attention scaling
209
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
210
  self.beta = self.inv_norm_factor
211
- if config.new_decoder_architecture:
212
- qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
213
- elif config.multi_query:
214
- qkv_out_dim = self.hidden_size + 2 * self.head_dim
215
- else:
216
- qkv_out_dim = 3 * self.hidden_size
217
- self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
218
- self.new_decoder_architecture = config.new_decoder_architecture
219
- self.multi_query = config.multi_query
220
- self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
221
  self.attention_dropout = nn.Dropout(config.attention_dropout)
222
- self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
223
 
224
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
  """
226
- Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
 
227
 
228
  Args:
229
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
230
 
231
  Returns:
232
- query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
 
233
  value: [batch_size, seq_length, num_heads, head_dim]
234
  """
235
- if self.new_decoder_architecture:
236
- batch, seq_len, _ = fused_qkv.shape
237
- qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
238
- query = qkv[:, :, :, :-2]
239
- key = qkv[:, :, :, [-2]]
240
- value = qkv[:, :, :, [-1]]
241
- key = torch.broadcast_to(key, query.shape)
242
- value = torch.broadcast_to(value, query.shape)
243
-
244
- query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
245
- return query, key, value
246
- elif not self.multi_query:
247
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
248
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
249
- return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
250
- else:
251
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
252
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
253
- return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
254
 
255
- # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
256
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
257
  """
258
  Merge heads together over the last dimenstion
259
 
260
  Args:
261
- x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
262
 
263
  Returns:
264
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
@@ -281,7 +242,7 @@ class FalconAttention(nn.Module):
281
  def forward(
282
  self,
283
  hidden_states: torch.Tensor,
284
- alibi: Optional[torch.Tensor],
285
  attention_mask: torch.Tensor,
286
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
287
  head_mask: Optional[torch.Tensor] = None,
@@ -289,120 +250,106 @@ class FalconAttention(nn.Module):
289
  output_attentions: bool = False,
290
  ):
291
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
292
- num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
293
  # 3 x [batch_size, seq_length, num_heads, head_dim]
294
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
295
 
296
- batch_size, query_length, _, _ = query_layer.shape
297
 
298
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
299
  key_layer = key_layer.transpose(1, 2).reshape(
300
- batch_size * num_kv_heads,
301
- query_length,
302
  self.head_dim,
303
  )
304
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
305
 
306
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
307
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
308
 
309
  if layer_past is not None:
310
  past_key, past_value = layer_past
311
  # concatenate along seq_length dimension:
312
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
313
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
314
  key_layer = torch.cat((past_key, key_layer), dim=1)
315
  value_layer = torch.cat((past_value, value_layer), dim=1)
316
 
317
  _, kv_length, _ = key_layer.shape
318
- if use_cache:
 
319
  present = (key_layer, value_layer)
320
  else:
321
  present = None
322
 
323
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
324
-
325
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
326
- key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
327
- value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
328
-
329
  if alibi is None:
330
- if output_attentions:
331
- # F.scaled_dot_product_attention doesn't return the attention weights, so we have
332
- # to do it by hand if we want them
333
- attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
334
- attention_scores /= math.sqrt(self.head_dim)
335
 
336
- attention_scores = F.softmax(
337
- attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
338
- )
339
- attn_output = attention_scores @ value_layer_
340
- else:
341
- attn_output = F.scaled_dot_product_attention(
342
- query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
343
- )
344
- attention_scores = None
345
 
346
- attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
347
- attn_output = attn_output.permute(0, 2, 1, 3)
348
- attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
349
 
350
  output_tensor = self.dense(attn_output)
351
 
352
- if output_attentions:
353
- return output_tensor, present, attention_scores
354
- else:
355
- return output_tensor, present
356
-
357
  else:
358
- matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
 
359
 
360
  # change view to [batch_size, num_heads, q_length, kv_length]
361
- attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
362
 
363
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
364
  input_dtype = attention_scores.dtype
365
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
366
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
367
  attention_scores = attention_scores.to(torch.float32)
368
- # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
369
- # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
370
- # equivalent and more performant, but there might be a numerical difference. If you're reading this
371
- # and you'd like to experiment and maybe file a PR, feel free!
372
- attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
373
- attention_logits *= self.inv_norm_factor
374
- attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
375
  # [batch_size, num_heads, q_length, kv_length]
376
  attention_probs = self.attention_dropout(attention_probs)
377
 
378
  if head_mask is not None:
379
  attention_probs = attention_probs * head_mask
380
 
381
- # change view [batch_size, num_heads, q_length, kv_length]
382
- attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
383
 
384
  # matmul: [batch_size * num_heads, q_length, head_dim]
385
- context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
386
 
387
  # change view [batch_size, num_heads, q_length, head_dim]
388
  context_layer = self._merge_heads(context_layer)
389
 
390
  output_tensor = self.dense(context_layer)
391
 
 
392
  if output_attentions:
393
- return output_tensor, present, attention_probs
394
- else:
395
- return output_tensor, present
396
 
397
 
398
- class FalconMLP(nn.Module):
399
- def __init__(self, config: FalconConfig):
400
  super().__init__()
401
  hidden_size = config.hidden_size
402
 
403
- self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
404
  self.act = nn.GELU()
405
- self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
406
  self.hidden_dropout = config.hidden_dropout
407
 
408
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -411,47 +358,43 @@ class FalconMLP(nn.Module):
411
  return x
412
 
413
 
414
- class FalconDecoderLayer(nn.Module):
415
- def __init__(self, config: FalconConfig):
416
  super().__init__()
417
  hidden_size = config.hidden_size
418
- self.num_heads = config.num_attention_heads
419
- self.self_attention = FalconAttention(config)
420
- self.mlp = FalconMLP(config)
 
 
 
 
 
 
 
421
  self.hidden_dropout = config.hidden_dropout
422
- self.config = config
423
 
424
- if config.new_decoder_architecture:
425
- # The layer norm before self-attention
426
- self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
427
- # The layer norm before the MLP
428
- self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
429
- else:
430
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
431
- if not config.parallel_attn:
432
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
433
 
434
  def forward(
435
  self,
436
  hidden_states: torch.Tensor,
437
- alibi: Optional[torch.Tensor],
438
  attention_mask: torch.Tensor,
439
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
440
  head_mask: Optional[torch.Tensor] = None,
441
  use_cache: bool = False,
442
  output_attentions: bool = False,
443
  ):
444
- residual = hidden_states
445
 
446
- if self.config.new_decoder_architecture:
447
- attention_layernorm_out = self.ln_attn(hidden_states)
448
- mlp_layernorm_out = self.ln_mlp(hidden_states)
449
- else:
450
- attention_layernorm_out = self.input_layernorm(hidden_states)
451
 
452
  # Self attention.
453
  attn_outputs = self.self_attention(
454
- attention_layernorm_out,
455
  layer_past=layer_past,
456
  attention_mask=attention_mask,
457
  alibi=alibi,
@@ -462,24 +405,14 @@ class FalconDecoderLayer(nn.Module):
462
 
463
  attention_output = attn_outputs[0]
464
 
465
- if not self.config.new_decoder_architecture:
466
- if self.config.parallel_attn:
467
- mlp_layernorm_out = attention_layernorm_out
468
- else:
469
- residual = dropout_add(
470
- attention_output, residual, self.config.attention_dropout, training=self.training
471
- )
472
- mlp_layernorm_out = self.post_attention_layernorm(residual)
473
-
474
  outputs = attn_outputs[1:]
475
 
476
  # MLP.
477
- mlp_output = self.mlp(mlp_layernorm_out)
478
 
479
- if self.config.new_decoder_architecture or self.config.parallel_attn:
480
- mlp_output += attention_output
481
-
482
- output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
483
 
484
  if use_cache:
485
  outputs = (output,) + outputs
@@ -489,93 +422,24 @@ class FalconDecoderLayer(nn.Module):
489
  return outputs # hidden_states, present, attentions
490
 
491
 
492
- FALCON_START_DOCSTRING = r"""
493
-
494
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
495
- library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
496
-
497
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
498
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
499
- and behavior.
500
-
501
- Parameters:
502
- config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
503
- Initializing with a config file does not load the weights associated with the model, only the
504
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
505
- """
506
-
507
- FALCON_INPUTS_DOCSTRING = r"""
508
- Args:
509
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
510
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
511
- (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
512
-
513
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
514
- `input_ids`.
515
-
516
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
517
- [`PreTrainedTokenizer.__call__`] for details.
518
-
519
- [What are input IDs?](../glossary#input-ids)
520
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
521
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
522
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
523
- their past given to this model should not be passed as `input_ids` as they have already been computed.
524
-
525
- Each element of `past_key_values` is a tuple (past_key, past_value):
526
- - past_key: [batch_size * num_heads, head_dim, kv_length]
527
- - past_value: [batch_size * num_heads, kv_length, head_dim]
528
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
529
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
530
-
531
- - 1 for tokens that are **not masked**,
532
- - 0 for tokens that are **masked**.
533
-
534
- [What are attention masks?](../glossary#attention-mask)
535
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
536
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
537
-
538
- - 1 indicates the head is **not masked**,
539
- - 0 indicates the head is **masked**.
540
-
541
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
542
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
543
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
544
- model's internal embedding lookup matrix.
545
-
546
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
547
- `past_key_values`).
548
- use_cache (`bool`, *optional*):
549
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
550
- `past_key_values`).
551
- output_attentions (`bool`, *optional*):
552
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
553
- tensors for more detail.
554
- output_hidden_states (`bool`, *optional*):
555
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
556
- more detail.
557
- return_dict (`bool`, *optional*):
558
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
559
- """
560
-
561
-
562
- class FalconPreTrainedModel(PreTrainedModel):
563
  """
564
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
565
  models.
566
  """
567
 
568
- config_class = FalconConfig
569
  base_model_prefix = "transformer"
570
  supports_gradient_checkpointing = True
571
- _no_split_modules = ["FalconDecoderLayer"]
572
 
573
  def __init__(self, *inputs, **kwargs):
574
  super().__init__(*inputs, **kwargs)
575
 
576
  def _init_weights(self, module: nn.Module):
577
  """Initialize the weights."""
578
- if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
579
  # Slightly different from the TF version which uses truncated_normal for initialization
580
  # cf https://github.com/pytorch/pytorch/pull/5617
581
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@@ -589,28 +453,26 @@ class FalconPreTrainedModel(PreTrainedModel):
589
  module.bias.data.zero_()
590
  module.weight.data.fill_(1.0)
591
 
592
- # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
593
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
594
- if isinstance(module, FalconModel):
595
  module.gradient_checkpointing = value
596
 
597
  @staticmethod
598
- def _convert_cache_to_standard_format(
599
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
600
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
601
  """
602
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
603
  num_heads, ...]))
604
  """
605
- batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
606
- # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
607
- # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
608
- # on whether we use multi_query attention.
609
  num_heads = batch_size_times_num_heads // batch_size
 
 
610
  return tuple(
611
  (
612
- layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
613
- layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
614
  )
615
  for layer_past in past_key_value
616
  )
@@ -619,35 +481,32 @@ class FalconPreTrainedModel(PreTrainedModel):
619
  def _convert_to_rw_cache(
620
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
621
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
622
- batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
623
  batch_size_times_num_heads = batch_size * num_heads
624
- # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
 
625
  return tuple(
626
  (
627
- layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
628
- layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
629
  )
630
  for layer_past in past_key_value
631
  )
632
 
633
 
634
- @add_start_docstrings(
635
- "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
636
- FALCON_START_DOCSTRING,
637
- )
638
- class FalconModel(FalconPreTrainedModel):
639
- def __init__(self, config: FalconConfig):
640
  super().__init__(config)
641
 
642
  self.embed_dim = config.hidden_size
643
- self.num_heads = config.num_attention_heads
644
- self.use_alibi = config.alibi
645
 
646
  # Embedding + LN Embedding
647
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
648
 
649
  # Transformer blocks
650
- self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
651
 
652
  # Final Layer Norm
653
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -660,31 +519,22 @@ class FalconModel(FalconPreTrainedModel):
660
  def get_input_embeddings(self):
661
  return self.word_embeddings
662
 
663
- @staticmethod
664
  def _prepare_attn_mask(
665
- attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
666
  ) -> torch.BoolTensor:
667
- # Create a causal mask
668
- # The attention mask we receive as input should cover the whole extended sequence, including any past
669
- # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
670
- # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
671
- if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
672
- raise ValueError(
673
- "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
674
- f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
675
- f" {past_key_values_length}."
676
- )
677
  combined_attention_mask = None
678
  device = attention_mask.device
679
- _, seq_length = input_shape
680
 
681
- if seq_length > 1:
682
  combined_attention_mask = _make_causal_mask(
683
  input_shape, device=device, past_key_values_length=past_key_values_length
684
  )
685
 
686
- # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
687
- expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
688
  combined_attention_mask = (
689
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
690
  )
@@ -694,12 +544,6 @@ class FalconModel(FalconPreTrainedModel):
694
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
695
  self.word_embeddings = new_embeddings
696
 
697
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
698
- @add_code_sample_docstrings(
699
- checkpoint=_CHECKPOINT_FOR_DOC,
700
- output_type=BaseModelOutputWithPastAndCrossAttentions,
701
- config_class=_CONFIG_FOR_DOC,
702
- )
703
  def forward(
704
  self,
705
  input_ids: Optional[torch.LongTensor] = None,
@@ -711,7 +555,18 @@ class FalconModel(FalconPreTrainedModel):
711
  output_attentions: Optional[bool] = None,
712
  output_hidden_states: Optional[bool] = None,
713
  return_dict: Optional[bool] = None,
 
714
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
715
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
  output_hidden_states = (
717
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -730,14 +585,12 @@ class FalconModel(FalconPreTrainedModel):
730
 
731
  if past_key_values is None:
732
  past_key_values = tuple([None] * len(self.h))
733
- else:
734
- past_key_values = self._convert_to_rw_cache(past_key_values)
735
 
736
  # Prepare head mask if needed
737
  # 1.0 in head_mask indicate we keep the head
738
  # attention_probs has shape batch_size x num_heads x N x N
739
  # head_mask has shape n_layer x batch x num_heads x N x N
740
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
741
 
742
  if inputs_embeds is None:
743
  inputs_embeds = self.word_embeddings(input_ids)
@@ -749,15 +602,17 @@ class FalconModel(FalconPreTrainedModel):
749
  all_hidden_states = () if output_hidden_states else None
750
 
751
  # Compute alibi tensor: check build_alibi_tensor documentation
 
752
  past_key_values_length = 0
753
  if past_key_values[0] is not None:
754
- past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
 
755
  if attention_mask is None:
756
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
757
  else:
758
  attention_mask = attention_mask.to(hidden_states.device)
759
 
760
- if self.use_alibi:
761
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
762
  else:
763
  alibi = None
@@ -769,10 +624,12 @@ class FalconModel(FalconPreTrainedModel):
769
  )
770
 
771
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
772
  if output_hidden_states:
773
  all_hidden_states = all_hidden_states + (hidden_states,)
774
 
775
  if self.gradient_checkpointing and self.training:
 
776
  if use_cache:
777
  logger.warning(
778
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -817,9 +674,6 @@ class FalconModel(FalconPreTrainedModel):
817
  if output_hidden_states:
818
  all_hidden_states = all_hidden_states + (hidden_states,)
819
 
820
- if presents is not None:
821
- presents = self._convert_cache_to_standard_format(presents, batch_size)
822
-
823
  if not return_dict:
824
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
825
 
@@ -831,16 +685,12 @@ class FalconModel(FalconPreTrainedModel):
831
  )
832
 
833
 
834
- @add_start_docstrings(
835
- "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
836
- FALCON_START_DOCSTRING,
837
- )
838
- class FalconForCausalLM(FalconPreTrainedModel):
839
- _tied_weights_keys = ["lm_head.weight"]
840
 
841
- def __init__(self, config: FalconConfig):
842
  super().__init__(config)
843
- self.transformer = FalconModel(config)
844
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
845
 
846
  # Initialize weights and apply final processing
@@ -855,26 +705,25 @@ class FalconForCausalLM(FalconPreTrainedModel):
855
  def prepare_inputs_for_generation(
856
  self,
857
  input_ids: torch.LongTensor,
858
- past_key_values: Optional[torch.Tensor] = None,
859
  attention_mask: Optional[torch.Tensor] = None,
860
  **kwargs,
861
  ) -> dict:
862
- if past_key_values is not None:
863
- input_ids = input_ids[:, -1:]
 
 
 
 
 
864
 
865
  return {
866
  "input_ids": input_ids,
867
- "past_key_values": past_key_values,
868
  "use_cache": kwargs.get("use_cache"),
869
  "attention_mask": attention_mask,
870
  }
871
 
872
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
873
- @add_code_sample_docstrings(
874
- checkpoint=_CHECKPOINT_FOR_DOC,
875
- output_type=CausalLMOutputWithCrossAttentions,
876
- config_class=_CONFIG_FOR_DOC,
877
- )
878
  def forward(
879
  self,
880
  input_ids: Optional[torch.LongTensor] = None,
@@ -887,6 +736,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
887
  output_attentions: Optional[bool] = None,
888
  output_hidden_states: Optional[bool] = None,
889
  return_dict: Optional[bool] = None,
 
890
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
891
  r"""
892
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -894,6 +744,15 @@ class FalconForCausalLM(FalconPreTrainedModel):
894
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
895
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
896
  """
 
 
 
 
 
 
 
 
 
897
 
898
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
899
 
@@ -946,6 +805,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
946
 
947
  Output shares the same memory storage as `past`.
948
  """
 
949
 
950
  # Get a copy of `beam_idx` on all the devices where we need those indices.
951
  device_to_beam_idx = {
@@ -956,42 +816,23 @@ class FalconForCausalLM(FalconPreTrainedModel):
956
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
957
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
958
  )
959
- for layer_past in past
960
  )
961
- return reordered_past
962
 
963
 
964
- @add_start_docstrings(
965
- """
966
- The Falcon Model transformer with a sequence classification head on top (linear layer).
967
-
968
- [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
969
- (e.g. GPT-1) do.
970
-
971
- Since it does classification on the last token, it requires to know the position of the last token. If a
972
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
973
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
974
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
975
- each row of the batch).
976
- """,
977
- FALCON_START_DOCSTRING,
978
- )
979
- class FalconForSequenceClassification(FalconPreTrainedModel):
980
- def __init__(self, config: FalconConfig):
981
  super().__init__(config)
982
  self.num_labels = config.num_labels
983
- self.transformer = FalconModel(config)
984
  self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
985
 
986
  # Initialize weights and apply final processing
987
  self.post_init()
988
 
989
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
990
- @add_code_sample_docstrings(
991
- checkpoint=_CHECKPOINT_FOR_DOC,
992
- output_type=SequenceClassifierOutputWithPast,
993
- config_class=_CONFIG_FOR_DOC,
994
- )
995
  def forward(
996
  self,
997
  input_ids: Optional[torch.LongTensor] = None,
@@ -1004,6 +845,7 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
1004
  output_attentions: Optional[bool] = None,
1005
  output_hidden_states: Optional[bool] = None,
1006
  return_dict: Optional[bool] = None,
 
1007
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1008
  r"""
1009
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1011,6 +853,15 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
1011
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1012
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1013
  """
 
 
 
 
 
 
 
 
 
1014
 
1015
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1016
 
@@ -1085,22 +936,17 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
1085
  )
1086
 
1087
 
1088
- @add_start_docstrings(
1089
- """
1090
- Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1091
- Named-Entity-Recognition (NER) tasks.
1092
- """,
1093
- FALCON_START_DOCSTRING,
1094
- )
1095
- class FalconForTokenClassification(FalconPreTrainedModel):
1096
- def __init__(self, config: FalconConfig):
1097
  super().__init__(config)
1098
  self.num_labels = config.num_labels
1099
 
1100
- self.transformer = FalconModel(config)
1101
- if getattr(config, "classifier_dropout", None) is not None:
1102
  classifier_dropout = config.classifier_dropout
1103
- elif getattr(config, "hidden_dropout", None) is not None:
1104
  classifier_dropout = config.hidden_dropout
1105
  else:
1106
  classifier_dropout = 0.1
@@ -1110,12 +956,6 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1110
  # Initialize weights and apply final processing
1111
  self.post_init()
1112
 
1113
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1114
- @add_code_sample_docstrings(
1115
- checkpoint=_CHECKPOINT_FOR_DOC,
1116
- output_type=TokenClassifierOutput,
1117
- config_class=_CONFIG_FOR_DOC,
1118
- )
1119
  def forward(
1120
  self,
1121
  input_ids: Optional[torch.LongTensor] = None,
@@ -1128,6 +968,7 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1128
  output_attentions: Optional[bool] = None,
1129
  output_hidden_states: Optional[bool] = None,
1130
  return_dict: Optional[bool] = None,
 
1131
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1132
  r"""
1133
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1135,6 +976,15 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1135
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1136
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1137
  """
 
 
 
 
 
 
 
 
 
1138
 
1139
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1140
 
@@ -1158,9 +1008,7 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1158
  if labels is not None:
1159
  batch_size, seq_length = labels.shape
1160
  loss_fct = CrossEntropyLoss()
1161
- loss = loss_fct(
1162
- logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1163
- )
1164
 
1165
  if not return_dict:
1166
  output = (logits,) + transformer_outputs[2:]
@@ -1174,27 +1022,22 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1174
  )
1175
 
1176
 
1177
- @add_start_docstrings(
1178
- """
1179
- The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1180
- SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1181
- """,
1182
- FALCON_START_DOCSTRING,
1183
- )
1184
- class FalconForQuestionAnswering(FalconPreTrainedModel):
1185
  def __init__(self, config):
1186
  super().__init__(config)
1187
- self.transformer = FalconModel(config)
1188
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1189
 
1190
  # Initialize weights and apply final processing
1191
  self.post_init()
1192
 
1193
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1194
  def forward(
1195
  self,
1196
  input_ids: Optional[torch.LongTensor] = None,
1197
  attention_mask: Optional[torch.FloatTensor] = None,
 
1198
  head_mask: Optional[torch.FloatTensor] = None,
1199
  inputs_embeds: Optional[torch.FloatTensor] = None,
1200
  start_positions: Optional[torch.LongTensor] = None,
@@ -1218,6 +1061,7 @@ class FalconForQuestionAnswering(FalconPreTrainedModel):
1218
  outputs = self.transformer(
1219
  input_ids,
1220
  attention_mask=attention_mask,
 
1221
  head_mask=head_mask,
1222
  inputs_embeds=inputs_embeds,
1223
  output_attentions=output_attentions,
 
1
+ # port of models described in RW
2
+ # We use the bloom model as a starting point for these model.
3
+ # Please refer to the bloom models for usage instructions.
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import math
6
+ import warnings
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
 
20
  TokenClassifierOutput,
21
  )
22
  from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import logging
24
+ from .configuration_RW import RWConfig
 
25
 
26
  logger = logging.get_logger(__name__)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
+ class Linear(nn.Linear):
31
  def forward(self, input: torch.Tensor) -> torch.Tensor:
32
+ ret = input @ self.weight.T
33
  if self.bias is None:
34
+ return ret
35
+ else:
36
+ return ret + self.bias
37
+
38
 
39
+ from einops import rearrange
40
 
41
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
42
  def rotate_half(x):
43
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
44
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
45
 
46
 
47
+ class RotaryEmbedding(torch.nn.Module):
48
  """Implementation of RotaryEmbedding from GPT-NeoX.
49
+ This implementation is design to operate on queries and keys that are compatible with
50
+ [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
51
  """
52
 
53
+ def __init__(
54
+ self,
55
+ head_dim: int,
56
+ base=10000,
57
+ ):
58
  super().__init__()
59
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
+ self.seq_len_cached = None
63
+ self.batch_size_cached = None
64
  self.cos_cached: torch.Tensor | None = None
65
  self.sin_cached: torch.Tensor | None = None
66
 
67
+ def cos_sin(
68
+ self,
69
+ seq_len: int,
70
+ device="cuda",
71
+ dtype=torch.bfloat16,
72
+ ) -> torch.Tensor:
73
+ if seq_len != self.seq_len_cached:
74
+ self.seq_len_cached = seq_len
75
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
77
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
78
 
 
85
  self.cos_cached = self.cos_cached.type(dtype)
86
  self.sin_cached = self.sin_cached.type(dtype)
87
 
88
+ return self.cos_cached, self.sin_cached
 
 
 
89
 
90
+ def forward(self, q, k):
91
+ batch, seq_len, head_dim = q.shape
92
+ cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
 
96
  def _make_causal_mask(
97
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
98
  ) -> torch.BoolTensor:
 
 
 
 
 
99
  batch_size, target_length = input_ids_shape
100
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
+ seq_ids = torch.arange(target_length, device=device)
103
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
+
105
+ if past_key_values_length > 0:
106
+ mask[:, :past_key_values_length] = False
107
 
 
 
 
 
 
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
110
 
111
 
112
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
+ batch_size, src_length = mask.shape
114
+ tgt_length = tgt_length if tgt_length is not None else src_length
 
 
 
115
 
116
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
 
119
 
120
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 
145
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
146
 
147
 
 
148
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  out = F.dropout(x, p=prob, training=training)
150
  out = residual + out
151
  return out
152
 
153
 
154
+ class Attention(nn.Module):
155
+ def __init__(self, config: RWConfig):
156
  super().__init__()
157
 
158
  self.hidden_size = config.hidden_size
159
+ self.num_heads = config.n_head
160
  self.head_dim = self.hidden_size // self.num_heads
161
  self.split_size = self.hidden_size
162
  self.hidden_dropout = config.hidden_dropout
 
167
  f" {self.num_heads})."
168
  )
169
 
170
+ self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
171
 
172
  # Layer-wise attention scaling
173
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
174
  self.beta = self.inv_norm_factor
175
+
176
+ self.query_key_value = Linear(
177
+ self.hidden_size,
178
+ (config.n_head_kv * 2 + config.n_head) * self.head_dim,
179
+ bias=config.bias,
180
+ )
181
+ self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
 
 
 
182
  self.attention_dropout = nn.Dropout(config.attention_dropout)
183
+ self.num_kv = config.n_head_kv
184
 
185
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
186
  """
187
+ Split the last dimension into (num_heads, head_dim), results share same memory
188
+ storage as `fused_qkv`
189
 
190
  Args:
191
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
192
 
193
  Returns:
194
+ query: [batch_size, seq_length, num_heads, head_dim]
195
+ key: [batch_size, seq_length, num_heads, head_dim]
196
  value: [batch_size, seq_length, num_heads, head_dim]
197
  """
198
+ batch, seq_len, _ = fused_qkv.shape
199
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv + 2, 64)
200
+ q = qkv[:, :, :, :-2]
201
+ k = qkv[:, :, :, [-2]]
202
+ v = qkv[:, :, :, [-1]]
203
+ k = torch.broadcast_to(k, q.shape)
204
+ v = torch.broadcast_to(v, q.shape)
205
+
206
+ q, k, v = [
207
+ rearrange(
208
+ x,
209
+ "batch seq_len group num_heads head_dim ->\
210
+ batch seq_len (group num_heads) head_dim",
211
+ head_dim=self.head_dim,
212
+ )
213
+ for x in [q, k, v]
214
+ ]
215
+ return q, k, v
 
216
 
 
217
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
218
  """
219
  Merge heads together over the last dimenstion
220
 
221
  Args:
222
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
223
 
224
  Returns:
225
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
 
242
  def forward(
243
  self,
244
  hidden_states: torch.Tensor,
245
+ alibi: torch.Tensor,
246
  attention_mask: torch.Tensor,
247
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
248
  head_mask: Optional[torch.Tensor] = None,
 
250
  output_attentions: bool = False,
251
  ):
252
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
253
+
254
  # 3 x [batch_size, seq_length, num_heads, head_dim]
255
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
256
 
257
+ batch_size, q_length, _, _ = query_layer.shape
258
 
259
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
260
  key_layer = key_layer.transpose(1, 2).reshape(
261
+ batch_size * self.num_heads,
262
+ q_length,
263
  self.head_dim,
264
  )
265
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
266
 
267
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
268
 
269
  if layer_past is not None:
270
  past_key, past_value = layer_past
271
  # concatenate along seq_length dimension:
272
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
274
  key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
278
+
279
+ if use_cache is True:
280
  present = (key_layer, value_layer)
281
  else:
282
  present = None
283
 
 
 
 
 
 
 
284
  if alibi is None:
285
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
286
+ key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
+ value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
 
 
288
 
289
+ attn_output = F.scaled_dot_product_attention(
290
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
291
+ )
 
 
 
 
 
 
292
 
293
+ x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
294
+ x = x.permute(0, 2, 1, 3)
295
+ attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
296
 
297
  output_tensor = self.dense(attn_output)
298
 
299
+ outputs = (output_tensor, present)
300
+ assert not output_attentions # not supported.
301
+ return outputs
 
 
302
  else:
303
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
304
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
305
 
306
  # change view to [batch_size, num_heads, q_length, kv_length]
307
+ attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
308
 
309
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
310
  input_dtype = attention_scores.dtype
311
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
312
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
313
  attention_scores = attention_scores.to(torch.float32)
314
+ # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
315
+ attention_probs = F.softmax(
316
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
317
+ + attention_mask_float,
318
+ dim=-1,
319
+ dtype=hidden_states.dtype,
320
+ )
321
  # [batch_size, num_heads, q_length, kv_length]
322
  attention_probs = self.attention_dropout(attention_probs)
323
 
324
  if head_mask is not None:
325
  attention_probs = attention_probs * head_mask
326
 
327
+ # change view [batch_size x num_heads, q_length, kv_length]
328
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
329
 
330
  # matmul: [batch_size * num_heads, q_length, head_dim]
331
+ context_layer = attention_probs_reshaped @ value_layer
332
 
333
  # change view [batch_size, num_heads, q_length, head_dim]
334
  context_layer = self._merge_heads(context_layer)
335
 
336
  output_tensor = self.dense(context_layer)
337
 
338
+ outputs = (output_tensor, present)
339
  if output_attentions:
340
+ outputs += (attention_probs,)
341
+
342
+ return outputs
343
 
344
 
345
+ class MLP(nn.Module):
346
+ def __init__(self, config: RWConfig):
347
  super().__init__()
348
  hidden_size = config.hidden_size
349
 
350
+ self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
351
  self.act = nn.GELU()
352
+ self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
353
  self.hidden_dropout = config.hidden_dropout
354
 
355
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
358
  return x
359
 
360
 
361
+ class DecoderLayer(nn.Module):
362
+ def __init__(self, config: RWConfig):
363
  super().__init__()
364
  hidden_size = config.hidden_size
365
+
366
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
367
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
368
+
369
+ self.num_heads = config.n_head
370
+ self.self_attention = Attention(config)
371
+
372
+ self.mlp = MLP(config)
373
+
374
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
375
  self.hidden_dropout = config.hidden_dropout
 
376
 
377
+ self.config = config
 
 
 
 
 
 
 
 
378
 
379
  def forward(
380
  self,
381
  hidden_states: torch.Tensor,
382
+ alibi: torch.Tensor,
383
  attention_mask: torch.Tensor,
384
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
385
  head_mask: Optional[torch.Tensor] = None,
386
  use_cache: bool = False,
387
  output_attentions: bool = False,
388
  ):
 
389
 
390
+ ln_attn = self.ln_attn(hidden_states)
391
+ ln_mlp = self.ln_mlp(hidden_states)
392
+
393
+ residual = hidden_states
 
394
 
395
  # Self attention.
396
  attn_outputs = self.self_attention(
397
+ ln_attn,
398
  layer_past=layer_past,
399
  attention_mask=attention_mask,
400
  alibi=alibi,
 
405
 
406
  attention_output = attn_outputs[0]
407
 
 
 
 
 
 
 
 
 
 
408
  outputs = attn_outputs[1:]
409
 
410
  # MLP.
411
+ mlp_output = self.mlp(ln_mlp)
412
 
413
+ output = dropout_add(
414
+ mlp_output + attention_output, residual, self.config.hidden_dropout, training=self.training
415
+ )
 
416
 
417
  if use_cache:
418
  outputs = (output,) + outputs
 
422
  return outputs # hidden_states, present, attentions
423
 
424
 
425
+ class RWPreTrainedModel(PreTrainedModel):
426
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  """
428
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
429
  models.
430
  """
431
 
432
+ config_class = RWConfig
433
  base_model_prefix = "transformer"
434
  supports_gradient_checkpointing = True
435
+ _no_split_modules = ["DecoderLayer"]
436
 
437
  def __init__(self, *inputs, **kwargs):
438
  super().__init__(*inputs, **kwargs)
439
 
440
  def _init_weights(self, module: nn.Module):
441
  """Initialize the weights."""
442
+ if isinstance(module, nn.Linear) or isinstance(module, Linear):
443
  # Slightly different from the TF version which uses truncated_normal for initialization
444
  # cf https://github.com/pytorch/pytorch/pull/5617
445
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
453
  module.bias.data.zero_()
454
  module.weight.data.fill_(1.0)
455
 
 
456
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
457
+ if isinstance(module, RWModel):
458
  module.gradient_checkpointing = value
459
 
460
  @staticmethod
461
+ def _convert_to_standard_cache(
462
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
463
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
464
  """
465
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
466
  num_heads, ...]))
467
  """
468
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
 
 
 
469
  num_heads = batch_size_times_num_heads // batch_size
470
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
471
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
472
  return tuple(
473
  (
474
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
475
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
476
  )
477
  for layer_past in past_key_value
478
  )
 
481
  def _convert_to_rw_cache(
482
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
483
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
484
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
485
  batch_size_times_num_heads = batch_size * num_heads
486
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
487
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
488
  return tuple(
489
  (
490
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
491
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
492
  )
493
  for layer_past in past_key_value
494
  )
495
 
496
 
497
+ class RWModel(RWPreTrainedModel):
498
+ def __init__(self, config: RWConfig):
 
 
 
 
499
  super().__init__(config)
500
 
501
  self.embed_dim = config.hidden_size
502
+ self.num_heads = config.n_head
503
+ self.alibi = config.alibi
504
 
505
  # Embedding + LN Embedding
506
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
507
 
508
  # Transformer blocks
509
+ self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
510
 
511
  # Final Layer Norm
512
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
519
  def get_input_embeddings(self):
520
  return self.word_embeddings
521
 
 
522
  def _prepare_attn_mask(
523
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
524
  ) -> torch.BoolTensor:
525
+ # create causal mask
526
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 
 
 
 
 
 
 
 
527
  combined_attention_mask = None
528
  device = attention_mask.device
529
+ _, src_length = input_shape
530
 
531
+ if src_length > 1:
532
  combined_attention_mask = _make_causal_mask(
533
  input_shape, device=device, past_key_values_length=past_key_values_length
534
  )
535
 
536
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
537
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
538
  combined_attention_mask = (
539
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
540
  )
 
544
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
545
  self.word_embeddings = new_embeddings
546
 
 
 
 
 
 
 
547
  def forward(
548
  self,
549
  input_ids: Optional[torch.LongTensor] = None,
 
555
  output_attentions: Optional[bool] = None,
556
  output_hidden_states: Optional[bool] = None,
557
  return_dict: Optional[bool] = None,
558
+ **deprecated_arguments,
559
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
560
+ if deprecated_arguments.pop("position_ids", False) is not False:
561
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
562
+ warnings.warn(
563
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
564
+ " passing `position_ids`.",
565
+ FutureWarning,
566
+ )
567
+ if len(deprecated_arguments) > 0:
568
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
569
+
570
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
571
  output_hidden_states = (
572
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
585
 
586
  if past_key_values is None:
587
  past_key_values = tuple([None] * len(self.h))
 
 
588
 
589
  # Prepare head mask if needed
590
  # 1.0 in head_mask indicate we keep the head
591
  # attention_probs has shape batch_size x num_heads x N x N
592
  # head_mask has shape n_layer x batch x num_heads x N x N
593
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
594
 
595
  if inputs_embeds is None:
596
  inputs_embeds = self.word_embeddings(input_ids)
 
602
  all_hidden_states = () if output_hidden_states else None
603
 
604
  # Compute alibi tensor: check build_alibi_tensor documentation
605
+ seq_length_with_past = seq_length
606
  past_key_values_length = 0
607
  if past_key_values[0] is not None:
608
+ past_key_values_length = past_key_values[0][0].shape[2]
609
+ seq_length_with_past = seq_length_with_past + past_key_values_length
610
  if attention_mask is None:
611
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
612
  else:
613
  attention_mask = attention_mask.to(hidden_states.device)
614
 
615
+ if self.alibi:
616
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
617
  else:
618
  alibi = None
 
624
  )
625
 
626
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
627
+
628
  if output_hidden_states:
629
  all_hidden_states = all_hidden_states + (hidden_states,)
630
 
631
  if self.gradient_checkpointing and self.training:
632
+
633
  if use_cache:
634
  logger.warning(
635
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 
674
  if output_hidden_states:
675
  all_hidden_states = all_hidden_states + (hidden_states,)
676
 
 
 
 
677
  if not return_dict:
678
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
679
 
 
685
  )
686
 
687
 
688
+ class RWForCausalLM(RWPreTrainedModel):
689
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
 
 
690
 
691
+ def __init__(self, config: RWConfig):
692
  super().__init__(config)
693
+ self.transformer = RWModel(config)
694
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
695
 
696
  # Initialize weights and apply final processing
 
705
  def prepare_inputs_for_generation(
706
  self,
707
  input_ids: torch.LongTensor,
708
+ past: Optional[torch.Tensor] = None,
709
  attention_mask: Optional[torch.Tensor] = None,
710
  **kwargs,
711
  ) -> dict:
712
+ # only last token for input_ids if past is not None
713
+ if past:
714
+ input_ids = input_ids[:, -1].unsqueeze(-1)
715
+
716
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
717
+ if past[0][0].shape[0] == input_ids.shape[0]:
718
+ past = self._convert_to_rw_cache(past)
719
 
720
  return {
721
  "input_ids": input_ids,
722
+ "past_key_values": past,
723
  "use_cache": kwargs.get("use_cache"),
724
  "attention_mask": attention_mask,
725
  }
726
 
 
 
 
 
 
 
727
  def forward(
728
  self,
729
  input_ids: Optional[torch.LongTensor] = None,
 
736
  output_attentions: Optional[bool] = None,
737
  output_hidden_states: Optional[bool] = None,
738
  return_dict: Optional[bool] = None,
739
+ **deprecated_arguments,
740
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
741
  r"""
742
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
744
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
745
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
746
  """
747
+ if deprecated_arguments.pop("position_ids", False) is not False:
748
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
749
+ warnings.warn(
750
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
751
+ " passing `position_ids`.",
752
+ FutureWarning,
753
+ )
754
+ if len(deprecated_arguments) > 0:
755
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
756
 
757
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
758
 
 
805
 
806
  Output shares the same memory storage as `past`.
807
  """
808
+ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
809
 
810
  # Get a copy of `beam_idx` on all the devices where we need those indices.
811
  device_to_beam_idx = {
 
816
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
817
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
818
  )
819
+ for layer_past in standardized_past
820
  )
821
+ return self._convert_to_rw_cache(reordered_past)
822
 
823
 
824
+ class RWForSequenceClassification(RWPreTrainedModel):
825
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
826
+
827
+ def __init__(self, config: RWConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
828
  super().__init__(config)
829
  self.num_labels = config.num_labels
830
+ self.transformer = RWModel(config)
831
  self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
832
 
833
  # Initialize weights and apply final processing
834
  self.post_init()
835
 
 
 
 
 
 
 
836
  def forward(
837
  self,
838
  input_ids: Optional[torch.LongTensor] = None,
 
845
  output_attentions: Optional[bool] = None,
846
  output_hidden_states: Optional[bool] = None,
847
  return_dict: Optional[bool] = None,
848
+ **deprecated_arguments,
849
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
850
  r"""
851
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
853
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
854
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
855
  """
856
+ if deprecated_arguments.pop("position_ids", False) is not False:
857
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
858
+ warnings.warn(
859
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
860
+ " passing `position_ids`.",
861
+ FutureWarning,
862
+ )
863
+ if len(deprecated_arguments) > 0:
864
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
865
 
866
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
867
 
 
936
  )
937
 
938
 
939
+ class RWForTokenClassification(RWPreTrainedModel):
940
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
941
+
942
+ def __init__(self, config: RWConfig):
 
 
 
 
 
943
  super().__init__(config)
944
  self.num_labels = config.num_labels
945
 
946
+ self.transformer = RWModel(config)
947
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
948
  classifier_dropout = config.classifier_dropout
949
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
950
  classifier_dropout = config.hidden_dropout
951
  else:
952
  classifier_dropout = 0.1
 
956
  # Initialize weights and apply final processing
957
  self.post_init()
958
 
 
 
 
 
 
 
959
  def forward(
960
  self,
961
  input_ids: Optional[torch.LongTensor] = None,
 
968
  output_attentions: Optional[bool] = None,
969
  output_hidden_states: Optional[bool] = None,
970
  return_dict: Optional[bool] = None,
971
+ **deprecated_arguments,
972
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
973
  r"""
974
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
976
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
977
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
978
  """
979
+ if deprecated_arguments.pop("position_ids", False) is not False:
980
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
981
+ warnings.warn(
982
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
983
+ " passing `position_ids`.",
984
+ FutureWarning,
985
+ )
986
+ if len(deprecated_arguments) > 0:
987
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
988
 
989
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
990
 
 
1008
  if labels is not None:
1009
  batch_size, seq_length = labels.shape
1010
  loss_fct = CrossEntropyLoss()
1011
+ loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
 
 
1012
 
1013
  if not return_dict:
1014
  output = (logits,) + transformer_outputs[2:]
 
1022
  )
1023
 
1024
 
1025
+ class RWForQuestionAnswering(RWPreTrainedModel):
1026
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1027
+
 
 
 
 
 
1028
  def __init__(self, config):
1029
  super().__init__(config)
1030
+ self.transformer = RWModel(config)
1031
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1032
 
1033
  # Initialize weights and apply final processing
1034
  self.post_init()
1035
 
 
1036
  def forward(
1037
  self,
1038
  input_ids: Optional[torch.LongTensor] = None,
1039
  attention_mask: Optional[torch.FloatTensor] = None,
1040
+ position_ids: Optional[torch.LongTensor] = None,
1041
  head_mask: Optional[torch.FloatTensor] = None,
1042
  inputs_embeds: Optional[torch.FloatTensor] = None,
1043
  start_positions: Optional[torch.LongTensor] = None,
 
1061
  outputs = self.transformer(
1062
  input_ids,
1063
  attention_mask=attention_mask,
1064
+ position_ids=position_ids,
1065
  head_mask=head_mask,
1066
  inputs_embeds=inputs_embeds,
1067
  output_attentions=output_attentions,
tokenizer_config.json CHANGED
@@ -1,11 +1,7 @@
1
  {
2
  "add_prefix_space": false,
3
  "eos_token": "<|endoftext|>",
4
- "model_input_names": [
5
- "input_ids",
6
- "attention_mask"
7
- ],
8
  "model_max_length": 2048,
9
  "special_tokens_map_file": null,
10
  "tokenizer_class": "PreTrainedTokenizerFast"
11
- }
 
1
  {
2
  "add_prefix_space": false,
3
  "eos_token": "<|endoftext|>",
 
 
 
 
4
  "model_max_length": 2048,
5
  "special_tokens_map_file": null,
6
  "tokenizer_class": "PreTrainedTokenizerFast"
7
+ }