emozilla commited on
Commit
2e25a9a
·
1 Parent(s): 4ece65f

add fast loading/inference

Browse files
README.md CHANGED
@@ -9,6 +9,9 @@ datasets:
9
  inference: false
10
  ---
11
 
 
 
 
12
  # MPT-7B-StoryWriter-65k+
13
 
14
  MPT-7B-StoryWriter-65k+ is a model designed to read and write fictional stories with super long context lengths.
 
9
  inference: false
10
  ---
11
 
12
+ The code for this model has been updated to include the adaptions from [Birchlabs/mosaicml-mpt-7b-chat-qlora](https://huggingface.co/Birchlabs/mosaicml-mpt-7b-chat-qlora) which allow MPT models to be loaded with `device_map="auto"` and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) support (e.g. `load_in_8bit`, `load_in_4bit`).
13
+ It also has the [latest key-value cache MPT code](https://github.com/mosaicml/llm-foundry/pull/210) to allow for fast inference with `transformers` (thus, `use_cache` is set to `True` in `config.json`).
14
+
15
  # MPT-7B-StoryWriter-65k+
16
 
17
  MPT-7B-StoryWriter-65k+ is a model designed to read and write fictional stories with super long context lengths.
adapt_tokenizer.py CHANGED
@@ -1,41 +1,27 @@
 
1
  from typing import Union
2
  from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3
- Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4
  NUM_SENTINEL_TOKENS: int = 100
5
 
6
  def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7
- """Adds sentinel tokens and padding token (if missing).
8
-
9
- Expands the tokenizer vocabulary to include sentinel tokens
10
- used in mixture-of-denoiser tasks as well as a padding token.
11
-
12
- All added tokens are added as special tokens. No tokens are
13
- added if sentinel tokens and padding token already exist.
14
- """
15
  sentinels_to_add = [f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)]
16
  tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17
- if tokenizer.pad_token is None:
18
  tokenizer.add_tokens('<pad>', special_tokens=True)
19
  tokenizer.pad_token = '<pad>'
20
- assert tokenizer.pad_token_id is not None
21
  sentinels = ''.join([f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)])
22
  _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23
  tokenizer.sentinel_token_ids = _sentinel_token_ids
24
 
25
  class AutoTokenizerForMOD(AutoTokenizer):
26
- """AutoTokenizer + Adaptation for MOD.
27
-
28
- A simple wrapper around AutoTokenizer to make instantiating
29
- an MOD-adapted tokenizer a bit easier.
30
-
31
- MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
32
- a padding token, and a property to get the token ids of the
33
- sentinel tokens.
34
- """
35
 
36
  @classmethod
37
  def from_pretrained(cls, *args, **kwargs):
38
- """See `AutoTokenizer.from_pretrained` docstring."""
39
  tokenizer = super().from_pretrained(*args, **kwargs)
40
  adapt_tokenizer_for_denoising(tokenizer)
41
- return tokenizer
 
1
+
2
  from typing import Union
3
  from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
4
+ Tokenizer = Union[(PreTrainedTokenizer, PreTrainedTokenizerFast)]
5
  NUM_SENTINEL_TOKENS: int = 100
6
 
7
  def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
8
+ 'Adds sentinel tokens and padding token (if missing).\n\n Expands the tokenizer vocabulary to include sentinel tokens\n used in mixture-of-denoiser tasks as well as a padding token.\n\n All added tokens are added as special tokens. No tokens are\n added if sentinel tokens and padding token already exist.\n '
 
 
 
 
 
 
 
9
  sentinels_to_add = [f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)]
10
  tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
11
+ if (tokenizer.pad_token is None):
12
  tokenizer.add_tokens('<pad>', special_tokens=True)
13
  tokenizer.pad_token = '<pad>'
14
+ assert (tokenizer.pad_token_id is not None)
15
  sentinels = ''.join([f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)])
16
  _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
17
  tokenizer.sentinel_token_ids = _sentinel_token_ids
18
 
19
  class AutoTokenizerForMOD(AutoTokenizer):
20
+ 'AutoTokenizer + Adaptation for MOD.\n\n A simple wrapper around AutoTokenizer to make instantiating\n an MOD-adapted tokenizer a bit easier.\n\n MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),\n a padding token, and a property to get the token ids of the\n sentinel tokens.\n '
 
 
 
 
 
 
 
 
21
 
22
  @classmethod
23
  def from_pretrained(cls, *args, **kwargs):
24
+ 'See `AutoTokenizer.from_pretrained` docstring.'
25
  tokenizer = super().from_pretrained(*args, **kwargs)
26
  adapt_tokenizer_for_denoising(tokenizer)
27
+ return tokenizer
attention.py CHANGED
@@ -1,13 +1,73 @@
1
  """Attention layers."""
2
  import math
3
  import warnings
4
- from typing import Optional
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
  from packaging import version
9
  from torch import nn
 
10
  from .norm import LPLayerNorm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
13
  if original_is_causal and num_query_tokens != num_key_tokens:
@@ -17,25 +77,57 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
17
  return False
18
  return original_is_causal
19
 
20
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
23
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
24
- min_val = torch.finfo(q.dtype).min
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  (b, _, s_q, d) = q.shape
26
  s_k = k.size(-1)
27
  if softmax_scale is None:
28
  softmax_scale = 1 / math.sqrt(d)
29
  attn_weight = q.matmul(k) * softmax_scale
30
  if attn_bias is not None:
 
 
 
 
31
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
32
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
33
  attn_weight = attn_weight + attn_bias
 
34
  if key_padding_mask is not None:
35
  if attn_bias is not None:
36
- warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
37
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
38
- if is_causal:
39
  s = max(s_q, s_k)
40
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
41
  causal_mask = causal_mask.tril()
@@ -49,8 +141,8 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
49
  out = attn_weight.matmul(v)
50
  out = rearrange(out, 'b h s d -> b s (h d)')
51
  if needs_weights:
52
- return (out, attn_weight)
53
- return (out, None)
54
 
55
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
56
  for tensor in tensors:
@@ -59,12 +151,38 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
59
  if not tensor.is_cuda:
60
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
61
 
62
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  try:
64
  from flash_attn import bert_padding, flash_attn_interface
65
  except:
66
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
67
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
 
 
 
68
  if attn_bias is not None:
69
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
70
  (batch_size, seqlen) = query.shape[:2]
@@ -84,9 +202,23 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
84
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
85
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
86
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
87
- return (output, None)
88
 
89
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  try:
91
  from .flash_attn_triton import flash_attn_func
92
  except:
@@ -100,6 +232,18 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
100
  if not _installed:
101
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
102
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
 
 
 
103
  if dropout_p:
104
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
105
  if needs_weights:
@@ -119,14 +263,16 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
119
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
120
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
121
  output = attn_output.view(*attn_output.shape[:2], -1)
122
- return (output, None)
123
 
124
- class MultiheadAttention(nn.Module):
125
  """Multi-head self attention.
126
 
127
  Using torch or triton attention implemetation enables user to also use
128
  additive bias.
129
  """
 
 
130
 
131
  def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
132
  super().__init__()
@@ -160,7 +306,15 @@ class MultiheadAttention(nn.Module):
160
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
161
  self.out_proj._is_residual = True
162
 
163
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
164
  qkv = self.Wqkv(x)
165
  if self.clip_qkv:
166
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
@@ -170,17 +324,71 @@ class MultiheadAttention(nn.Module):
170
  dtype = query.dtype
171
  query = self.q_ln(query).to(dtype)
172
  key = self.k_ln(key).to(dtype)
173
- if past_key_value is not None:
174
- if len(past_key_value) != 0:
175
- key = torch.cat([past_key_value[0], key], dim=1)
176
- value = torch.cat([past_key_value[1], value], dim=1)
177
- past_key_value = (key, value)
178
- if attn_bias is not None:
179
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
180
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
181
- return (self.out_proj(context), attn_weights, past_key_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- class MultiQueryAttention(nn.Module):
184
  """Multi-Query self attention.
185
 
186
  Using torch or triton attention implemetation enables user to also use
@@ -220,7 +428,15 @@ class MultiQueryAttention(nn.Module):
220
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
221
  self.out_proj._is_residual = True
222
 
223
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
224
  qkv = self.Wqkv(x)
225
  if self.clip_qkv:
226
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
@@ -234,11 +450,72 @@ class MultiQueryAttention(nn.Module):
234
  if len(past_key_value) != 0:
235
  key = torch.cat([past_key_value[0], key], dim=1)
236
  value = torch.cat([past_key_value[1], value], dim=1)
237
- past_key_value = (key, value)
238
  if attn_bias is not None:
239
  attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
240
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
241
- return (self.out_proj(context), attn_weights, past_key_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
244
  if attn_impl == 'flash':
 
1
  """Attention layers."""
2
  import math
3
  import warnings
4
+ from typing import Optional, Dict, Any, NamedTuple, Protocol, Tuple, Union
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
  from packaging import version
9
  from torch import nn
10
+ from torch.utils.checkpoint import checkpoint
11
  from .norm import LPLayerNorm
12
+ from .is_torch_version import is_torch_version
13
+
14
+ class PastKeyValue(NamedTuple):
15
+ key: torch.Tensor
16
+ value: torch.Tensor
17
+
18
+ class AttnFnOutput(NamedTuple):
19
+ attns: torch.Tensor
20
+ attn_probs: Optional[torch.Tensor]
21
+ past_key_value: Union[PastKeyValue, Tuple, None]
22
+
23
+ class AttnFn(Protocol):
24
+ def __call__(
25
+ self,
26
+ query: torch.Tensor,
27
+ key: torch.Tensor,
28
+ value: torch.Tensor,
29
+ n_heads: int,
30
+ softmax_scale: Optional[float] = None,
31
+ attn_bias: Optional[torch.Tensor] = None,
32
+ key_padding_mask: Optional[torch.ByteTensor] = None,
33
+ is_causal = False,
34
+ dropout_p = 0.0,
35
+ training = False,
36
+ needs_weights = False,
37
+ multiquery = False,
38
+ ) -> AttnFnOutput: ...
39
+
40
+ class AttnFnCheckpointed(Protocol):
41
+ def __call__(
42
+ self,
43
+ query: torch.Tensor,
44
+ key: torch.Tensor,
45
+ value: torch.Tensor,
46
+ n_heads: int,
47
+ softmax_scale: Optional[float],
48
+ attn_bias: Optional[torch.Tensor],
49
+ key_padding_mask: Optional[torch.ByteTensor],
50
+ is_causal: bool,
51
+ dropout_p: float,
52
+ training: bool,
53
+ needs_weights: bool,
54
+ ) -> AttnFnOutput: ...
55
+
56
+ class AttnOutput(NamedTuple):
57
+ projected_context: torch.Tensor
58
+ attn_weights: Optional[torch.Tensor]
59
+ past_key_value: Union[PastKeyValue, Tuple, None]
60
+
61
+ class Attn(Protocol):
62
+ def __call__(
63
+ self,
64
+ x: torch.Tensor,
65
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
66
+ attn_bias: Optional[torch.Tensor] = None,
67
+ attention_mask: Optional[torch.ByteTensor] = None,
68
+ is_causal = True,
69
+ needs_weights = False,
70
+ ) -> AttnOutput: ...
71
 
72
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
73
  if original_is_causal and num_query_tokens != num_key_tokens:
 
77
  return False
78
  return original_is_causal
79
 
80
+ def scaled_multihead_dot_product_attention(
81
+ query: torch.Tensor,
82
+ key: torch.Tensor,
83
+ value: torch.Tensor,
84
+ n_heads: int,
85
+ past_key_value=None,
86
+ softmax_scale: Optional[float] = None,
87
+ attn_bias: Optional[torch.Tensor] = None,
88
+ key_padding_mask: Optional[torch.ByteTensor] = None,
89
+ is_causal = False,
90
+ dropout_p = 0.0,
91
+ training = False,
92
+ needs_weights = False,
93
+ multiquery = False,
94
+ ) -> AttnFnOutput:
95
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
96
+ kv_n_heads = 1 if multiquery else n_heads
97
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
98
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
99
+
100
+ if past_key_value is not None:
101
+ # attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
102
+ # kv_cache is therefore stored using that shape.
103
+ # attn_impl: torch stores the kv_cache in the ordering which is most advantageous
104
+ # for its attn computation ie
105
+ # keys are stored as tensors with shape [b, h, d_head, s] and
106
+ # values are stored as tensors with shape [b, h, s, d_head]
107
+ if len(past_key_value) != 0:
108
+ k = torch.cat([past_key_value[0], k], dim=3)
109
+ v = torch.cat([past_key_value[1], v], dim=2)
110
+
111
+ past_key_value = (k, v)
112
  (b, _, s_q, d) = q.shape
113
  s_k = k.size(-1)
114
  if softmax_scale is None:
115
  softmax_scale = 1 / math.sqrt(d)
116
  attn_weight = q.matmul(k) * softmax_scale
117
  if attn_bias is not None:
118
+ # clamp to 0 necessary for torch 2.0 compile()
119
+ _s_q = max(0, attn_bias.size(2) - s_q)
120
+ _s_k = max(0, attn_bias.size(3) - s_k)
121
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
122
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
123
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
124
  attn_weight = attn_weight + attn_bias
125
+ min_val = torch.finfo(q.dtype).min
126
  if key_padding_mask is not None:
127
  if attn_bias is not None:
128
+ warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
129
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
130
+ if is_causal and (not q.size(2) == 1):
131
  s = max(s_q, s_k)
132
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
133
  causal_mask = causal_mask.tril()
 
141
  out = attn_weight.matmul(v)
142
  out = rearrange(out, 'b h s d -> b s (h d)')
143
  if needs_weights:
144
+ return AttnFnOutput(out, attn_weight, past_key_value)
145
+ return AttnFnOutput(out, None, past_key_value)
146
 
147
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
148
  for tensor in tensors:
 
151
  if not tensor.is_cuda:
152
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
153
 
154
+ def flash_attn_fn(
155
+ query: torch.Tensor,
156
+ key: torch.Tensor,
157
+ value: torch.Tensor,
158
+ n_heads: int,
159
+ past_key_value=None,
160
+ softmax_scale: Optional[float] = None,
161
+ attn_bias: Optional[torch.Tensor] = None,
162
+ key_padding_mask: Optional[torch.ByteTensor] = None,
163
+ is_causal = False,
164
+ dropout_p = 0.0,
165
+ training = False,
166
+ needs_weights = False,
167
+ multiquery = False,
168
+ ) -> AttnFnOutput:
169
  try:
170
  from flash_attn import bert_padding, flash_attn_interface
171
  except:
172
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
173
  check_valid_inputs(query, key, value)
174
+ if past_key_value is not None:
175
+ if len(past_key_value) != 0:
176
+ key = torch.cat([past_key_value[0], key], dim=1)
177
+ value = torch.cat([past_key_value[1], value], dim=1)
178
+
179
+ past_key_value = (key, value)
180
+
181
+ if attn_bias is not None:
182
+ # clamp to 0 necessary for torch 2.0 compile()
183
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
184
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
185
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
186
  if attn_bias is not None:
187
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
188
  (batch_size, seqlen) = query.shape[:2]
 
202
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
203
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
204
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
205
+ return AttnFnOutput(output, None, past_key_value)
206
 
207
+ def triton_flash_attn_fn(
208
+ query: torch.Tensor,
209
+ key: torch.Tensor,
210
+ value: torch.Tensor,
211
+ n_heads: int,
212
+ past_key_value=None,
213
+ softmax_scale: Optional[float] = None,
214
+ attn_bias: Optional[torch.Tensor] = None,
215
+ key_padding_mask: Optional[torch.ByteTensor] = None,
216
+ is_causal = False,
217
+ dropout_p = 0.0,
218
+ training = False,
219
+ needs_weights = False,
220
+ multiquery = False,
221
+ ) -> AttnFnOutput:
222
  try:
223
  from .flash_attn_triton import flash_attn_func
224
  except:
 
232
  if not _installed:
233
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
234
  check_valid_inputs(query, key, value)
235
+ if past_key_value is not None:
236
+ if len(past_key_value) != 0:
237
+ key = torch.cat([past_key_value[0], key], dim=1)
238
+ value = torch.cat([past_key_value[1], value], dim=1)
239
+
240
+ past_key_value = (key, value)
241
+
242
+ if attn_bias is not None:
243
+ # clamp to 0 necessary for torch 2.0 compile()
244
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
245
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
246
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
247
  if dropout_p:
248
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
249
  if needs_weights:
 
263
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
264
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
265
  output = attn_output.view(*attn_output.shape[:2], -1)
266
+ return AttnFnOutput(output, None, past_key_value)
267
 
268
+ class MultiheadAttention(nn.Module, Attn):
269
  """Multi-head self attention.
270
 
271
  Using torch or triton attention implemetation enables user to also use
272
  additive bias.
273
  """
274
+ gradient_checkpointing = False
275
+ attn_fn: AttnFn
276
 
277
  def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
278
  super().__init__()
 
306
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
307
  self.out_proj._is_residual = True
308
 
309
+ def forward(
310
+ self,
311
+ x: torch.Tensor,
312
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
313
+ attn_bias: Optional[torch.Tensor] = None,
314
+ attention_mask: Optional[torch.ByteTensor] = None,
315
+ is_causal = True,
316
+ needs_weights = False,
317
+ ) -> AttnOutput:
318
  qkv = self.Wqkv(x)
319
  if self.clip_qkv:
320
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 
324
  dtype = query.dtype
325
  query = self.q_ln(query).to(dtype)
326
  key = self.k_ln(key).to(dtype)
327
+ if self.training and self.gradient_checkpointing:
328
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
329
+ def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
330
+ def custom_forward(
331
+ query: torch.Tensor,
332
+ key: torch.Tensor,
333
+ value: torch.Tensor,
334
+ n_heads: int,
335
+ softmax_scale: Optional[float],
336
+ attn_bias: Optional[torch.Tensor],
337
+ key_padding_mask: Optional[torch.ByteTensor],
338
+ is_causal: bool,
339
+ dropout_p: float,
340
+ training: bool,
341
+ needs_weights: bool,
342
+ ):
343
+ return attn_fn(
344
+ query,
345
+ key,
346
+ value,
347
+ n_heads,
348
+ softmax_scale,
349
+ attn_bias,
350
+ key_padding_mask,
351
+ is_causal,
352
+ dropout_p,
353
+ training,
354
+ needs_weights,
355
+ False, # multiquery
356
+ )
357
+ return custom_forward
358
+ attn_fn_out: AttnFnOutput = checkpoint(
359
+ create_custom_forward(self.attn_fn),
360
+ query,
361
+ key,
362
+ value,
363
+ self.n_heads,
364
+ self.softmax_scale,
365
+ attn_bias,
366
+ key_padding_mask,
367
+ is_causal,
368
+ self.attn_dropout_p,
369
+ self.training,
370
+ needs_weights,
371
+ **ckpt_kwargs,
372
+ )
373
+ else:
374
+ attn_fn_out: AttnFnOutput = self.attn_fn(
375
+ query,
376
+ key,
377
+ value,
378
+ self.n_heads,
379
+ past_key_value=past_key_value,
380
+ softmax_scale=self.softmax_scale,
381
+ attn_bias=attn_bias,
382
+ key_padding_mask=key_padding_mask,
383
+ is_causal=is_causal,
384
+ dropout_p=self.attn_dropout_p,
385
+ training=self.training,
386
+ needs_weights=needs_weights,
387
+ )
388
+ context, attn_weights, past_key_value = attn_fn_out
389
+ return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
390
 
391
+ class MultiQueryAttention(nn.Module, Attn):
392
  """Multi-Query self attention.
393
 
394
  Using torch or triton attention implemetation enables user to also use
 
428
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
429
  self.out_proj._is_residual = True
430
 
431
+ def forward(
432
+ self,
433
+ x: torch.Tensor,
434
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
435
+ attn_bias: Optional[torch.Tensor] = None,
436
+ attention_mask: Optional[torch.ByteTensor] = None,
437
+ is_causal = True,
438
+ needs_weights = False,
439
+ ) -> AttnOutput:
440
  qkv = self.Wqkv(x)
441
  if self.clip_qkv:
442
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 
450
  if len(past_key_value) != 0:
451
  key = torch.cat([past_key_value[0], key], dim=1)
452
  value = torch.cat([past_key_value[1], value], dim=1)
453
+ past_key_value = PastKeyValue(key, value)
454
  if attn_bias is not None:
455
  attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
456
+ if self.training and self.gradient_checkpointing:
457
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
458
+ def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
459
+ def custom_forward(
460
+ query: torch.Tensor,
461
+ key: torch.Tensor,
462
+ value: torch.Tensor,
463
+ n_heads: int,
464
+ softmax_scale: Optional[float],
465
+ attn_bias: Optional[torch.Tensor],
466
+ key_padding_mask: Optional[torch.ByteTensor],
467
+ is_causal: bool,
468
+ dropout_p: float,
469
+ training: bool,
470
+ needs_weights: bool,
471
+ ):
472
+ return attn_fn(
473
+ query,
474
+ key,
475
+ value,
476
+ n_heads,
477
+ softmax_scale,
478
+ attn_bias,
479
+ key_padding_mask,
480
+ is_causal,
481
+ dropout_p,
482
+ training,
483
+ needs_weights,
484
+ True, # multiquery
485
+ )
486
+ return custom_forward
487
+ attn_fn_out: AttnFnOutput = checkpoint(
488
+ create_custom_forward(self.attn_fn),
489
+ query,
490
+ key,
491
+ value,
492
+ self.n_heads,
493
+ self.softmax_scale,
494
+ attn_bias,
495
+ key_padding_mask,
496
+ is_causal,
497
+ self.attn_dropout_p,
498
+ self.training,
499
+ needs_weights,
500
+ **ckpt_kwargs,
501
+ )
502
+ else:
503
+ attn_fn_out: AttnFnOutput = self.attn_fn(
504
+ query,
505
+ key,
506
+ value,
507
+ self.n_heads,
508
+ past_key_value=past_key_value,
509
+ softmax_scale=self.softmax_scale,
510
+ attn_bias=attn_bias,
511
+ key_padding_mask=key_padding_mask,
512
+ is_causal=is_causal,
513
+ dropout_p=self.attn_dropout_p,
514
+ training=self.training,
515
+ needs_weights=needs_weights,
516
+ )
517
+ context, attn_weights = attn_fn_out
518
+ return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
519
 
520
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
521
  if attn_impl == 'flash':
blocks.py CHANGED
@@ -1,10 +1,15 @@
1
  """GPT Blocks used for the GPT Model."""
2
- from typing import Dict, Optional, Tuple
3
  import torch
4
  import torch.nn as nn
5
- from .attention import ATTN_CLASS_REGISTRY
6
  from .norm import NORM_CLASS_REGISTRY
7
 
 
 
 
 
 
8
  class MPTMLP(nn.Module):
9
 
10
  def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
@@ -18,6 +23,7 @@ class MPTMLP(nn.Module):
18
  return self.down_proj(self.act(self.up_proj(x)))
19
 
20
  class MPTBlock(nn.Module):
 
21
 
22
  def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23
  del kwargs
@@ -31,11 +37,11 @@ class MPTBlock(nn.Module):
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
 
34
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
 
1
  """GPT Blocks used for the GPT Model."""
2
+ from typing import Dict, Optional, Tuple, NamedTuple, Union
3
  import torch
4
  import torch.nn as nn
5
+ from .attention import ATTN_CLASS_REGISTRY, Attn, PastKeyValue
6
  from .norm import NORM_CLASS_REGISTRY
7
 
8
+ class MPTBlockOutput(NamedTuple):
9
+ hidden_states: torch.Tensor
10
+ attn_probs: Optional[torch.Tensor]
11
+ past_key_value: Union[PastKeyValue, Tuple, None]
12
+
13
  class MPTMLP(nn.Module):
14
 
15
  def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
 
23
  return self.down_proj(self.act(self.up_proj(x)))
24
 
25
  class MPTBlock(nn.Module):
26
+ attn: Attn
27
 
28
  def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
29
  del kwargs
 
37
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
38
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
39
 
40
+ def forward(self, x: torch.Tensor, past_key_value: Union[PastKeyValue, Tuple, None] = None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> MPTBlockOutput:
41
  a = self.norm_1(x)
42
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
43
  x = x + self.resid_attn_dropout(b)
44
  m = self.norm_2(x)
45
  n = self.ffn(m)
46
  x = x + self.resid_ffn_dropout(n)
47
+ return MPTBlockOutput(x, attn_weights, past_key_value)
config.json CHANGED
@@ -46,7 +46,7 @@
46
  "tokenizer_name": "EleutherAI/gpt-neox-20b",
47
  "torch_dtype": "bfloat16",
48
  "transformers_version": "4.28.1",
49
- "use_cache": false,
50
  "verbose": 0,
51
  "vocab_size": 50432
52
  }
 
46
  "tokenizer_name": "EleutherAI/gpt-neox-20b",
47
  "torch_dtype": "bfloat16",
48
  "transformers_version": "4.28.1",
49
+ "use_cache": true,
50
  "verbose": 0,
51
  "vocab_size": 50432
52
  }
configuration_mpt.py CHANGED
@@ -1,67 +1,15 @@
1
- """A HuggingFace-style model configuration."""
 
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
9
 
10
- def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
11
- """The MPT configuration class.
12
-
13
- Args:
14
- d_model (int): The size of the embedding dimension of the model.
15
- n_heads (int): The number of attention heads.
16
- n_layers (int): The number of layers in the model.
17
- expansion_ratio (int): The ratio of the up/down scale in the MLP.
18
- max_seq_len (int): The maximum sequence length of the model.
19
- vocab_size (int): The size of the vocabulary.
20
- resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
21
- emb_pdrop (float): The dropout probability for the embedding layer.
22
- learned_pos_emb (bool): Whether to use learned positional embeddings
23
- attn_config (Dict): A dictionary used to configure the model's attention module:
24
- attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
25
- attn_pdrop (float): The dropout probability for the attention layers.
26
- attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
27
- qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
28
- clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
29
- this value.
30
- softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
31
- use the default scale of ``1/sqrt(d_keys)``.
32
- prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
33
- extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
34
- can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
35
- attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
36
- When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
37
- which sub-sequence each token belongs to.
38
- Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
39
- alibi (bool): Whether to use the alibi bias instead of position embeddings.
40
- alibi_bias_max (int): The maximum value of the alibi bias.
41
- init_device (str): The device to use for parameter initialization.
42
- logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
43
- no_bias (bool): Whether to use bias in all layers.
44
- verbose (int): The verbosity level. 0 is silent.
45
- embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
46
- norm_type (str): choose type of norm to use
47
- multiquery_attention (bool): Whether to use multiquery attention implementation.
48
- use_cache (bool): Whether or not the model should return the last key/values attentions
49
- init_config (Dict): A dictionary used to configure the model initialization:
50
- init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
51
- 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
52
- 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
53
- init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
54
- emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
55
- emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
56
- used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
57
- init_std (float): The standard deviation of the normal distribution used to initialize the model,
58
- if using the baseline_ parameter initialization scheme.
59
- init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
60
- fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
61
- init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
62
- ---
63
- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
64
- """
65
  self.d_model = d_model
66
  self.n_heads = n_heads
67
  self.n_layers = n_layers
@@ -80,39 +28,39 @@ class MPTConfig(PretrainedConfig):
80
  self.norm_type = norm_type
81
  self.use_cache = use_cache
82
  self.init_config = init_config
83
- if 'name' in kwargs:
84
  del kwargs['name']
85
- if 'loss_fn' in kwargs:
86
  del kwargs['loss_fn']
87
  super().__init__(**kwargs)
88
  self._validate_config()
89
 
90
  def _set_config_defaults(self, config, config_defaults):
91
  for (k, v) in config_defaults.items():
92
- if k not in config:
93
  config[k] = v
94
  return config
95
 
96
  def _validate_config(self):
97
  self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
98
  self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
99
- if self.d_model % self.n_heads != 0:
100
  raise ValueError('d_model must be divisible by n_heads')
101
- if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
102
  raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
103
- if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
104
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
105
- if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
106
  raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
107
- if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
108
  raise NotImplementedError('alibi only implemented with torch and triton attention.')
109
- if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
110
  raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
111
- if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
112
  raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
113
- if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
114
  raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
115
- if self.init_config.get('name', None) is None:
116
  raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
117
- if not self.learned_pos_emb and (not self.attn_config['alibi']):
118
- raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
 
1
+
2
+ 'A HuggingFace-style model configuration.'
3
  from typing import Dict, Optional, Union
4
  from transformers import PretrainedConfig
5
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
6
+ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
7
 
8
  class MPTConfig(PretrainedConfig):
9
  model_type = 'mpt'
10
 
11
+ def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[(float, str)]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
12
+ "The MPT configuration class.\n\n Args:\n d_model (int): The size of the embedding dimension of the model.\n n_heads (int): The number of attention heads.\n n_layers (int): The number of layers in the model.\n expansion_ratio (int): The ratio of the up/down scale in the MLP.\n max_seq_len (int): The maximum sequence length of the model.\n vocab_size (int): The size of the vocabulary.\n resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.\n emb_pdrop (float): The dropout probability for the embedding layer.\n learned_pos_emb (bool): Whether to use learned positional embeddings\n attn_config (Dict): A dictionary used to configure the model's attention module:\n attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention\n attn_pdrop (float): The dropout probability for the attention layers.\n attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.\n qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.\n clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to\n this value.\n softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,\n use the default scale of ``1/sqrt(d_keys)``.\n prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an\n extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix\n can attend to one another bi-directionally. Tokens outside the prefix use causal attention.\n attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.\n When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates\n which sub-sequence each token belongs to.\n Defaults to ``False`` meaning any provided `sequence_id` will be ignored.\n alibi (bool): Whether to use the alibi bias instead of position embeddings.\n alibi_bias_max (int): The maximum value of the alibi bias.\n init_device (str): The device to use for parameter initialization.\n logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.\n no_bias (bool): Whether to use bias in all layers.\n verbose (int): The verbosity level. 0 is silent.\n embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.\n norm_type (str): choose type of norm to use\n multiquery_attention (bool): Whether to use multiquery attention implementation.\n use_cache (bool): Whether or not the model should return the last key/values attentions\n init_config (Dict): A dictionary used to configure the model initialization:\n init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',\n 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or\n 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.\n init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.\n emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.\n emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution\n used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.\n init_std (float): The standard deviation of the normal distribution used to initialize the model,\n if using the baseline_ parameter initialization scheme.\n init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.\n fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.\n init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.\n ---\n See llmfoundry.models.utils.param_init_fns.py for info on other param init config options\n "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  self.d_model = d_model
14
  self.n_heads = n_heads
15
  self.n_layers = n_layers
 
28
  self.norm_type = norm_type
29
  self.use_cache = use_cache
30
  self.init_config = init_config
31
+ if ('name' in kwargs):
32
  del kwargs['name']
33
+ if ('loss_fn' in kwargs):
34
  del kwargs['loss_fn']
35
  super().__init__(**kwargs)
36
  self._validate_config()
37
 
38
  def _set_config_defaults(self, config, config_defaults):
39
  for (k, v) in config_defaults.items():
40
+ if (k not in config):
41
  config[k] = v
42
  return config
43
 
44
  def _validate_config(self):
45
  self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
46
  self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
47
+ if ((self.d_model % self.n_heads) != 0):
48
  raise ValueError('d_model must be divisible by n_heads')
49
+ if any((((prob < 0) or (prob > 1)) for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
50
  raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
51
+ if (self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']):
52
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
53
+ if (self.attn_config['prefix_lm'] and (self.attn_config['attn_impl'] not in ['torch', 'triton'])):
54
  raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
55
+ if (self.attn_config['alibi'] and (self.attn_config['attn_impl'] not in ['torch', 'triton'])):
56
  raise NotImplementedError('alibi only implemented with torch and triton attention.')
57
+ if (self.attn_config['attn_uses_sequence_id'] and (self.attn_config['attn_impl'] not in ['torch', 'triton'])):
58
  raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
59
+ if ((self.embedding_fraction > 1) or (self.embedding_fraction <= 0)):
60
  raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
61
+ if (isinstance(self.logit_scale, str) and (self.logit_scale != 'inv_sqrt_d_model')):
62
  raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
63
+ if (self.init_config.get('name', None) is None):
64
  raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
65
+ if ((not self.learned_pos_emb) and (not self.attn_config['alibi'])):
66
+ raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
flash_attn_triton.py CHANGED
@@ -1,324 +1,283 @@
1
- """
2
- Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
3
- update imports to use 'triton_pre_mlir'
4
 
5
- *Experimental* implementation of FlashAttention in Triton.
6
- Tested with triton==2.0.0.dev20221202.
7
- Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
8
- other than 64:
9
- https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
10
- We'll update this implementation with the new Triton backend once this is fixed.
11
-
12
- We use the FlashAttention implementation from Phil Tillet a starting point.
13
- https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
14
-
15
- Changes:
16
- - Implement both causal and non-causal attention.
17
- - Implement both self-attention and cross-attention.
18
- - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
19
- - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
20
- - Support attention bias.
21
- - Speed up the forward pass a bit, and only store the LSE instead of m and l.
22
- - Make the backward for d=128 much faster by reducing register spilling.
23
- - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
24
- small batch size * nheads.
25
-
26
- Caution:
27
- - This is an *experimental* implementation. The forward pass should be quite robust but
28
- I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
29
- - This implementation has only been tested on A100.
30
- - If you plan to use headdim other than 64 and 128, you should test for race conditions
31
- (due to the Triton compiler), as done in tests/test_flash_attn.py
32
- "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
33
- for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
34
- that there are none left for other head dimensions.
35
-
36
- Differences between this Triton version and the CUDA version:
37
- - Triton version doesn't support dropout.
38
- - Triton forward is generally faster than CUDA forward, while Triton backward is
39
- generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
40
- than CUDA forward + backward.
41
- - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
42
- - Triton version supports attention bias, while CUDA version doesn't.
43
- """
44
  import math
45
  import torch
46
  import triton_pre_mlir as triton
47
  import triton_pre_mlir.language as tl
48
 
49
- @triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})
50
  @triton.jit
51
  def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
52
  start_m = tl.program_id(0)
53
  off_hb = tl.program_id(1)
54
- off_b = off_hb // nheads
55
- off_h = off_hb % nheads
56
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
57
  offs_n = tl.arange(0, BLOCK_N)
58
  offs_d = tl.arange(0, BLOCK_HEADDIM)
59
- q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
60
- k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
61
- v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
62
- if BIAS_TYPE == 'vector':
63
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
64
- elif BIAS_TYPE == 'matrix':
65
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
66
- t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
67
- lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
68
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
69
  acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
70
- if EVEN_M & EVEN_N:
71
  if EVEN_HEADDIM:
72
  q = tl.load(q_ptrs)
73
  else:
74
- q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
75
  elif EVEN_HEADDIM:
76
- q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
77
  else:
78
- q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
79
- end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
80
  for start_n in range(0, end_n, BLOCK_N):
81
  start_n = tl.multiple_of(start_n, BLOCK_N)
82
- if EVEN_N & EVEN_M:
83
  if EVEN_HEADDIM:
84
- k = tl.load(k_ptrs + start_n * stride_kn)
85
  else:
86
- k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
87
  elif EVEN_HEADDIM:
88
- k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
89
  else:
90
- k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
91
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
92
  qk += tl.dot(q, k, trans_b=True)
93
- if not EVEN_N:
94
- qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float('-inf'))
95
  if IS_CAUSAL:
96
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float('-inf'))
97
- if BIAS_TYPE != 'none':
98
- if BIAS_TYPE == 'vector':
99
  if EVEN_N:
100
- bias = tl.load(b_ptrs + start_n).to(tl.float32)
101
  else:
102
- bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)
103
  bias = bias[None, :]
104
- elif BIAS_TYPE == 'matrix':
105
- if EVEN_M & EVEN_N:
106
- bias = tl.load(b_ptrs + start_n).to(tl.float32)
107
  else:
108
- bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)
109
- qk = qk * softmax_scale + bias
110
  m_ij = tl.maximum(tl.max(qk, 1), lse_i)
111
- p = tl.exp(qk - m_ij[:, None])
112
  else:
113
- m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
114
- p = tl.exp(qk * softmax_scale - m_ij[:, None])
115
  l_ij = tl.sum(p, 1)
116
- acc_o_scale = tl.exp(m_i - m_ij)
117
  tl.store(t_ptrs, acc_o_scale)
118
  acc_o_scale = tl.load(t_ptrs)
119
- acc_o = acc_o * acc_o_scale[:, None]
120
- if EVEN_N & EVEN_M:
121
  if EVEN_HEADDIM:
122
- v = tl.load(v_ptrs + start_n * stride_vn)
123
  else:
124
- v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
125
  elif EVEN_HEADDIM:
126
- v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
127
  else:
128
- v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
129
  p = p.to(v.dtype)
130
  acc_o += tl.dot(p, v)
131
  m_i = m_ij
132
- l_i_new = tl.exp(lse_i - m_ij) + l_ij
133
- lse_i = m_ij + tl.log(l_i_new)
134
- o_scale = tl.exp(m_i - lse_i)
135
  tl.store(t_ptrs, o_scale)
136
  o_scale = tl.load(t_ptrs)
137
- acc_o = acc_o * o_scale[:, None]
138
  start_m = tl.program_id(0)
139
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
140
- lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
141
  tl.store(lse_ptrs, lse_i)
142
  offs_d = tl.arange(0, BLOCK_HEADDIM)
143
- out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
144
  if EVEN_M:
145
  if EVEN_HEADDIM:
146
  tl.store(out_ptrs, acc_o)
147
  else:
148
- tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
149
  elif EVEN_HEADDIM:
150
- tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
151
  else:
152
- tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
153
 
154
  @triton.jit
155
  def _bwd_preprocess_do_o_dot(Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr):
156
  start_m = tl.program_id(0)
157
  off_hb = tl.program_id(1)
158
- off_b = off_hb // nheads
159
- off_h = off_hb % nheads
160
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
161
  offs_d = tl.arange(0, BLOCK_HEADDIM)
162
- o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
163
- do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
164
- delta = tl.sum(o * do, axis=1)
165
- tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
166
 
167
  @triton.jit
168
  def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):
169
- if EVEN_N & EVEN_M:
170
  if EVEN_HEADDIM:
171
  tl.store(dv_ptrs, dv)
172
  tl.store(dk_ptrs, dk)
173
  else:
174
- tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
175
- tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
176
  elif EVEN_HEADDIM:
177
- tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
178
- tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
179
  else:
180
- tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
181
- tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
182
 
183
  @triton.jit
184
  def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
185
- begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
186
- offs_qm = begin_m + tl.arange(0, BLOCK_M)
187
- offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
188
  offs_m = tl.arange(0, BLOCK_M)
189
  offs_d = tl.arange(0, BLOCK_HEADDIM)
190
- q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
191
- k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
192
- v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
193
- do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
194
- dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
195
- if BIAS_TYPE == 'vector':
196
- b_ptrs = Bias + offs_n
197
- elif BIAS_TYPE == 'matrix':
198
- b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
199
  dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
200
  dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
201
- if begin_m >= seqlen_q:
202
- dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
203
- dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
204
  _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
205
  return
206
- if EVEN_N & EVEN_M:
207
  if EVEN_HEADDIM:
208
  k = tl.load(k_ptrs)
209
  v = tl.load(v_ptrs)
210
  else:
211
- k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
212
- v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
213
  elif EVEN_HEADDIM:
214
- k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
215
- v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
216
  else:
217
- k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
218
- v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
219
  num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
220
- for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
221
  start_m = tl.multiple_of(start_m, BLOCK_M)
222
- offs_m_curr = start_m + offs_m
223
- if EVEN_M & EVEN_HEADDIM:
224
  q = tl.load(q_ptrs)
225
  elif EVEN_HEADDIM:
226
- q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
227
  else:
228
- q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
229
  qk = tl.dot(q, k, trans_b=True)
230
- if not EVEN_N:
231
- qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
232
  if IS_CAUSAL:
233
- qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float('-inf'))
234
- if BIAS_TYPE != 'none':
235
  tl.debug_barrier()
236
- if BIAS_TYPE == 'vector':
237
  if EVEN_N:
238
  bias = tl.load(b_ptrs).to(tl.float32)
239
  else:
240
- bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
241
  bias = bias[None, :]
242
- elif BIAS_TYPE == 'matrix':
243
- if EVEN_M & EVEN_N:
244
  bias = tl.load(b_ptrs).to(tl.float32)
245
  else:
246
- bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32)
247
- qk = qk * softmax_scale + bias
248
- if not EVEN_M & EVEN_HEADDIM:
249
  tl.debug_barrier()
250
- lse_i = tl.load(LSE + offs_m_curr)
251
- if BIAS_TYPE == 'none':
252
- p = tl.exp(qk * softmax_scale - lse_i[:, None])
253
  else:
254
- p = tl.exp(qk - lse_i[:, None])
255
- if EVEN_M & EVEN_HEADDIM:
256
  do = tl.load(do_ptrs)
257
  else:
258
- do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
259
  dv += tl.dot(p.to(do.dtype), do, trans_a=True)
260
- if not EVEN_M & EVEN_HEADDIM:
261
  tl.debug_barrier()
262
  dp = tl.dot(do, v, trans_b=True)
263
- if not EVEN_HEADDIM:
264
  tl.debug_barrier()
265
- Di = tl.load(D + offs_m_curr)
266
- ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
267
  dk += tl.dot(ds, q, trans_a=True)
268
- if not EVEN_M & EVEN_HEADDIM:
269
  tl.debug_barrier()
270
- if not ATOMIC_ADD:
271
- if EVEN_M & EVEN_HEADDIM:
272
  dq = tl.load(dq_ptrs, eviction_policy='evict_last')
273
  dq += tl.dot(ds, k)
274
  tl.store(dq_ptrs, dq, eviction_policy='evict_last')
275
  elif EVEN_HEADDIM:
276
- dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy='evict_last')
277
  dq += tl.dot(ds, k)
278
- tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy='evict_last')
279
  else:
280
- dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy='evict_last')
281
  dq += tl.dot(ds, k)
282
- tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy='evict_last')
283
  else:
284
  dq = tl.dot(ds, k)
285
- if EVEN_M & EVEN_HEADDIM:
286
  tl.atomic_add(dq_ptrs, dq)
287
  elif EVEN_HEADDIM:
288
- tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
289
  else:
290
- tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
291
- dq_ptrs += BLOCK_M * stride_dqm
292
- q_ptrs += BLOCK_M * stride_qm
293
- do_ptrs += BLOCK_M * stride_dom
294
- if BIAS_TYPE == 'matrix':
295
- b_ptrs += BLOCK_M * stride_bm
296
- dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
297
- dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
298
  _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
299
 
300
  def init_to_zero(name):
301
- return lambda nargs: nargs[name].zero_()
302
 
303
  @triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ'))], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'])
304
- @triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})
305
  @triton.jit
306
  def _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
307
  off_hb = tl.program_id(1)
308
- off_b = off_hb // nheads
309
- off_h = off_hb % nheads
310
- Q += off_b * stride_qb + off_h * stride_qh
311
- K += off_b * stride_kb + off_h * stride_kh
312
- V += off_b * stride_vb + off_h * stride_vh
313
- DO += off_b * stride_dob + off_h * stride_doh
314
- DQ += off_b * stride_dqb + off_h * stride_dqh
315
- DK += off_b * stride_dkb + off_h * stride_dkh
316
- DV += off_b * stride_dvb + off_h * stride_dvh
317
- if BIAS_TYPE != 'none':
318
- Bias += off_b * stride_bb + off_h * stride_bh
319
- D += off_hb * seqlen_q_rounded
320
- LSE += off_hb * seqlen_q_rounded
321
- if not SEQUENCE_PARALLEL:
322
  num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
323
  for start_n in range(0, num_block_n):
324
  _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N)
@@ -329,86 +288,81 @@ def _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb,
329
  def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
330
  (batch, seqlen_q, nheads, d) = q.shape
331
  (_, seqlen_k, _, _) = k.shape
332
- assert k.shape == (batch, seqlen_k, nheads, d)
333
- assert v.shape == (batch, seqlen_k, nheads, d)
334
- assert d <= 128, 'FlashAttention only support head dimensions up to 128'
335
- assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
336
- assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
337
- assert q.is_cuda and k.is_cuda and v.is_cuda
338
- softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
339
- has_bias = bias is not None
340
  bias_type = 'none'
341
  if has_bias:
342
- assert bias.dtype in [q.dtype, torch.float]
343
  assert bias.is_cuda
344
- assert bias.dim() == 4
345
- if bias.stride(-1) != 1:
346
  bias = bias.contiguous()
347
- if bias.shape[2:] == (1, seqlen_k):
348
  bias_type = 'vector'
349
- elif bias.shape[2:] == (seqlen_q, seqlen_k):
350
  bias_type = 'matrix'
351
  else:
352
  raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')
353
  bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
354
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
355
- seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
356
  lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
357
  tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
358
  o = torch.empty_like(q)
359
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
360
  BLOCK = 128
361
- num_warps = 4 if d <= 64 else 8
362
- grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
363
- _fwd_kernel[grid](q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1)
364
  return (o, lse, softmax_scale)
365
 
366
  def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
367
- if do.stride(-1) != 1:
368
  do = do.contiguous()
369
  (batch, seqlen_q, nheads, d) = q.shape
370
  (_, seqlen_k, _, _) = k.shape
371
- assert d <= 128
372
- seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
373
- assert lse.shape == (batch, nheads, seqlen_q_rounded)
374
- assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
375
- assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
376
- softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
377
  dq_accum = torch.empty_like(q, dtype=torch.float32)
378
  delta = torch.empty_like(lse)
379
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
380
- grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
381
  _bwd_preprocess_do_o_dot[grid](o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM)
382
- has_bias = bias is not None
383
  bias_type = 'none'
384
  if has_bias:
385
- assert bias.dtype in [q.dtype, torch.float]
386
  assert bias.is_cuda
387
- assert bias.dim() == 4
388
- assert bias.stride(-1) == 1
389
- if bias.shape[2:] == (1, seqlen_k):
390
  bias_type = 'vector'
391
- elif bias.shape[2:] == (seqlen_q, seqlen_k):
392
  bias_type = 'matrix'
393
  else:
394
  raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')
395
  bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
396
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
397
- grid = lambda META: (triton.cdiv(seqlen_k, META['BLOCK_N']) if META['SEQUENCE_PARALLEL'] else 1, batch * nheads)
398
- _bwd_kernel[grid](q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM)
399
  dq.copy_(dq_accum)
400
 
401
  class FlashAttnQKVPackedFunc(torch.autograd.Function):
402
 
403
  @staticmethod
404
  def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
405
- """
406
- qkv: (batch, seqlen, 3, nheads, headdim)
407
- bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
408
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
409
- ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
410
- """
411
- if qkv.stride(-1) != 1:
412
  qkv = qkv.contiguous()
413
  (o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale)
414
  ctx.save_for_backward(qkv, o, lse, bias)
@@ -418,7 +372,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
418
  @staticmethod
419
  def backward(ctx, do):
420
  (qkv, o, lse, bias) = ctx.saved_tensors
421
- assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
422
  with torch.inference_mode():
423
  dqkv = torch.empty_like(qkv)
424
  _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
@@ -429,14 +383,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
429
 
430
  @staticmethod
431
  def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
432
- """
433
- q: (batch, seqlen_q, nheads, headdim)
434
- kv: (batch, seqlen_k, 2, nheads, headdim)
435
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
436
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
437
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
438
- """
439
- (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
440
  (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale)
441
  ctx.save_for_backward(q, kv, o, lse, bias)
442
  ctx.causal = causal
@@ -445,8 +393,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
445
  @staticmethod
446
  def backward(ctx, do):
447
  (q, kv, o, lse, bias) = ctx.saved_tensors
448
- if len(ctx.needs_input_grad) >= 3:
449
- assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
450
  with torch.inference_mode():
451
  dq = torch.empty_like(q)
452
  dkv = torch.empty_like(kv)
@@ -458,14 +406,8 @@ class FlashAttnFunc(torch.autograd.Function):
458
 
459
  @staticmethod
460
  def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
461
- """
462
- q: (batch_size, seqlen_q, nheads, headdim)
463
- k, v: (batch_size, seqlen_k, nheads, headdim)
464
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
465
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
466
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
467
- """
468
- (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
469
  (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
470
  ctx.save_for_backward(q, k, v, o, lse, bias)
471
  ctx.causal = causal
@@ -474,11 +416,11 @@ class FlashAttnFunc(torch.autograd.Function):
474
  @staticmethod
475
  def backward(ctx, do):
476
  (q, k, v, o, lse, bias) = ctx.saved_tensors
477
- assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
478
  with torch.inference_mode():
479
  dq = torch.empty_like(q)
480
  dk = torch.empty_like(k)
481
  dv = torch.empty_like(v)
482
  _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
483
  return (dq, dk, dv, None, None, None)
484
- flash_attn_func = FlashAttnFunc.apply
 
 
 
 
1
 
2
+ '\nCopied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py\nupdate imports to use \'triton_pre_mlir\'\n\n*Experimental* implementation of FlashAttention in Triton.\nTested with triton==2.0.0.dev20221202.\nTriton 2.0 has a new backend (MLIR) but seems like it doesn\'t yet work for head dimensions\nother than 64:\nhttps://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207\nWe\'ll update this implementation with the new Triton backend once this is fixed.\n\nWe use the FlashAttention implementation from Phil Tillet a starting point.\nhttps://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py\n\nChanges:\n- Implement both causal and non-causal attention.\n- Implement both self-attention and cross-attention.\n- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.\n- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.\n- Support attention bias.\n- Speed up the forward pass a bit, and only store the LSE instead of m and l.\n- Make the backward for d=128 much faster by reducing register spilling.\n- Optionally parallelize the backward pass across seqlen_k, to deal with the case of\nsmall batch size * nheads.\n\nCaution:\n- This is an *experimental* implementation. The forward pass should be quite robust but\nI\'m not 100% sure that the backward pass doesn\'t have race conditions (due to the Triton compiler).\n- This implementation has only been tested on A100.\n- If you plan to use headdim other than 64 and 128, you should test for race conditions\n(due to the Triton compiler), as done in tests/test_flash_attn.py\n"test_flash_attn_triton_race_condition". I\'ve tested and fixed many race conditions\nfor different head dimensions (40, 48, 64, 128, 80, 88, 96), but I\'m still not 100% confident\nthat there are none left for other head dimensions.\n\nDifferences between this Triton version and the CUDA version:\n- Triton version doesn\'t support dropout.\n- Triton forward is generally faster than CUDA forward, while Triton backward is\ngenerally slower than CUDA backward. Overall Triton forward + backward is slightly slower\nthan CUDA forward + backward.\n- Triton version doesn\'t support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).\n- Triton version supports attention bias, while CUDA version doesn\'t.\n'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import math
4
  import torch
5
  import triton_pre_mlir as triton
6
  import triton_pre_mlir.language as tl
7
 
8
+ @triton.heuristics({'EVEN_M': (lambda args: ((args['seqlen_q'] % args['BLOCK_M']) == 0)), 'EVEN_N': (lambda args: ((args['seqlen_k'] % args['BLOCK_N']) == 0)), 'EVEN_HEADDIM': (lambda args: (args['headdim'] == args['BLOCK_HEADDIM']))})
9
  @triton.jit
10
  def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
11
  start_m = tl.program_id(0)
12
  off_hb = tl.program_id(1)
13
+ off_b = (off_hb // nheads)
14
+ off_h = (off_hb % nheads)
15
+ offs_m = ((start_m * BLOCK_M) + tl.arange(0, BLOCK_M))
16
  offs_n = tl.arange(0, BLOCK_N)
17
  offs_d = tl.arange(0, BLOCK_HEADDIM)
18
+ q_ptrs = (((Q + (off_b * stride_qb)) + (off_h * stride_qh)) + ((offs_m[:, None] * stride_qm) + offs_d[None, :]))
19
+ k_ptrs = (((K + (off_b * stride_kb)) + (off_h * stride_kh)) + ((offs_n[:, None] * stride_kn) + offs_d[None, :]))
20
+ v_ptrs = (((V + (off_b * stride_vb)) + (off_h * stride_vh)) + ((offs_n[:, None] * stride_vn) + offs_d[None, :]))
21
+ if (BIAS_TYPE == 'vector'):
22
+ b_ptrs = (((Bias + (off_b * stride_bb)) + (off_h * stride_bh)) + offs_n)
23
+ elif (BIAS_TYPE == 'matrix'):
24
+ b_ptrs = (((Bias + (off_b * stride_bb)) + (off_h * stride_bh)) + ((offs_m[:, None] * stride_bm) + offs_n[None, :]))
25
+ t_ptrs = ((TMP + (off_hb * seqlen_q_rounded)) + offs_m)
26
+ lse_i = (tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf'))
27
+ m_i = (tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf'))
28
  acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
29
+ if (EVEN_M & EVEN_N):
30
  if EVEN_HEADDIM:
31
  q = tl.load(q_ptrs)
32
  else:
33
+ q = tl.load(q_ptrs, mask=(offs_d[None, :] < headdim), other=0.0)
34
  elif EVEN_HEADDIM:
35
+ q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q), other=0.0)
36
  else:
37
+ q = tl.load(q_ptrs, mask=((offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)), other=0.0)
38
+ end_n = (seqlen_k if (not IS_CAUSAL) else tl.minimum(((start_m + 1) * BLOCK_M), seqlen_k))
39
  for start_n in range(0, end_n, BLOCK_N):
40
  start_n = tl.multiple_of(start_n, BLOCK_N)
41
+ if (EVEN_N & EVEN_M):
42
  if EVEN_HEADDIM:
43
+ k = tl.load((k_ptrs + (start_n * stride_kn)))
44
  else:
45
+ k = tl.load((k_ptrs + (start_n * stride_kn)), mask=(offs_d[None, :] < headdim), other=0.0)
46
  elif EVEN_HEADDIM:
47
+ k = tl.load((k_ptrs + (start_n * stride_kn)), mask=((start_n + offs_n)[:, None] < seqlen_k), other=0.0)
48
  else:
49
+ k = tl.load((k_ptrs + (start_n * stride_kn)), mask=(((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim)), other=0.0)
50
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
51
  qk += tl.dot(q, k, trans_b=True)
52
+ if (not EVEN_N):
53
+ qk += tl.where(((start_n + offs_n)[None, :] < seqlen_k), 0, float('-inf'))
54
  if IS_CAUSAL:
55
+ qk += tl.where((offs_m[:, None] >= (start_n + offs_n)[None, :]), 0, float('-inf'))
56
+ if (BIAS_TYPE != 'none'):
57
+ if (BIAS_TYPE == 'vector'):
58
  if EVEN_N:
59
+ bias = tl.load((b_ptrs + start_n)).to(tl.float32)
60
  else:
61
+ bias = tl.load((b_ptrs + start_n), mask=((start_n + offs_n) < seqlen_k), other=0.0).to(tl.float32)
62
  bias = bias[None, :]
63
+ elif (BIAS_TYPE == 'matrix'):
64
+ if (EVEN_M & EVEN_N):
65
+ bias = tl.load((b_ptrs + start_n)).to(tl.float32)
66
  else:
67
+ bias = tl.load((b_ptrs + start_n), mask=((offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k)), other=0.0).to(tl.float32)
68
+ qk = ((qk * softmax_scale) + bias)
69
  m_ij = tl.maximum(tl.max(qk, 1), lse_i)
70
+ p = tl.exp((qk - m_ij[:, None]))
71
  else:
72
+ m_ij = tl.maximum((tl.max(qk, 1) * softmax_scale), lse_i)
73
+ p = tl.exp(((qk * softmax_scale) - m_ij[:, None]))
74
  l_ij = tl.sum(p, 1)
75
+ acc_o_scale = tl.exp((m_i - m_ij))
76
  tl.store(t_ptrs, acc_o_scale)
77
  acc_o_scale = tl.load(t_ptrs)
78
+ acc_o = (acc_o * acc_o_scale[:, None])
79
+ if (EVEN_N & EVEN_M):
80
  if EVEN_HEADDIM:
81
+ v = tl.load((v_ptrs + (start_n * stride_vn)))
82
  else:
83
+ v = tl.load((v_ptrs + (start_n * stride_vn)), mask=(offs_d[None, :] < headdim), other=0.0)
84
  elif EVEN_HEADDIM:
85
+ v = tl.load((v_ptrs + (start_n * stride_vn)), mask=((start_n + offs_n)[:, None] < seqlen_k), other=0.0)
86
  else:
87
+ v = tl.load((v_ptrs + (start_n * stride_vn)), mask=(((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim)), other=0.0)
88
  p = p.to(v.dtype)
89
  acc_o += tl.dot(p, v)
90
  m_i = m_ij
91
+ l_i_new = (tl.exp((lse_i - m_ij)) + l_ij)
92
+ lse_i = (m_ij + tl.log(l_i_new))
93
+ o_scale = tl.exp((m_i - lse_i))
94
  tl.store(t_ptrs, o_scale)
95
  o_scale = tl.load(t_ptrs)
96
+ acc_o = (acc_o * o_scale[:, None])
97
  start_m = tl.program_id(0)
98
+ offs_m = ((start_m * BLOCK_M) + tl.arange(0, BLOCK_M))
99
+ lse_ptrs = ((Lse + (off_hb * seqlen_q_rounded)) + offs_m)
100
  tl.store(lse_ptrs, lse_i)
101
  offs_d = tl.arange(0, BLOCK_HEADDIM)
102
+ out_ptrs = (((Out + (off_b * stride_ob)) + (off_h * stride_oh)) + ((offs_m[:, None] * stride_om) + offs_d[None, :]))
103
  if EVEN_M:
104
  if EVEN_HEADDIM:
105
  tl.store(out_ptrs, acc_o)
106
  else:
107
+ tl.store(out_ptrs, acc_o, mask=(offs_d[None, :] < headdim))
108
  elif EVEN_HEADDIM:
109
+ tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q))
110
  else:
111
+ tl.store(out_ptrs, acc_o, mask=((offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)))
112
 
113
  @triton.jit
114
  def _bwd_preprocess_do_o_dot(Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr):
115
  start_m = tl.program_id(0)
116
  off_hb = tl.program_id(1)
117
+ off_b = (off_hb // nheads)
118
+ off_h = (off_hb % nheads)
119
+ offs_m = ((start_m * BLOCK_M) + tl.arange(0, BLOCK_M))
120
  offs_d = tl.arange(0, BLOCK_HEADDIM)
121
+ o = tl.load(((((Out + (off_b * stride_ob)) + (off_h * stride_oh)) + (offs_m[:, None] * stride_om)) + offs_d[None, :]), mask=((offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)), other=0.0).to(tl.float32)
122
+ do = tl.load(((((DO + (off_b * stride_dob)) + (off_h * stride_doh)) + (offs_m[:, None] * stride_dom)) + offs_d[None, :]), mask=((offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)), other=0.0).to(tl.float32)
123
+ delta = tl.sum((o * do), axis=1)
124
+ tl.store(((Delta + (off_hb * seqlen_q_rounded)) + offs_m), delta)
125
 
126
  @triton.jit
127
  def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):
128
+ if (EVEN_N & EVEN_M):
129
  if EVEN_HEADDIM:
130
  tl.store(dv_ptrs, dv)
131
  tl.store(dk_ptrs, dk)
132
  else:
133
+ tl.store(dv_ptrs, dv, mask=(offs_d[None, :] < headdim))
134
+ tl.store(dk_ptrs, dk, mask=(offs_d[None, :] < headdim))
135
  elif EVEN_HEADDIM:
136
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k))
137
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k))
138
  else:
139
+ tl.store(dv_ptrs, dv, mask=((offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)))
140
+ tl.store(dk_ptrs, dk, mask=((offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)))
141
 
142
  @triton.jit
143
  def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
144
+ begin_m = (0 if (not IS_CAUSAL) else (((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M))
145
+ offs_qm = (begin_m + tl.arange(0, BLOCK_M))
146
+ offs_n = ((start_n * BLOCK_N) + tl.arange(0, BLOCK_N))
147
  offs_m = tl.arange(0, BLOCK_M)
148
  offs_d = tl.arange(0, BLOCK_HEADDIM)
149
+ q_ptrs = (Q + ((offs_qm[:, None] * stride_qm) + offs_d[None, :]))
150
+ k_ptrs = (K + ((offs_n[:, None] * stride_kn) + offs_d[None, :]))
151
+ v_ptrs = (V + ((offs_n[:, None] * stride_vn) + offs_d[None, :]))
152
+ do_ptrs = (DO + ((offs_qm[:, None] * stride_dom) + offs_d[None, :]))
153
+ dq_ptrs = (DQ + ((offs_qm[:, None] * stride_dqm) + offs_d[None, :]))
154
+ if (BIAS_TYPE == 'vector'):
155
+ b_ptrs = (Bias + offs_n)
156
+ elif (BIAS_TYPE == 'matrix'):
157
+ b_ptrs = (Bias + ((offs_qm[:, None] * stride_bm) + offs_n[None, :]))
158
  dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
159
  dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
160
+ if (begin_m >= seqlen_q):
161
+ dv_ptrs = (DV + ((offs_n[:, None] * stride_dvn) + offs_d[None, :]))
162
+ dk_ptrs = (DK + ((offs_n[:, None] * stride_dkn) + offs_d[None, :]))
163
  _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
164
  return
165
+ if (EVEN_N & EVEN_M):
166
  if EVEN_HEADDIM:
167
  k = tl.load(k_ptrs)
168
  v = tl.load(v_ptrs)
169
  else:
170
+ k = tl.load(k_ptrs, mask=(offs_d[None, :] < headdim), other=0.0)
171
+ v = tl.load(v_ptrs, mask=(offs_d[None, :] < headdim), other=0.0)
172
  elif EVEN_HEADDIM:
173
+ k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k), other=0.0)
174
+ v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k), other=0.0)
175
  else:
176
+ k = tl.load(k_ptrs, mask=((offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)), other=0.0)
177
+ v = tl.load(v_ptrs, mask=((offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)), other=0.0)
178
  num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
179
+ for start_m in range(begin_m, (num_block_m * BLOCK_M), BLOCK_M):
180
  start_m = tl.multiple_of(start_m, BLOCK_M)
181
+ offs_m_curr = (start_m + offs_m)
182
+ if (EVEN_M & EVEN_HEADDIM):
183
  q = tl.load(q_ptrs)
184
  elif EVEN_HEADDIM:
185
+ q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q), other=0.0)
186
  else:
187
+ q = tl.load(q_ptrs, mask=((offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)), other=0.0)
188
  qk = tl.dot(q, k, trans_b=True)
189
+ if (not EVEN_N):
190
+ qk = tl.where((offs_n[None, :] < seqlen_k), qk, float('-inf'))
191
  if IS_CAUSAL:
192
+ qk = tl.where((offs_m_curr[:, None] >= offs_n[None, :]), qk, float('-inf'))
193
+ if (BIAS_TYPE != 'none'):
194
  tl.debug_barrier()
195
+ if (BIAS_TYPE == 'vector'):
196
  if EVEN_N:
197
  bias = tl.load(b_ptrs).to(tl.float32)
198
  else:
199
+ bias = tl.load(b_ptrs, mask=(offs_n < seqlen_k), other=0.0).to(tl.float32)
200
  bias = bias[None, :]
201
+ elif (BIAS_TYPE == 'matrix'):
202
+ if (EVEN_M & EVEN_N):
203
  bias = tl.load(b_ptrs).to(tl.float32)
204
  else:
205
+ bias = tl.load(b_ptrs, mask=((offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k)), other=0.0).to(tl.float32)
206
+ qk = ((qk * softmax_scale) + bias)
207
+ if (not (EVEN_M & EVEN_HEADDIM)):
208
  tl.debug_barrier()
209
+ lse_i = tl.load((LSE + offs_m_curr))
210
+ if (BIAS_TYPE == 'none'):
211
+ p = tl.exp(((qk * softmax_scale) - lse_i[:, None]))
212
  else:
213
+ p = tl.exp((qk - lse_i[:, None]))
214
+ if (EVEN_M & EVEN_HEADDIM):
215
  do = tl.load(do_ptrs)
216
  else:
217
+ do = tl.load(do_ptrs, mask=((offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)), other=0.0)
218
  dv += tl.dot(p.to(do.dtype), do, trans_a=True)
219
+ if (not (EVEN_M & EVEN_HEADDIM)):
220
  tl.debug_barrier()
221
  dp = tl.dot(do, v, trans_b=True)
222
+ if (not EVEN_HEADDIM):
223
  tl.debug_barrier()
224
+ Di = tl.load((D + offs_m_curr))
225
+ ds = ((p * (dp - Di[:, None])) * softmax_scale).to(q.dtype)
226
  dk += tl.dot(ds, q, trans_a=True)
227
+ if (not (EVEN_M & EVEN_HEADDIM)):
228
  tl.debug_barrier()
229
+ if (not ATOMIC_ADD):
230
+ if (EVEN_M & EVEN_HEADDIM):
231
  dq = tl.load(dq_ptrs, eviction_policy='evict_last')
232
  dq += tl.dot(ds, k)
233
  tl.store(dq_ptrs, dq, eviction_policy='evict_last')
234
  elif EVEN_HEADDIM:
235
+ dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q), other=0.0, eviction_policy='evict_last')
236
  dq += tl.dot(ds, k)
237
+ tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q), eviction_policy='evict_last')
238
  else:
239
+ dq = tl.load(dq_ptrs, mask=((offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)), other=0.0, eviction_policy='evict_last')
240
  dq += tl.dot(ds, k)
241
+ tl.store(dq_ptrs, dq, mask=((offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)), eviction_policy='evict_last')
242
  else:
243
  dq = tl.dot(ds, k)
244
+ if (EVEN_M & EVEN_HEADDIM):
245
  tl.atomic_add(dq_ptrs, dq)
246
  elif EVEN_HEADDIM:
247
+ tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q))
248
  else:
249
+ tl.atomic_add(dq_ptrs, dq, mask=((offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)))
250
+ dq_ptrs += (BLOCK_M * stride_dqm)
251
+ q_ptrs += (BLOCK_M * stride_qm)
252
+ do_ptrs += (BLOCK_M * stride_dom)
253
+ if (BIAS_TYPE == 'matrix'):
254
+ b_ptrs += (BLOCK_M * stride_bm)
255
+ dv_ptrs = (DV + ((offs_n[:, None] * stride_dvn) + offs_d[None, :]))
256
+ dk_ptrs = (DK + ((offs_n[:, None] * stride_dkn) + offs_d[None, :]))
257
  _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
258
 
259
  def init_to_zero(name):
260
+ return (lambda nargs: nargs[name].zero_())
261
 
262
  @triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ'))], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'])
263
+ @triton.heuristics({'EVEN_M': (lambda args: ((args['seqlen_q'] % args['BLOCK_M']) == 0)), 'EVEN_N': (lambda args: ((args['seqlen_k'] % args['BLOCK_N']) == 0)), 'EVEN_HEADDIM': (lambda args: (args['headdim'] == args['BLOCK_HEADDIM']))})
264
  @triton.jit
265
  def _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
266
  off_hb = tl.program_id(1)
267
+ off_b = (off_hb // nheads)
268
+ off_h = (off_hb % nheads)
269
+ Q += ((off_b * stride_qb) + (off_h * stride_qh))
270
+ K += ((off_b * stride_kb) + (off_h * stride_kh))
271
+ V += ((off_b * stride_vb) + (off_h * stride_vh))
272
+ DO += ((off_b * stride_dob) + (off_h * stride_doh))
273
+ DQ += ((off_b * stride_dqb) + (off_h * stride_dqh))
274
+ DK += ((off_b * stride_dkb) + (off_h * stride_dkh))
275
+ DV += ((off_b * stride_dvb) + (off_h * stride_dvh))
276
+ if (BIAS_TYPE != 'none'):
277
+ Bias += ((off_b * stride_bb) + (off_h * stride_bh))
278
+ D += (off_hb * seqlen_q_rounded)
279
+ LSE += (off_hb * seqlen_q_rounded)
280
+ if (not SEQUENCE_PARALLEL):
281
  num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
282
  for start_n in range(0, num_block_n):
283
  _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N)
 
288
  def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
289
  (batch, seqlen_q, nheads, d) = q.shape
290
  (_, seqlen_k, _, _) = k.shape
291
+ assert (k.shape == (batch, seqlen_k, nheads, d))
292
+ assert (v.shape == (batch, seqlen_k, nheads, d))
293
+ assert (d <= 128), 'FlashAttention only support head dimensions up to 128'
294
+ assert (q.dtype == k.dtype == v.dtype), 'All tensors must have the same type'
295
+ assert (q.dtype in [torch.float16, torch.bfloat16]), 'Only support fp16 and bf16'
296
+ assert (q.is_cuda and k.is_cuda and v.is_cuda)
297
+ softmax_scale = (softmax_scale or (1.0 / math.sqrt(d)))
298
+ has_bias = (bias is not None)
299
  bias_type = 'none'
300
  if has_bias:
301
+ assert (bias.dtype in [q.dtype, torch.float])
302
  assert bias.is_cuda
303
+ assert (bias.dim() == 4)
304
+ if (bias.stride((- 1)) != 1):
305
  bias = bias.contiguous()
306
+ if (bias.shape[2:] == (1, seqlen_k)):
307
  bias_type = 'vector'
308
+ elif (bias.shape[2:] == (seqlen_q, seqlen_k)):
309
  bias_type = 'matrix'
310
  else:
311
  raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')
312
  bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
313
+ bias_strides = ((bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0))
314
+ seqlen_q_rounded = (math.ceil((seqlen_q / 128)) * 128)
315
  lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
316
  tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
317
  o = torch.empty_like(q)
318
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
319
  BLOCK = 128
320
+ num_warps = (4 if (d <= 64) else 8)
321
+ grid = (lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), (batch * nheads)))
322
+ _fwd_kernel[grid](q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, (seqlen_q // 32), (seqlen_k // 32), bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1)
323
  return (o, lse, softmax_scale)
324
 
325
  def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
326
+ if (do.stride((- 1)) != 1):
327
  do = do.contiguous()
328
  (batch, seqlen_q, nheads, d) = q.shape
329
  (_, seqlen_k, _, _) = k.shape
330
+ assert (d <= 128)
331
+ seqlen_q_rounded = (math.ceil((seqlen_q / 128)) * 128)
332
+ assert (lse.shape == (batch, nheads, seqlen_q_rounded))
333
+ assert (q.stride((- 1)) == k.stride((- 1)) == v.stride((- 1)) == o.stride((- 1)) == 1)
334
+ assert (dq.stride((- 1)) == dk.stride((- 1)) == dv.stride((- 1)) == 1)
335
+ softmax_scale = (softmax_scale or (1.0 / math.sqrt(d)))
336
  dq_accum = torch.empty_like(q, dtype=torch.float32)
337
  delta = torch.empty_like(lse)
338
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
339
+ grid = (lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), (batch * nheads)))
340
  _bwd_preprocess_do_o_dot[grid](o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM)
341
+ has_bias = (bias is not None)
342
  bias_type = 'none'
343
  if has_bias:
344
+ assert (bias.dtype in [q.dtype, torch.float])
345
  assert bias.is_cuda
346
+ assert (bias.dim() == 4)
347
+ assert (bias.stride((- 1)) == 1)
348
+ if (bias.shape[2:] == (1, seqlen_k)):
349
  bias_type = 'vector'
350
+ elif (bias.shape[2:] == (seqlen_q, seqlen_k)):
351
  bias_type = 'matrix'
352
  else:
353
  raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')
354
  bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
355
+ bias_strides = ((bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0))
356
+ grid = (lambda META: ((triton.cdiv(seqlen_k, META['BLOCK_N']) if META['SEQUENCE_PARALLEL'] else 1), (batch * nheads)))
357
+ _bwd_kernel[grid](q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, (seqlen_q // 32), (seqlen_k // 32), bias_type, causal, BLOCK_HEADDIM)
358
  dq.copy_(dq_accum)
359
 
360
  class FlashAttnQKVPackedFunc(torch.autograd.Function):
361
 
362
  @staticmethod
363
  def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
364
+ '\n qkv: (batch, seqlen, 3, nheads, headdim)\n bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).\n For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).\n ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)\n '
365
+ if (qkv.stride((- 1)) != 1):
 
 
 
 
 
366
  qkv = qkv.contiguous()
367
  (o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale)
368
  ctx.save_for_backward(qkv, o, lse, bias)
 
372
  @staticmethod
373
  def backward(ctx, do):
374
  (qkv, o, lse, bias) = ctx.saved_tensors
375
+ assert (not ctx.needs_input_grad[1]), 'FlashAttention does not support bias gradient yet'
376
  with torch.inference_mode():
377
  dqkv = torch.empty_like(qkv)
378
  _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
 
383
 
384
  @staticmethod
385
  def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
386
+ '\n q: (batch, seqlen_q, nheads, headdim)\n kv: (batch, seqlen_k, 2, nheads, headdim)\n bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).\n For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).\n ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)\n '
387
+ (q, kv) = [(x if (x.stride((- 1)) == 1) else x.contiguous()) for x in [q, kv]]
 
 
 
 
 
 
388
  (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale)
389
  ctx.save_for_backward(q, kv, o, lse, bias)
390
  ctx.causal = causal
 
393
  @staticmethod
394
  def backward(ctx, do):
395
  (q, kv, o, lse, bias) = ctx.saved_tensors
396
+ if (len(ctx.needs_input_grad) >= 3):
397
+ assert (not ctx.needs_input_grad[2]), 'FlashAttention does not support bias gradient yet'
398
  with torch.inference_mode():
399
  dq = torch.empty_like(q)
400
  dkv = torch.empty_like(kv)
 
406
 
407
  @staticmethod
408
  def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
409
+ '\n q: (batch_size, seqlen_q, nheads, headdim)\n k, v: (batch_size, seqlen_k, nheads, headdim)\n bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).\n For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).\n ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)\n '
410
+ (q, k, v) = [(x if (x.stride((- 1)) == 1) else x.contiguous()) for x in [q, k, v]]
 
 
 
 
 
 
411
  (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
412
  ctx.save_for_backward(q, k, v, o, lse, bias)
413
  ctx.causal = causal
 
416
  @staticmethod
417
  def backward(ctx, do):
418
  (q, k, v, o, lse, bias) = ctx.saved_tensors
419
+ assert (not ctx.needs_input_grad[3]), 'FlashAttention does not support bias gradient yet'
420
  with torch.inference_mode():
421
  dq = torch.empty_like(q)
422
  dk = torch.empty_like(k)
423
  dv = torch.empty_like(v)
424
  _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
425
  return (dq, dk, dv, None, None, None)
426
+ flash_attn_func = FlashAttnFunc.apply
generation_config.json CHANGED
@@ -1,5 +1,7 @@
1
  {
2
  "_from_model_config": true,
3
  "transformers_version": "4.28.1",
4
- "use_cache": false
 
 
5
  }
 
1
  {
2
  "_from_model_config": true,
3
  "transformers_version": "4.28.1",
4
+ "eos_token_id": 0,
5
+ "pad_token_id": 0,
6
+ "use_cache": true
7
  }
hf_prefixlm_converter.py CHANGED
@@ -1,11 +1,5 @@
1
- """Converts Huggingface Causal LM to Prefix LM.
2
 
3
- Conversion does lightweight surgery on a HuggingFace
4
- Causal LM to convert it to a Prefix LM.
5
-
6
- Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
- and treat the input prompt as the prefix in `generate`.
8
- """
9
  import math
10
  import warnings
11
  from types import MethodType
@@ -24,31 +18,17 @@ from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_op
24
  from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
25
  logger = logging.get_logger(__name__)
26
  _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
27
- CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
28
 
29
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
30
- """Converts a GPT-style Causal LM to a Prefix LM.
31
-
32
- Supported HuggingFace model classes:
33
- - `GPT2LMHeadModel`
34
- - `GPTNeoForCausalLM`
35
- - `GPTNeoXForCausalLM`
36
- - `GPTJForCausalLM`
37
-
38
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
39
- """
40
  if hasattr(model, '_prefix_lm_converted'):
41
  return model
42
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
43
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
44
 
45
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
46
- """Helper that gets a list of the model's attention modules.
47
-
48
- Each module has a `bias` buffer used for causal masking. The Prefix LM
49
- conversion adds logic to dynamically manipulate these biases to support
50
- Prefix LM attention masking.
51
- """
52
  attn_modules = []
53
  if isinstance(model, GPTNeoXForCausalLM):
54
  blocks = model.gpt_neox.layers
@@ -56,7 +36,7 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
56
  blocks = model.transformer.h
57
  for block in blocks:
58
  if isinstance(model, GPTNeoForCausalLM):
59
- if block.attn.attention_type != 'global':
60
  continue
61
  attn_module = block.attn.attention
62
  elif isinstance(model, GPTNeoXForCausalLM):
@@ -69,41 +49,41 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
69
  setattr(model, '_original_generate', getattr(model, 'generate'))
70
 
71
  def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
72
- """Wraps original forward to enable PrefixLM attention."""
73
 
74
  def call_og_forward():
75
  if isinstance(self, GPTNeoXForCausalLM):
76
  return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
77
  else:
78
  return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
79
- if bidirectional_mask is None:
80
  return call_og_forward()
81
  assert isinstance(bidirectional_mask, torch.Tensor)
82
  attn_modules = _get_attn_modules(model)
83
  (b, s) = bidirectional_mask.shape
84
- max_length = attn_modules[0].bias.shape[-1]
85
- if s > max_length:
86
- raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')
87
- assert s <= max_length
88
- if s < max_length:
89
- pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
90
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
  for attn_module in attn_modules:
93
  attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
94
  output = call_og_forward()
95
  for attn_module in attn_modules:
96
- attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
97
  return output
98
 
99
- def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
100
- """Wraps original generate to enable PrefixLM attention."""
101
  attn_modules = _get_attn_modules(model)
102
  for attn_module in attn_modules:
103
  attn_module.bias.data[:] = 1
104
  output = self._original_generate(*args, **kwargs)
105
  for attn_module in attn_modules:
106
- attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
107
  return output
108
  setattr(model, 'forward', MethodType(forward, model))
109
  setattr(model, 'generate', MethodType(generate, model))
@@ -111,85 +91,79 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
111
  return model
112
 
113
  def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
114
- """Converts a BLOOM Causal LM to a Prefix LM.
115
-
116
- Supported HuggingFace model classes:
117
- - `BloomForCausalLM`
118
-
119
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
120
- """
121
  if hasattr(model, '_prefix_lm_converted'):
122
  return model
123
  assert isinstance(model, BloomForCausalLM)
124
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
125
 
126
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
127
  combined_attention_mask = None
128
  device = attention_mask.device
129
  (_, src_length) = input_shape
130
- if src_length > 1:
131
  combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
132
- if bidirectional_mask is not None:
133
- assert attention_mask.shape == bidirectional_mask.shape
134
  expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
135
  combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
136
  expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
137
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
138
  return combined_attention_mask
139
 
140
  def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
141
  num_heads = self.config.n_head
142
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
143
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
144
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
145
  slopes = torch.pow(base, powers)
146
- if closest_power_of_2 != num_heads:
147
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
148
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
149
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
150
  slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
151
- qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
152
- ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
153
- diffs = qa - ka + key_length - query_length
154
- diffs = -diffs.abs()
155
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
156
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
157
  return alibi.to(dtype)
158
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
159
 
160
- def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
161
- if deprecated_arguments.pop('position_ids', False) is not False:
162
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
163
- if len(deprecated_arguments) > 0:
164
  raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
165
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
166
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
167
- use_cache = use_cache if use_cache is not None else self.config.use_cache
168
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
- if input_ids is not None and inputs_embeds is not None:
170
  raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
171
- elif input_ids is not None:
172
  (batch_size, seq_length) = input_ids.shape
173
- elif inputs_embeds is not None:
174
  (batch_size, seq_length, _) = inputs_embeds.shape
175
  else:
176
  raise ValueError('You have to specify either input_ids or inputs_embeds')
177
- if past_key_values is None:
178
- past_key_values = tuple([None] * len(self.h))
179
  head_mask = self.get_head_mask(head_mask, self.config.n_layer)
180
- if inputs_embeds is None:
181
  inputs_embeds = self.word_embeddings(input_ids)
182
  hidden_states = self.word_embeddings_layernorm(inputs_embeds)
183
- presents = () if use_cache else None
184
- all_self_attentions = () if output_attentions else None
185
- all_hidden_states = () if output_hidden_states else None
186
  seq_length_with_past = seq_length
187
  past_key_values_length = 0
188
- if past_key_values[0] is not None:
189
  tmp = past_key_values[0][0]
190
  past_key_values_length = tmp.shape[2]
191
- seq_length_with_past = seq_length_with_past + past_key_values_length
192
- if attention_mask is None:
193
  attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
194
  else:
195
  attention_mask = attention_mask.to(hidden_states.device)
@@ -198,8 +172,8 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
198
  for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
199
  if output_hidden_states:
200
  hst = (hidden_states,)
201
- all_hidden_states = all_hidden_states + hst
202
- if self.gradient_checkpointing and self.training:
203
  if use_cache:
204
  logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
205
  use_cache = False
@@ -213,50 +187,50 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
213
  else:
214
  outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
215
  hidden_states = outputs[0]
216
- if use_cache is True:
217
- presents = presents + (outputs[1],)
218
  if output_attentions:
219
- oa = (outputs[2 if use_cache else 1],)
220
- all_self_attentions = all_self_attentions + oa
221
  hidden_states = self.ln_f(hidden_states)
222
  if output_hidden_states:
223
  hst = (hidden_states,)
224
- all_hidden_states = all_hidden_states + hst
225
- if not return_dict:
226
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
227
  return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
228
  setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
229
  setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
230
  setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
231
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
232
 
233
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
234
- """Replacement forward method for BloomCausalLM."""
235
- if deprecated_arguments.pop('position_ids', False) is not False:
236
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
237
- if len(deprecated_arguments) > 0:
238
  raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
239
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
240
  transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
241
  hidden_states = transformer_outputs[0]
242
  lm_logits = self.lm_head(hidden_states)
243
  loss = None
244
- if labels is not None:
245
- shift_logits = lm_logits[..., :-1, :].contiguous()
246
  shift_labels = labels[..., 1:].contiguous()
247
  (batch_size, seq_length, vocab_size) = shift_logits.shape
248
  loss_fct = CrossEntropyLoss()
249
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
250
- if not return_dict:
251
- output = (lm_logits,) + transformer_outputs[1:]
252
- return (loss,) + output if loss is not None else output
253
  return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
254
 
255
  def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
256
  if past:
257
- input_ids = input_ids[:, -1].unsqueeze(-1)
258
  bidirectional_mask = None
259
- if past[0][0].shape[0] == input_ids.shape[0]:
260
  past = self._convert_to_bloom_cache(past)
261
  else:
262
  bidirectional_mask = torch.ones_like(input_ids)
@@ -267,36 +241,30 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
267
  return model
268
 
269
  def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
270
- """Converts an OPT Causal LM to a Prefix LM.
271
-
272
- Supported HuggingFace model classes:
273
- - `OPTForCausalLM`
274
-
275
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
276
- """
277
  if hasattr(model, '_prefix_lm_converted'):
278
  return model
279
  assert isinstance(model, OPTForCausalLM)
280
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
281
  setattr(model, '_original_forward', getattr(model, 'forward'))
282
  setattr(model, '_original_generate', getattr(model, 'generate'))
283
  model.model.decoder.bidirectional_mask = None
284
 
285
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
286
  combined_attention_mask = None
287
- if input_shape[-1] > 1:
288
- if self.bidirectional_mask == 'g':
289
  (bsz, src_length) = input_shape
290
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
291
  else:
292
  combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
293
- if self.bidirectional_mask is not None:
294
- assert attention_mask.shape == self.bidirectional_mask.shape
295
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
296
  combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
297
- if attention_mask is not None:
298
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
299
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
300
  return combined_attention_mask
301
  setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
302
 
@@ -304,7 +272,7 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
304
 
305
  def call_og_forward():
306
  return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
307
- if bidirectional_mask is None:
308
  return call_og_forward()
309
  self.model.decoder.bidirectional_mask = bidirectional_mask
310
  try:
@@ -315,8 +283,8 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
315
  self.model.decoder.bidirectional_mask = None
316
  return outputs
317
 
318
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
319
- """Wraps original generate to enable PrefixLM-style attention."""
320
  self.model.decoder.bidirectional_mask = 'g'
321
  try:
322
  output = self._original_generate(*args, **kwargs)
@@ -329,66 +297,11 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
329
  setattr(model, 'generate', MethodType(generate, model))
330
  setattr(model, '_prefix_lm_converted', True)
331
  return model
332
- _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
333
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
334
 
335
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
336
- """Converts a HuggingFace Causal LM to a Prefix LM.
337
-
338
- Supported HuggingFace model classes:
339
- - `GPT2LMHeadModel`
340
- - `GPTNeoForCausalLM`
341
- - `GPTNeoXForCausalLM`
342
- - `GPTJForCausalLM`
343
- - `BloomForCausalLM`
344
- - `OPTForCausalLM`
345
-
346
- Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
347
- `generate` method and/or select underlying methods depending on the model class.
348
-
349
- These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
350
-
351
- Notes on training:
352
- To actually train the converted model as a Prefix LM, training batches will need to indicate
353
- the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
354
-
355
- **This is not a standard input and requires custom layers either within or after your dataloader.**
356
-
357
- In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
358
- such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
359
- That is, the prefix portion of the sequence should not generate any loss. Loss should only be
360
- generated by the target portion of the sequence.
361
-
362
- Notes on `GPTNeoForCausalLM`:
363
- To simplify the implementation, "global" and "local" attention layers are handled differently.
364
- For "global" layers, we handle conversion as described above. For "local" layers, which use a
365
- causal attention mask within a restricted local window, we do not alter the masking.
366
-
367
- Notes on `forward` method conversion:
368
- After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
369
- which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
370
- belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
371
- 0 indicates token positions belonging to the target.
372
-
373
- The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
374
- causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
375
- the causal masks before returning the result.
376
-
377
- Notes on `generate` method conversion:
378
- After conversion, the `generate` method will have the same signature but will internally
379
- convert all causal masks to be purely bidirectional, call the original `generate` method, and
380
- (where appropriate) reset the causal masks before returning the result.
381
-
382
- This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
383
- "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
384
- each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
385
- another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
386
- previously-generated tokens (also as expected in a Prefix LM).
387
-
388
- To preserve the API, the original methods are renamed to `_original_forward` and
389
- `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
390
- them, respectively. Although implementation details vary by model class.
391
- """
392
  if isinstance(model, _SUPPORTED_GPT_MODELS):
393
  return _convert_gpt_causal_lm_to_prefix_lm(model)
394
  elif isinstance(model, BloomForCausalLM):
@@ -396,20 +309,17 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
396
  elif isinstance(model, OPTForCausalLM):
397
  return _convert_opt_causal_lm_to_prefix_lm(model)
398
  else:
399
- raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
400
-
401
- def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
402
- """Attempts to add bidirectional_mask to batch if missing.
403
 
404
- Raises:
405
- KeyError if bidirectional_mask is missing and can't be inferred
406
- """
407
- if 'bidirectional_mask' not in batch:
408
- if batch.get('mode', None) == 'icl_task':
409
  batch['bidirectional_mask'] = batch['attention_mask'].clone()
410
  for (i, continuation_indices) in enumerate(batch['continuation_indices']):
411
- batch['bidirectional_mask'][i, continuation_indices] = 0
412
- elif 'labels' in batch and 'attention_mask' in batch:
413
- batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])
414
  else:
415
- raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')
 
 
1
 
2
+ 'Converts Huggingface Causal LM to Prefix LM.\n\nConversion does lightweight surgery on a HuggingFace\nCausal LM to convert it to a Prefix LM.\n\nPrefix LMs accepts a `bidirectional_mask` input in `forward`\nand treat the input prompt as the prefix in `generate`.\n'
 
 
 
 
 
3
  import math
4
  import warnings
5
  from types import MethodType
 
18
  from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
19
  logger = logging.get_logger(__name__)
20
  _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
21
+ CAUSAL_GPT_TYPES = Union[(GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)]
22
 
23
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
24
+ 'Converts a GPT-style Causal LM to a Prefix LM.\n\n Supported HuggingFace model classes:\n - `GPT2LMHeadModel`\n - `GPTNeoForCausalLM`\n - `GPTNeoXForCausalLM`\n - `GPTJForCausalLM`\n\n See `convert_hf_causal_lm_to_prefix_lm` for more details.\n '
 
 
 
 
 
 
 
 
 
25
  if hasattr(model, '_prefix_lm_converted'):
26
  return model
27
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
28
+ assert (model.config.add_cross_attention == False), 'Only supports GPT-style decoder-only models'
29
 
30
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
31
+ "Helper that gets a list of the model's attention modules.\n\n Each module has a `bias` buffer used for causal masking. The Prefix LM\n conversion adds logic to dynamically manipulate these biases to support\n Prefix LM attention masking.\n "
 
 
 
 
 
32
  attn_modules = []
33
  if isinstance(model, GPTNeoXForCausalLM):
34
  blocks = model.gpt_neox.layers
 
36
  blocks = model.transformer.h
37
  for block in blocks:
38
  if isinstance(model, GPTNeoForCausalLM):
39
+ if (block.attn.attention_type != 'global'):
40
  continue
41
  attn_module = block.attn.attention
42
  elif isinstance(model, GPTNeoXForCausalLM):
 
49
  setattr(model, '_original_generate', getattr(model, 'generate'))
50
 
51
  def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
52
+ 'Wraps original forward to enable PrefixLM attention.'
53
 
54
  def call_og_forward():
55
  if isinstance(self, GPTNeoXForCausalLM):
56
  return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
57
  else:
58
  return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
59
+ if (bidirectional_mask is None):
60
  return call_og_forward()
61
  assert isinstance(bidirectional_mask, torch.Tensor)
62
  attn_modules = _get_attn_modules(model)
63
  (b, s) = bidirectional_mask.shape
64
+ max_length = attn_modules[0].bias.shape[(- 1)]
65
+ if (s > max_length):
66
+ raise ValueError((f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).'))
67
+ assert (s <= max_length)
68
+ if (s < max_length):
69
+ pad = torch.zeros((int(b), int((max_length - s))), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
70
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
71
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
72
  for attn_module in attn_modules:
73
  attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
74
  output = call_og_forward()
75
  for attn_module in attn_modules:
76
+ attn_module.bias.data = torch.tril(attn_module.bias.data[(0, 0)])[(None, None)]
77
  return output
78
 
79
+ def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[(str, Any)]):
80
+ 'Wraps original generate to enable PrefixLM attention.'
81
  attn_modules = _get_attn_modules(model)
82
  for attn_module in attn_modules:
83
  attn_module.bias.data[:] = 1
84
  output = self._original_generate(*args, **kwargs)
85
  for attn_module in attn_modules:
86
+ attn_module.bias.data = torch.tril(attn_module.bias.data[(0, 0)])[(None, None)]
87
  return output
88
  setattr(model, 'forward', MethodType(forward, model))
89
  setattr(model, 'generate', MethodType(generate, model))
 
91
  return model
92
 
93
  def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
94
+ 'Converts a BLOOM Causal LM to a Prefix LM.\n\n Supported HuggingFace model classes:\n - `BloomForCausalLM`\n\n See `convert_hf_causal_lm_to_prefix_lm` for more details.\n '
 
 
 
 
 
 
95
  if hasattr(model, '_prefix_lm_converted'):
96
  return model
97
  assert isinstance(model, BloomForCausalLM)
98
+ assert (model.config.add_cross_attention == False), 'Only supports BLOOM decoder-only models'
99
 
100
+ def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[(int, int)], past_key_values_length: int) -> torch.BoolTensor:
101
  combined_attention_mask = None
102
  device = attention_mask.device
103
  (_, src_length) = input_shape
104
+ if (src_length > 1):
105
  combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
106
+ if (bidirectional_mask is not None):
107
+ assert (attention_mask.shape == bidirectional_mask.shape)
108
  expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
109
  combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
110
  expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
111
+ combined_attention_mask = (expanded_attn_mask if (combined_attention_mask is None) else (expanded_attn_mask | combined_attention_mask))
112
  return combined_attention_mask
113
 
114
  def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
115
  num_heads = self.config.n_head
116
+ closest_power_of_2 = (2 ** math.floor(math.log2(num_heads)))
117
+ base = torch.tensor((2 ** (- (2 ** (- (math.log2(closest_power_of_2) - 3))))), device=device, dtype=torch.float32)
118
+ powers = torch.arange(1, (1 + closest_power_of_2), device=device, dtype=torch.int32)
119
  slopes = torch.pow(base, powers)
120
+ if (closest_power_of_2 != num_heads):
121
+ extra_base = torch.tensor((2 ** (- (2 ** (- (math.log2((2 * closest_power_of_2)) - 3))))), device=device, dtype=torch.float32)
122
+ num_remaining_heads = min(closest_power_of_2, (num_heads - closest_power_of_2))
123
+ extra_powers = torch.arange(1, (1 + (2 * num_remaining_heads)), 2, device=device, dtype=torch.int32)
124
  slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
125
+ qa = torch.arange(query_length, device=device, dtype=torch.int32).view((- 1), 1)
126
+ ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, (- 1))
127
+ diffs = (((qa - ka) + key_length) - query_length)
128
+ diffs = (- diffs.abs())
129
+ alibi = (slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length))
130
+ alibi = alibi.expand(batch_size, (- 1), (- 1), (- 1)).reshape((- 1), query_length, key_length)
131
  return alibi.to(dtype)
132
+ KeyValueT = Tuple[(torch.Tensor, torch.Tensor)]
133
 
134
+ def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[(KeyValueT, ...)]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[(Tuple[(torch.Tensor, ...)], BaseModelOutputWithPastAndCrossAttentions)]:
135
+ if (deprecated_arguments.pop('position_ids', False) is not False):
136
+ warnings.warn(('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.'), FutureWarning)
137
+ if (len(deprecated_arguments) > 0):
138
  raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
139
+ output_attentions = (output_attentions if (output_attentions is not None) else self.config.output_attentions)
140
+ output_hidden_states = (output_hidden_states if (output_hidden_states is not None) else self.config.output_hidden_states)
141
+ use_cache = (use_cache if (use_cache is not None) else self.config.use_cache)
142
+ return_dict = (return_dict if (return_dict is not None) else self.config.use_return_dict)
143
+ if ((input_ids is not None) and (inputs_embeds is not None)):
144
  raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
145
+ elif (input_ids is not None):
146
  (batch_size, seq_length) = input_ids.shape
147
+ elif (inputs_embeds is not None):
148
  (batch_size, seq_length, _) = inputs_embeds.shape
149
  else:
150
  raise ValueError('You have to specify either input_ids or inputs_embeds')
151
+ if (past_key_values is None):
152
+ past_key_values = tuple(([None] * len(self.h)))
153
  head_mask = self.get_head_mask(head_mask, self.config.n_layer)
154
+ if (inputs_embeds is None):
155
  inputs_embeds = self.word_embeddings(input_ids)
156
  hidden_states = self.word_embeddings_layernorm(inputs_embeds)
157
+ presents = (() if use_cache else None)
158
+ all_self_attentions = (() if output_attentions else None)
159
+ all_hidden_states = (() if output_hidden_states else None)
160
  seq_length_with_past = seq_length
161
  past_key_values_length = 0
162
+ if (past_key_values[0] is not None):
163
  tmp = past_key_values[0][0]
164
  past_key_values_length = tmp.shape[2]
165
+ seq_length_with_past = (seq_length_with_past + past_key_values_length)
166
+ if (attention_mask is None):
167
  attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
168
  else:
169
  attention_mask = attention_mask.to(hidden_states.device)
 
172
  for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
173
  if output_hidden_states:
174
  hst = (hidden_states,)
175
+ all_hidden_states = (all_hidden_states + hst)
176
+ if (self.gradient_checkpointing and self.training):
177
  if use_cache:
178
  logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
179
  use_cache = False
 
187
  else:
188
  outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
189
  hidden_states = outputs[0]
190
+ if (use_cache is True):
191
+ presents = (presents + (outputs[1],))
192
  if output_attentions:
193
+ oa = (outputs[(2 if use_cache else 1)],)
194
+ all_self_attentions = (all_self_attentions + oa)
195
  hidden_states = self.ln_f(hidden_states)
196
  if output_hidden_states:
197
  hst = (hidden_states,)
198
+ all_hidden_states = (all_hidden_states + hst)
199
+ if (not return_dict):
200
+ return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if (v is not None)))
201
  return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
202
  setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
203
  setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
204
  setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
205
+ KeyValueT = Tuple[(torch.Tensor, torch.Tensor)]
206
 
207
+ def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[(KeyValueT, ...)]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[(Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions)]:
208
+ 'Replacement forward method for BloomCausalLM.'
209
+ if (deprecated_arguments.pop('position_ids', False) is not False):
210
+ warnings.warn(('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.'), FutureWarning)
211
+ if (len(deprecated_arguments) > 0):
212
  raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
213
+ return_dict = (return_dict if (return_dict is not None) else self.config.use_return_dict)
214
  transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
215
  hidden_states = transformer_outputs[0]
216
  lm_logits = self.lm_head(hidden_states)
217
  loss = None
218
+ if (labels is not None):
219
+ shift_logits = lm_logits[..., :(- 1), :].contiguous()
220
  shift_labels = labels[..., 1:].contiguous()
221
  (batch_size, seq_length, vocab_size) = shift_logits.shape
222
  loss_fct = CrossEntropyLoss()
223
+ loss = loss_fct(shift_logits.view((batch_size * seq_length), vocab_size), shift_labels.view((batch_size * seq_length)))
224
+ if (not return_dict):
225
+ output = ((lm_logits,) + transformer_outputs[1:])
226
+ return (((loss,) + output) if (loss is not None) else output)
227
  return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
228
 
229
  def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
230
  if past:
231
+ input_ids = input_ids[:, (- 1)].unsqueeze((- 1))
232
  bidirectional_mask = None
233
+ if (past[0][0].shape[0] == input_ids.shape[0]):
234
  past = self._convert_to_bloom_cache(past)
235
  else:
236
  bidirectional_mask = torch.ones_like(input_ids)
 
241
  return model
242
 
243
  def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
244
+ 'Converts an OPT Causal LM to a Prefix LM.\n\n Supported HuggingFace model classes:\n - `OPTForCausalLM`\n\n See `convert_hf_causal_lm_to_prefix_lm` for more details.\n '
 
 
 
 
 
 
245
  if hasattr(model, '_prefix_lm_converted'):
246
  return model
247
  assert isinstance(model, OPTForCausalLM)
248
+ assert (model.config.add_cross_attention == False), 'Only supports OPT decoder-only models'
249
  setattr(model, '_original_forward', getattr(model, 'forward'))
250
  setattr(model, '_original_generate', getattr(model, 'generate'))
251
  model.model.decoder.bidirectional_mask = None
252
 
253
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
254
  combined_attention_mask = None
255
+ if (input_shape[(- 1)] > 1):
256
+ if (self.bidirectional_mask == 'g'):
257
  (bsz, src_length) = input_shape
258
+ combined_attention_mask = torch.zeros((bsz, 1, src_length, (src_length + past_key_values_length)), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
259
  else:
260
  combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
261
+ if (self.bidirectional_mask is not None):
262
+ assert (attention_mask.shape == self.bidirectional_mask.shape)
263
+ expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[(- 1)]).to(inputs_embeds.device)
264
  combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
265
+ if (attention_mask is not None):
266
+ expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[(- 1)]).to(inputs_embeds.device)
267
+ combined_attention_mask = (expanded_attn_mask if (combined_attention_mask is None) else (expanded_attn_mask + combined_attention_mask))
268
  return combined_attention_mask
269
  setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
270
 
 
272
 
273
  def call_og_forward():
274
  return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
275
+ if (bidirectional_mask is None):
276
  return call_og_forward()
277
  self.model.decoder.bidirectional_mask = bidirectional_mask
278
  try:
 
283
  self.model.decoder.bidirectional_mask = None
284
  return outputs
285
 
286
+ def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[(str, Any)]):
287
+ 'Wraps original generate to enable PrefixLM-style attention.'
288
  self.model.decoder.bidirectional_mask = 'g'
289
  try:
290
  output = self._original_generate(*args, **kwargs)
 
297
  setattr(model, 'generate', MethodType(generate, model))
298
  setattr(model, '_prefix_lm_converted', True)
299
  return model
300
+ _SUPPORTED_HF_MODELS = (_SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM))
301
+ CAUSAL_LM_TYPES = Union[(GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM)]
302
 
303
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
304
+ 'Converts a HuggingFace Causal LM to a Prefix LM.\n\n Supported HuggingFace model classes:\n - `GPT2LMHeadModel`\n - `GPTNeoForCausalLM`\n - `GPTNeoXForCausalLM`\n - `GPTJForCausalLM`\n - `BloomForCausalLM`\n - `OPTForCausalLM`\n\n Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the\n `generate` method and/or select underlying methods depending on the model class.\n\n These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".\n\n Notes on training:\n To actually train the converted model as a Prefix LM, training batches will need to indicate\n the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.\n\n **This is not a standard input and requires custom layers either within or after your dataloader.**\n\n In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`\n such that `batch[\'labels\'][batch[\'bidirectional_mask\'] == 1] == -100`.\n That is, the prefix portion of the sequence should not generate any loss. Loss should only be\n generated by the target portion of the sequence.\n\n Notes on `GPTNeoForCausalLM`:\n To simplify the implementation, "global" and "local" attention layers are handled differently.\n For "global" layers, we handle conversion as described above. For "local" layers, which use a\n causal attention mask within a restricted local window, we do not alter the masking.\n\n Notes on `forward` method conversion:\n After conversion, the `forward` method will handle a new input, `bidirectional_mask`,\n which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions\n belonging to the prefix (prefix tokens can attend to one another bidirectionally), and\n 0 indicates token positions belonging to the target.\n\n The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing\n causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset\n the causal masks before returning the result.\n\n Notes on `generate` method conversion:\n After conversion, the `generate` method will have the same signature but will internally\n convert all causal masks to be purely bidirectional, call the original `generate` method, and\n (where appropriate) reset the causal masks before returning the result.\n\n This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token\n "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates\n each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one\n another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and\n previously-generated tokens (also as expected in a Prefix LM).\n\n To preserve the API, the original methods are renamed to `_original_forward` and\n `_original_generate`, and replaced with new `forward` and `generate` methods that wrap\n them, respectively. Although implementation details vary by model class.\n '
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  if isinstance(model, _SUPPORTED_GPT_MODELS):
306
  return _convert_gpt_causal_lm_to_prefix_lm(model)
307
  elif isinstance(model, BloomForCausalLM):
 
309
  elif isinstance(model, OPTForCausalLM):
310
  return _convert_opt_causal_lm_to_prefix_lm(model)
311
  else:
312
+ raise TypeError(((f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:') + f'''
313
+ {_SUPPORTED_HF_MODELS}'''))
 
 
314
 
315
+ def add_bidirectional_mask_if_missing(batch: Dict[(str, Any)]):
316
+ "Attempts to add bidirectional_mask to batch if missing.\n\n Raises:\n KeyError if bidirectional_mask is missing and can't be inferred\n "
317
+ if ('bidirectional_mask' not in batch):
318
+ if (batch.get('mode', None) == 'icl_task'):
 
319
  batch['bidirectional_mask'] = batch['attention_mask'].clone()
320
  for (i, continuation_indices) in enumerate(batch['continuation_indices']):
321
+ batch['bidirectional_mask'][(i, continuation_indices)] = 0
322
+ elif (('labels' in batch) and ('attention_mask' in batch)):
323
+ batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], (- 100))).type_as(batch['attention_mask'])
324
  else:
325
+ raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')
is_torch_version.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import operator as op
4
+ from packaging import version
5
+ from packaging.version import Version, parse
6
+ from typing import Union
7
+ import importlib.util
8
+
9
+ # The package importlib_metadata is in a different place, depending on the python version.
10
+ if sys.version_info < (3, 8):
11
+ import importlib_metadata
12
+ else:
13
+ import importlib.metadata as importlib_metadata
14
+
15
+ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ _torch_available = importlib.util.find_spec("torch") is not None
20
+ if _torch_available:
21
+ try:
22
+ _torch_version = importlib_metadata.version("torch")
23
+ logger.info(f"PyTorch version {_torch_version} available.")
24
+ except importlib_metadata.PackageNotFoundError:
25
+ _torch_available = False
26
+
27
+ # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
28
+ def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
29
+ """
30
+ Args:
31
+ Compares a library version to some requirement using a given operation.
32
+ library_or_version (`str` or `packaging.version.Version`):
33
+ A library name or a version to check.
34
+ operation (`str`):
35
+ A string representation of an operator, such as `">"` or `"<="`.
36
+ requirement_version (`str`):
37
+ The version to compare the library version against
38
+ """
39
+ if operation not in STR_OPERATION_TO_FUNC.keys():
40
+ raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
41
+ operation = STR_OPERATION_TO_FUNC[operation]
42
+ if isinstance(library_or_version, str):
43
+ library_or_version = parse(importlib_metadata.version(library_or_version))
44
+ return operation(library_or_version, parse(requirement_version))
45
+
46
+ # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
47
+ def is_torch_version(operation: str, version: str):
48
+ """
49
+ Args:
50
+ Compares the current PyTorch version to a given reference with an operation.
51
+ operation (`str`):
52
+ A string representation of an operator, such as `">"` or `"<="`
53
+ version (`str`):
54
+ A string version of PyTorch
55
+ """
56
+ return compare_versions(parse(_torch_version), operation, version)
meta_init_context.py CHANGED
@@ -1,72 +1,31 @@
 
1
  from contextlib import contextmanager
2
  import torch
3
  import torch.nn as nn
4
 
5
  @contextmanager
6
  def init_empty_weights(include_buffers: bool=False):
7
- """Meta initialization context manager.
8
-
9
- A context manager under which models are initialized with all parameters
10
- on the meta device, therefore creating an empty model. Useful when just
11
- initializing the model would blow the available RAM.
12
-
13
- Args:
14
- include_buffers (`bool`, *optional*, defaults to `False`): Whether or
15
- not to also put all buffers on the meta device while initializing.
16
-
17
- Example:
18
- ```python
19
- import torch.nn as nn
20
-
21
- # Initialize a model with 100 billions parameters in no time and without using any RAM.
22
- with init_empty_weights():
23
- tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
24
- ```
25
-
26
- <Tip warning={true}>
27
-
28
- Any model created under this context manager has no weights. As such you can't do something like
29
- `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
30
-
31
- </Tip>
32
- """
33
  with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
34
- yield f
35
 
36
  @contextmanager
37
  def init_on_device(device: torch.device, include_buffers: bool=False):
38
- """Device initialization context manager.
39
-
40
- A context manager under which models are initialized with all parameters
41
- on the specified device.
42
-
43
- Args:
44
- device (`torch.device`): Device to initialize all parameters on.
45
- include_buffers (`bool`, *optional*, defaults to `False`): Whether or
46
- not to also put all buffers on the meta device while initializing.
47
-
48
- Example:
49
- ```python
50
- import torch.nn as nn
51
-
52
- with init_on_device(device=torch.device("cuda")):
53
- tst = nn.Liner(100, 100) # on `cuda` device
54
- ```
55
- """
56
  old_register_parameter = nn.Module.register_parameter
57
  if include_buffers:
58
  old_register_buffer = nn.Module.register_buffer
59
 
60
  def register_empty_parameter(module, name, param):
61
  old_register_parameter(module, name, param)
62
- if param is not None:
63
  param_cls = type(module._parameters[name])
64
  kwargs = module._parameters[name].__dict__
65
  module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
66
 
67
  def register_empty_buffer(module, name, buffer):
68
  old_register_buffer(module, name, buffer)
69
- if buffer is not None:
70
  module._buffers[name] = module._buffers[name].to(device)
71
  if include_buffers:
72
  tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
@@ -85,10 +44,10 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
85
  nn.Module.register_buffer = register_empty_buffer
86
  for torch_function_name in tensor_constructors_to_patch.keys():
87
  setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
88
- yield
89
  finally:
90
  nn.Module.register_parameter = old_register_parameter
91
  if include_buffers:
92
  nn.Module.register_buffer = old_register_buffer
93
  for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
94
- setattr(torch, torch_function_name, old_torch_function)
 
1
+
2
  from contextlib import contextmanager
3
  import torch
4
  import torch.nn as nn
5
 
6
  @contextmanager
7
  def init_empty_weights(include_buffers: bool=False):
8
+ "Meta initialization context manager.\n\n A context manager under which models are initialized with all parameters\n on the meta device, therefore creating an empty model. Useful when just\n initializing the model would blow the available RAM.\n\n Args:\n include_buffers (`bool`, *optional*, defaults to `False`): Whether or\n not to also put all buffers on the meta device while initializing.\n\n Example:\n ```python\n import torch.nn as nn\n\n # Initialize a model with 100 billions parameters in no time and without using any RAM.\n with init_empty_weights():\n tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])\n ```\n\n <Tip warning={true}>\n\n Any model created under this context manager has no weights. As such you can't do something like\n `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].\n\n </Tip>\n "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
10
+ (yield f)
11
 
12
  @contextmanager
13
  def init_on_device(device: torch.device, include_buffers: bool=False):
14
+ 'Device initialization context manager.\n\n A context manager under which models are initialized with all parameters\n on the specified device.\n\n Args:\n device (`torch.device`): Device to initialize all parameters on.\n include_buffers (`bool`, *optional*, defaults to `False`): Whether or\n not to also put all buffers on the meta device while initializing.\n\n Example:\n ```python\n import torch.nn as nn\n\n with init_on_device(device=torch.device("cuda")):\n tst = nn.Liner(100, 100) # on `cuda` device\n ```\n '
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  old_register_parameter = nn.Module.register_parameter
16
  if include_buffers:
17
  old_register_buffer = nn.Module.register_buffer
18
 
19
  def register_empty_parameter(module, name, param):
20
  old_register_parameter(module, name, param)
21
+ if (param is not None):
22
  param_cls = type(module._parameters[name])
23
  kwargs = module._parameters[name].__dict__
24
  module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
25
 
26
  def register_empty_buffer(module, name, buffer):
27
  old_register_buffer(module, name, buffer)
28
+ if (buffer is not None):
29
  module._buffers[name] = module._buffers[name].to(device)
30
  if include_buffers:
31
  tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
 
44
  nn.Module.register_buffer = register_empty_buffer
45
  for torch_function_name in tensor_constructors_to_patch.keys():
46
  setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
47
+ (yield)
48
  finally:
49
  nn.Module.register_parameter = old_register_parameter
50
  if include_buffers:
51
  nn.Module.register_buffer = old_register_buffer
52
  for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
53
+ setattr(torch, torch_function_name, old_torch_function)
modeling_mpt.py CHANGED
@@ -4,25 +4,45 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
  """
5
  import math
6
  import warnings
7
- from typing import List, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
 
11
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
- from .attention import attn_bias_shape, build_attn_bias
14
- from .blocks import MPTBlock
 
15
  from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
 
 
 
 
 
26
 
27
  class MPTModel(MPTPreTrainedModel):
28
 
@@ -64,6 +84,7 @@ class MPTModel(MPTPreTrainedModel):
64
  if self.config.init_config['verbose'] > 1:
65
  init_fn_name = self.config.init_config['name']
66
  warnings.warn(f'Using {init_fn_name} initialization.')
 
67
 
68
  def get_input_embeddings(self):
69
  return self.wte
@@ -95,7 +116,9 @@ class MPTModel(MPTPreTrainedModel):
95
  if attn_bias is None:
96
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
97
  else:
98
- attn_bias = attn_bias[:, :, :, -s_k:]
 
 
99
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
100
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
101
  min_val = torch.finfo(attn_bias.dtype).min
@@ -130,6 +153,12 @@ class MPTModel(MPTPreTrainedModel):
130
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
131
  return_dict = return_dict if return_dict is not None else self.config.return_dict
132
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
133
  if attention_mask is not None:
134
  attention_mask = attention_mask.bool()
135
  if prefix_mask is not None:
@@ -137,7 +166,10 @@ class MPTModel(MPTPreTrainedModel):
137
  if not return_dict:
138
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
139
  if output_attentions:
140
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
 
 
 
141
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
142
  raise NotImplementedError('MPT does not support training with left padding.')
143
  if self.prefix_lm and prefix_mask is None:
@@ -157,7 +189,12 @@ class MPTModel(MPTPreTrainedModel):
157
  if past_key_values is not None:
158
  if len(past_key_values) != self.config.n_layers:
159
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
 
 
 
160
  past_position = past_key_values[0][0].size(1)
 
 
161
  if S + past_position > self.config.max_seq_len:
162
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
163
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
@@ -175,16 +212,60 @@ class MPTModel(MPTPreTrainedModel):
175
  if use_cache and past_key_values is None:
176
  past_key_values = [() for _ in range(self.config.n_layers)]
177
  all_hidden_states = () if output_hidden_states else None
 
178
  for (b_idx, block) in enumerate(self.blocks):
179
  if output_hidden_states:
180
  assert all_hidden_states is not None
181
  all_hidden_states = all_hidden_states + (x,)
182
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
183
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  if past_key_values is not None:
185
  past_key_values[b_idx] = past_key_value
 
 
 
186
  x = self.norm_f(x)
187
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
 
188
 
189
  def param_init_fn(self, module):
190
  init_fn_name = self.config.init_config['name']
@@ -231,7 +312,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
231
  def get_decoder(self):
232
  return self.transformer
233
 
234
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
235
  return_dict = return_dict if return_dict is not None else self.config.return_dict
236
  use_cache = use_cache if use_cache is not None else self.config.use_cache
237
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
@@ -245,7 +326,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
245
  labels = torch.roll(labels, shifts=-1)
246
  labels[:, -1] = -100
247
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
248
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
249
 
250
  def param_init_fn(self, module):
251
  init_fn_name = self.config.init_config['name']
 
4
  """
5
  import math
6
  import warnings
7
+ from typing import Any, List, Optional, Tuple, Union, Protocol, Dict
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ from torch.utils.checkpoint import checkpoint
12
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.utils import logging
15
+ from .attention import attn_bias_shape, build_attn_bias, PastKeyValue, MultiheadAttention, MultiQueryAttention
16
+ from .blocks import MPTBlock, MPTBlockOutput
17
  from .norm import NORM_CLASS_REGISTRY
18
  from .configuration_mpt import MPTConfig
19
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
20
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
21
  from .meta_init_context import init_empty_weights
22
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
23
+ from .is_torch_version import is_torch_version
24
+
25
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
26
 
27
+ logger = logging.get_logger(__name__)
28
+
29
+ class MPTBlockCheckpointedForward(Protocol):
30
+ def __call__(
31
+ x: torch.Tensor,
32
+ past_key_value: Union[PastKeyValue, Tuple, None],
33
+ attn_bias: Optional[torch.Tensor],
34
+ attention_mask: Optional[torch.ByteTensor],
35
+ is_causal: bool,
36
+ ) -> MPTBlockOutput: ...
37
+
38
  class MPTPreTrainedModel(PreTrainedModel):
39
  config_class = MPTConfig
40
  base_model_prefix = 'model'
41
+ _no_split_modules = ['MPTBlock']
42
+ supports_gradient_checkpointing = True
43
+ def _set_gradient_checkpointing(self, module: nn.Module, value=False) -> None:
44
+ if isinstance(module, MPTModel) or isinstance(module, MultiheadAttention) or isinstance(module, MultiQueryAttention):
45
+ module.gradient_checkpointing = value
46
 
47
  class MPTModel(MPTPreTrainedModel):
48
 
 
84
  if self.config.init_config['verbose'] > 1:
85
  init_fn_name = self.config.init_config['name']
86
  warnings.warn(f'Using {init_fn_name} initialization.')
87
+ self.gradient_checkpointing = False
88
 
89
  def get_input_embeddings(self):
90
  return self.wte
 
116
  if attn_bias is None:
117
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
118
  else:
119
+ # clamp to 0 necessary for torch 2.0 compile()
120
+ _s_k = max(0, attn_bias.size(-1) - s_k)
121
+ attn_bias = attn_bias[:, :, :, _s_k:]
122
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
123
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
124
  min_val = torch.finfo(attn_bias.dtype).min
 
153
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
154
  return_dict = return_dict if return_dict is not None else self.config.return_dict
155
  use_cache = use_cache if use_cache is not None else self.config.use_cache
156
+ if self.gradient_checkpointing and self.training:
157
+ if use_cache:
158
+ logger.warning_once(
159
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
160
+ )
161
+ use_cache = False
162
  if attention_mask is not None:
163
  attention_mask = attention_mask.bool()
164
  if prefix_mask is not None:
 
166
  if not return_dict:
167
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
168
  if output_attentions:
169
+ if self.attn_impl != 'torch':
170
+ raise NotImplementedError(
171
+ 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
172
+ )
173
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
174
  raise NotImplementedError('MPT does not support training with left padding.')
175
  if self.prefix_lm and prefix_mask is None:
 
189
  if past_key_values is not None:
190
  if len(past_key_values) != self.config.n_layers:
191
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
192
+ # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
193
+ # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
194
+ # Here we shift position embedding using the `seq` dim of the past key
195
  past_position = past_key_values[0][0].size(1)
196
+ if self.attn_impl == 'torch':
197
+ past_position = past_key_values[0][0].size(3)
198
  if S + past_position > self.config.max_seq_len:
199
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
200
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
212
  if use_cache and past_key_values is None:
213
  past_key_values = [() for _ in range(self.config.n_layers)]
214
  all_hidden_states = () if output_hidden_states else None
215
+ all_self_attns = () if output_attentions else None
216
  for (b_idx, block) in enumerate(self.blocks):
217
  if output_hidden_states:
218
  assert all_hidden_states is not None
219
  all_hidden_states = all_hidden_states + (x,)
220
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
221
+ if self.gradient_checkpointing and self.training:
222
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
223
+ def create_custom_forward(module: MPTBlock) -> MPTBlockCheckpointedForward:
224
+ def custom_forward(
225
+ x: torch.Tensor,
226
+ past_key_value: Union[PastKeyValue, Tuple, None],
227
+ attn_bias: Optional[torch.Tensor],
228
+ attention_mask: Optional[torch.ByteTensor],
229
+ is_causal: bool
230
+ ):
231
+ return module.forward(
232
+ x,
233
+ past_key_value,
234
+ attn_bias,
235
+ attention_mask,
236
+ is_causal,
237
+ )
238
+ return custom_forward
239
+ block_out: MPTBlockOutput = checkpoint(
240
+ create_custom_forward(block),
241
+ x,
242
+ past_key_value,
243
+ attn_bias,
244
+ attention_mask,
245
+ self.is_causal,
246
+ **ckpt_kwargs,
247
+ )
248
+ else:
249
+ block_out: MPTBlockOutput = block(
250
+ x,
251
+ past_key_value=past_key_value,
252
+ attn_bias=attn_bias,
253
+ attention_mask=attention_mask,
254
+ is_causal=self.is_causal,
255
+ )
256
+ x, attn_weights, past_key_value = block_out
257
+ del block_out
258
  if past_key_values is not None:
259
  past_key_values[b_idx] = past_key_value
260
+ if output_attentions:
261
+ assert all_self_attns is not None # pyright
262
+ all_self_attns = all_self_attns + (attn_weights,)
263
  x = self.norm_f(x)
264
+ # add hidden states from the last decoder layer
265
+ if output_hidden_states:
266
+ assert all_hidden_states is not None # pyright
267
+ all_hidden_states = all_hidden_states + (x,)
268
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
269
 
270
  def param_init_fn(self, module):
271
  init_fn_name = self.config.init_config['name']
 
312
  def get_decoder(self):
313
  return self.transformer
314
 
315
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, *args, **kwargs):
316
  return_dict = return_dict if return_dict is not None else self.config.return_dict
317
  use_cache = use_cache if use_cache is not None else self.config.use_cache
318
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
 
326
  labels = torch.roll(labels, shifts=-1)
327
  labels[:, -1] = -100
328
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
329
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
330
 
331
  def param_init_fn(self, module):
332
  init_fn_name = self.config.init_config['name']
norm.py CHANGED
@@ -1,10 +1,11 @@
 
1
  import torch
2
 
3
  def _cast_if_autocast_enabled(tensor):
4
  if torch.is_autocast_enabled():
5
- if tensor.device.type == 'cuda':
6
  dtype = torch.get_autocast_gpu_dtype()
7
- elif tensor.device.type == 'cpu':
8
  dtype = torch.get_autocast_cpu_dtype()
9
  else:
10
  raise NotImplementedError()
@@ -19,15 +20,15 @@ class LPLayerNorm(torch.nn.LayerNorm):
19
  def forward(self, x):
20
  module_device = x.device
21
  downcast_x = _cast_if_autocast_enabled(x)
22
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23
- downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24
  with torch.autocast(enabled=False, device_type=module_device.type):
25
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26
 
27
  def rms_norm(x, weight=None, eps=1e-05):
28
- output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29
- if weight is not None:
30
- return output * weight
31
  return output
32
 
33
  class RMSNorm(torch.nn.Module):
@@ -50,7 +51,7 @@ class LPRMSNorm(RMSNorm):
50
 
51
  def forward(self, x):
52
  downcast_x = _cast_if_autocast_enabled(x)
53
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54
  with torch.autocast(enabled=False, device_type=x.device.type):
55
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56
- NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
 
1
+
2
  import torch
3
 
4
  def _cast_if_autocast_enabled(tensor):
5
  if torch.is_autocast_enabled():
6
+ if (tensor.device.type == 'cuda'):
7
  dtype = torch.get_autocast_gpu_dtype()
8
+ elif (tensor.device.type == 'cpu'):
9
  dtype = torch.get_autocast_cpu_dtype()
10
  else:
11
  raise NotImplementedError()
 
20
  def forward(self, x):
21
  module_device = x.device
22
  downcast_x = _cast_if_autocast_enabled(x)
23
+ downcast_weight = (_cast_if_autocast_enabled(self.weight) if (self.weight is not None) else self.weight)
24
+ downcast_bias = (_cast_if_autocast_enabled(self.bias) if (self.bias is not None) else self.bias)
25
  with torch.autocast(enabled=False, device_type=module_device.type):
26
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
27
 
28
  def rms_norm(x, weight=None, eps=1e-05):
29
+ output = (x / torch.rsqrt((x.pow(2).mean((- 1), keepdim=True) + eps)))
30
+ if (weight is not None):
31
+ return (output * weight)
32
  return output
33
 
34
  class RMSNorm(torch.nn.Module):
 
51
 
52
  def forward(self, x):
53
  downcast_x = _cast_if_autocast_enabled(x)
54
+ downcast_weight = (_cast_if_autocast_enabled(self.weight) if (self.weight is not None) else self.weight)
55
  with torch.autocast(enabled=False, device_type=x.device.type):
56
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
57
+ NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
param_init_fns.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import math
2
  import warnings
3
  from collections.abc import Sequence
@@ -9,110 +10,110 @@ from .norm import NORM_CLASS_REGISTRY
9
 
10
  def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
11
  del kwargs
12
- if verbose > 1:
13
  warnings.warn(f"Initializing network using module's reset_parameters attribute")
14
  if hasattr(module, 'reset_parameters'):
15
  module.reset_parameters()
16
 
17
  def fused_init_helper_(module: nn.Module, init_fn_):
18
  _fused = getattr(module, '_fused', None)
19
- if _fused is None:
20
  raise RuntimeError(f'Internal logic error')
21
  (dim, splits) = _fused
22
  splits = (0, *splits, module.weight.size(dim))
23
- for (s, e) in zip(splits[:-1], splits[1:]):
24
- slice_indices = [slice(None)] * module.weight.ndim
25
  slice_indices[dim] = slice(s, e)
26
  init_fn_(module.weight[slice_indices])
27
 
28
- def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
29
  del kwargs
30
- if verbose > 1:
31
  warnings.warn(f'If model has bias parameters they are initialized to 0.')
32
  init_div_is_residual = init_div_is_residual
33
- if init_div_is_residual is False:
34
  div_is_residual = 1.0
35
- elif init_div_is_residual is True:
36
- div_is_residual = math.sqrt(2 * n_layers)
37
- elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
38
  div_is_residual = init_div_is_residual
39
- elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40
  div_is_residual = float(init_div_is_residual)
41
  else:
42
  div_is_residual = 1.0
43
  raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
44
- if init_div_is_residual is not False:
45
- if verbose > 1:
46
- warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
47
  if isinstance(module, nn.Linear):
48
  if hasattr(module, '_fused'):
49
  fused_init_helper_(module, init_fn_)
50
  else:
51
  init_fn_(module.weight)
52
- if module.bias is not None:
53
  torch.nn.init.zeros_(module.bias)
54
- if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
  with torch.no_grad():
56
  module.weight.div_(div_is_residual)
57
  elif isinstance(module, nn.Embedding):
58
- if emb_init_std is not None:
59
  std = emb_init_std
60
- if std == 0:
61
  warnings.warn(f'Embedding layer initialized to 0.')
62
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63
- if verbose > 1:
64
  warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
65
- elif emb_init_uniform_lim is not None:
66
  lim = emb_init_uniform_lim
67
  if isinstance(lim, Sequence):
68
- if len(lim) > 2:
69
  raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
70
- if lim[0] == lim[1]:
71
  warnings.warn(f'Embedding layer initialized to {lim[0]}.')
72
  else:
73
- if lim == 0:
74
  warnings.warn(f'Embedding layer initialized to 0.')
75
- lim = [-lim, lim]
76
  (a, b) = lim
77
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
78
- if verbose > 1:
79
  warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
80
  else:
81
  emb_init_fn_ = init_fn_
82
  emb_init_fn_(module.weight)
83
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
84
- if verbose > 1:
85
  warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
86
- if hasattr(module, 'weight') and module.weight is not None:
87
  torch.nn.init.ones_(module.weight)
88
- if hasattr(module, 'bias') and module.bias is not None:
89
  torch.nn.init.zeros_(module.bias)
90
  elif isinstance(module, nn.MultiheadAttention):
91
  if module._qkv_same_embed_dim:
92
- assert module.in_proj_weight is not None
93
- assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
94
- assert d_model is not None
95
  _d = d_model
96
- splits = (0, _d, 2 * _d, 3 * _d)
97
- for (s, e) in zip(splits[:-1], splits[1:]):
98
  init_fn_(module.in_proj_weight[s:e])
99
  else:
100
- assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
101
- assert module.in_proj_weight is None
102
  init_fn_(module.q_proj_weight)
103
  init_fn_(module.k_proj_weight)
104
  init_fn_(module.v_proj_weight)
105
- if module.in_proj_bias is not None:
106
  torch.nn.init.zeros_(module.in_proj_bias)
107
- if module.bias_k is not None:
108
  torch.nn.init.zeros_(module.bias_k)
109
- if module.bias_v is not None:
110
  torch.nn.init.zeros_(module.bias_v)
111
  init_fn_(module.out_proj.weight)
112
- if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
113
  with torch.no_grad():
114
  module.out_proj.weight.div_(div_is_residual)
115
- if module.out_proj.bias is not None:
116
  torch.nn.init.zeros_(module.out_proj.bias)
117
  else:
118
  for _ in module.parameters(recurse=False):
@@ -121,61 +122,56 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
121
  def _normal_init_(std, mean=0.0):
122
  return partial(torch.nn.init.normal_, mean=mean, std=std)
123
 
124
- def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
125
  del kwargs
126
  init_fn_ = _normal_init_(std=std)
127
- if verbose > 1:
128
  warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
129
  generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
130
 
131
- def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
132
  del kwargs
133
- if init_std is None:
134
  raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
135
  _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
136
 
137
- def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
138
  del kwargs
139
- std = math.sqrt(2 / (5 * d_model))
140
  _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
141
 
142
- def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
143
- """From section 2.3.1 of GPT-NeoX-20B:
144
-
145
- An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
146
- see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
147
- and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
148
- """
149
  del kwargs
150
- residual_div = n_layers / math.sqrt(10)
151
- if verbose > 1:
152
  warnings.warn(f'setting init_div_is_residual to {residual_div}')
153
  small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
154
 
155
- def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
156
  del kwargs
157
- if verbose > 1:
158
- warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
159
  kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
160
  generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
161
 
162
- def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
163
  del kwargs
164
- if verbose > 1:
165
- warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
166
  kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
167
  generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
168
 
169
- def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
170
  del kwargs
171
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
172
- if verbose > 1:
173
- warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
174
  generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
175
 
176
- def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
- if verbose > 1:
179
- warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
180
  generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
181
- MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
 
1
+
2
  import math
3
  import warnings
4
  from collections.abc import Sequence
 
10
 
11
  def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
12
  del kwargs
13
+ if (verbose > 1):
14
  warnings.warn(f"Initializing network using module's reset_parameters attribute")
15
  if hasattr(module, 'reset_parameters'):
16
  module.reset_parameters()
17
 
18
  def fused_init_helper_(module: nn.Module, init_fn_):
19
  _fused = getattr(module, '_fused', None)
20
+ if (_fused is None):
21
  raise RuntimeError(f'Internal logic error')
22
  (dim, splits) = _fused
23
  splits = (0, *splits, module.weight.size(dim))
24
+ for (s, e) in zip(splits[:(- 1)], splits[1:]):
25
+ slice_indices = ([slice(None)] * module.weight.ndim)
26
  slice_indices[dim] = slice(s, e)
27
  init_fn_(module.weight[slice_indices])
28
 
29
+ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs):
30
  del kwargs
31
+ if (verbose > 1):
32
  warnings.warn(f'If model has bias parameters they are initialized to 0.')
33
  init_div_is_residual = init_div_is_residual
34
+ if (init_div_is_residual is False):
35
  div_is_residual = 1.0
36
+ elif (init_div_is_residual is True):
37
+ div_is_residual = math.sqrt((2 * n_layers))
38
+ elif (isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int)):
39
  div_is_residual = init_div_is_residual
40
+ elif (isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric()):
41
  div_is_residual = float(init_div_is_residual)
42
  else:
43
  div_is_residual = 1.0
44
  raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
45
+ if (init_div_is_residual is not False):
46
+ if (verbose > 1):
47
+ warnings.warn((f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.'))
48
  if isinstance(module, nn.Linear):
49
  if hasattr(module, '_fused'):
50
  fused_init_helper_(module, init_fn_)
51
  else:
52
  init_fn_(module.weight)
53
+ if (module.bias is not None):
54
  torch.nn.init.zeros_(module.bias)
55
+ if ((init_div_is_residual is not False) and getattr(module, '_is_residual', False)):
56
  with torch.no_grad():
57
  module.weight.div_(div_is_residual)
58
  elif isinstance(module, nn.Embedding):
59
+ if (emb_init_std is not None):
60
  std = emb_init_std
61
+ if (std == 0):
62
  warnings.warn(f'Embedding layer initialized to 0.')
63
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
64
+ if (verbose > 1):
65
  warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
66
+ elif (emb_init_uniform_lim is not None):
67
  lim = emb_init_uniform_lim
68
  if isinstance(lim, Sequence):
69
+ if (len(lim) > 2):
70
  raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
71
+ if (lim[0] == lim[1]):
72
  warnings.warn(f'Embedding layer initialized to {lim[0]}.')
73
  else:
74
+ if (lim == 0):
75
  warnings.warn(f'Embedding layer initialized to 0.')
76
+ lim = [(- lim), lim]
77
  (a, b) = lim
78
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
79
+ if (verbose > 1):
80
  warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
81
  else:
82
  emb_init_fn_ = init_fn_
83
  emb_init_fn_(module.weight)
84
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
85
+ if (verbose > 1):
86
  warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
87
+ if (hasattr(module, 'weight') and (module.weight is not None)):
88
  torch.nn.init.ones_(module.weight)
89
+ if (hasattr(module, 'bias') and (module.bias is not None)):
90
  torch.nn.init.zeros_(module.bias)
91
  elif isinstance(module, nn.MultiheadAttention):
92
  if module._qkv_same_embed_dim:
93
+ assert (module.in_proj_weight is not None)
94
+ assert ((module.q_proj_weight is None) and (module.k_proj_weight is None) and (module.v_proj_weight is None))
95
+ assert (d_model is not None)
96
  _d = d_model
97
+ splits = (0, _d, (2 * _d), (3 * _d))
98
+ for (s, e) in zip(splits[:(- 1)], splits[1:]):
99
  init_fn_(module.in_proj_weight[s:e])
100
  else:
101
+ assert ((module.q_proj_weight is not None) and (module.k_proj_weight is not None) and (module.v_proj_weight is not None))
102
+ assert (module.in_proj_weight is None)
103
  init_fn_(module.q_proj_weight)
104
  init_fn_(module.k_proj_weight)
105
  init_fn_(module.v_proj_weight)
106
+ if (module.in_proj_bias is not None):
107
  torch.nn.init.zeros_(module.in_proj_bias)
108
+ if (module.bias_k is not None):
109
  torch.nn.init.zeros_(module.bias_k)
110
+ if (module.bias_v is not None):
111
  torch.nn.init.zeros_(module.bias_v)
112
  init_fn_(module.out_proj.weight)
113
+ if ((init_div_is_residual is not False) and getattr(module.out_proj, '_is_residual', False)):
114
  with torch.no_grad():
115
  module.out_proj.weight.div_(div_is_residual)
116
+ if (module.out_proj.bias is not None):
117
  torch.nn.init.zeros_(module.out_proj.bias)
118
  else:
119
  for _ in module.parameters(recurse=False):
 
122
  def _normal_init_(std, mean=0.0):
123
  return partial(torch.nn.init.normal_, mean=mean, std=std)
124
 
125
+ def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs):
126
  del kwargs
127
  init_fn_ = _normal_init_(std=std)
128
+ if (verbose > 1):
129
  warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
130
  generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
131
 
132
+ def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs):
133
  del kwargs
134
+ if (init_std is None):
135
  raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
136
  _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
137
 
138
+ def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs):
139
  del kwargs
140
+ std = math.sqrt((2 / (5 * d_model)))
141
  _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
142
 
143
+ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs):
144
+ 'From section 2.3.1 of GPT-NeoX-20B:\n\n An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)\n see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151\n and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py\n '
 
 
 
 
 
145
  del kwargs
146
+ residual_div = (n_layers / math.sqrt(10))
147
+ if (verbose > 1):
148
  warnings.warn(f'setting init_div_is_residual to {residual_div}')
149
  small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
150
 
151
+ def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
152
  del kwargs
153
+ if (verbose > 1):
154
+ warnings.warn((f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'))
155
  kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
156
  generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
157
 
158
+ def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
159
  del kwargs
160
+ if (verbose > 1):
161
+ warnings.warn((f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'))
162
  kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
163
  generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
164
 
165
+ def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, init_gain: float=0, verbose: int=0, **kwargs):
166
  del kwargs
167
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
168
+ if (verbose > 1):
169
+ warnings.warn((f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}'))
170
  generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
171
 
172
+ def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, init_gain: float=0, verbose: int=0, **kwargs):
173
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
174
+ if (verbose > 1):
175
+ warnings.warn((f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}'))
176
  generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
177
+ MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}