michael-guenther commited on
Commit
2e3ebcb
·
1 Parent(s): 2aec9c9

upload model

Browse files
Files changed (11) hide show
  1. bert_padding.py +218 -0
  2. block.py +400 -0
  3. config.json +27 -0
  4. configuration_bert.py +42 -0
  5. embedding.py +62 -0
  6. mha.py +735 -0
  7. mlp.py +194 -0
  8. modeling_bert.py +784 -0
  9. pytorch_model.bin +3 -0
  10. tokenizer.json +0 -0
  11. tokenizer_config.json +1 -0
bert_padding.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
+
4
+ # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class IndexFirstAxis(torch.autograd.Function):
12
+ @staticmethod
13
+ def forward(ctx, input, indices):
14
+ ctx.save_for_backward(indices)
15
+ assert input.ndim >= 2
16
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17
+ second_dim = other_shape.numel()
18
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
+ # return input[indices]
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
22
+ ).reshape(-1, *other_shape)
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_output):
26
+ (indices,) = ctx.saved_tensors
27
+ assert grad_output.ndim >= 2
28
+ other_shape = grad_output.shape[1:]
29
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
30
+ grad_input = torch.zeros(
31
+ [ctx.first_axis_dim, grad_output.shape[1]],
32
+ device=grad_output.device,
33
+ dtype=grad_output.dtype,
34
+ )
35
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
+ # grad_input[indices] = grad_output
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
38
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
+
40
+
41
+ index_first_axis = IndexFirstAxis.apply
42
+
43
+
44
+ class IndexPutFirstAxis(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, values, indices, first_axis_dim):
47
+ ctx.save_for_backward(indices)
48
+ assert indices.ndim == 1
49
+ assert values.ndim >= 2
50
+ output = torch.zeros(
51
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ output[indices] = values
55
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ (indices,) = ctx.saved_tensors
61
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62
+ grad_values = grad_output[indices]
63
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64
+ return grad_values, None, None
65
+
66
+
67
+ index_put_first_axis = IndexPutFirstAxis.apply
68
+
69
+
70
+ class IndexFirstAxisResidual(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, input, indices):
73
+ ctx.save_for_backward(indices)
74
+ assert input.ndim >= 2
75
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76
+ second_dim = other_shape.numel()
77
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78
+ output = input[indices]
79
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80
+ # memory format to channel_first. In other words, input might not be contiguous.
81
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82
+ return output, input.detach()
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output, grad_residual):
86
+ (indices,) = ctx.saved_tensors
87
+ assert grad_output.ndim >= 2
88
+ other_shape = grad_output.shape[1:]
89
+ assert grad_residual.shape[1:] == other_shape
90
+ grad_input = grad_residual
91
+ # grad_input[indices] += grad_output
92
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
93
+ indices = indices.expand_as(grad_output)
94
+ grad_input.scatter_add_(0, indices, grad_output)
95
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
96
+
97
+
98
+ index_first_axis_residual = IndexFirstAxisResidual.apply
99
+
100
+
101
+ def unpad_input(hidden_states, attention_mask):
102
+ """
103
+ Arguments:
104
+ hidden_states: (batch, seqlen, ...)
105
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
106
+ Return:
107
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
108
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
109
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
110
+ max_seqlen_in_batch: int
111
+ """
112
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
115
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
119
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
120
+ # so we write custom forward and backward to make it a bit faster.
121
+ return (
122
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
123
+ indices,
124
+ cu_seqlens,
125
+ max_seqlen_in_batch,
126
+ )
127
+
128
+
129
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
130
+ """
131
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
132
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
133
+
134
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
135
+ ```
136
+ [
137
+ [2, 3, 0, 0, 0, 0],
138
+ [3, 2, 0, 0, 0, 0],
139
+ [6, 0, 0, 0, 0, 0]
140
+ ]
141
+ ```
142
+ , which refers to the 3D-attention mask:
143
+ ```
144
+ [
145
+ [
146
+ [1, 0, 0, 0, 0, 0],
147
+ [1, 1, 0, 0, 0, 0],
148
+ [0, 0, 1, 0, 0, 0],
149
+ [0, 0, 1, 1, 0, 0],
150
+ [0, 0, 1, 1, 1, 0],
151
+ [0, 0, 0, 0, 0, 1]
152
+ ],
153
+ [
154
+ [1, 0, 0, 0, 0, 0],
155
+ [1, 1, 0, 0, 0, 0],
156
+ [1, 1, 1, 0, 0, 0],
157
+ [0, 0, 0, 1, 0, 0],
158
+ [0, 0, 0, 1, 1, 0],
159
+ [0, 0, 0, 0, 0, 1]
160
+ ],
161
+ [
162
+ [1, 0, 0, 0, 0, 0],
163
+ [1, 1, 0, 0, 0, 0],
164
+ [1, 1, 1, 0, 0, 0],
165
+ [1, 1, 1, 1, 0, 0],
166
+ [1, 1, 1, 1, 1, 0],
167
+ [1, 1, 1, 1, 1, 1]
168
+ ]
169
+ ]
170
+ ```.
171
+
172
+ Arguments:
173
+ hidden_states: (batch, seqlen, ...)
174
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
175
+ Return:
176
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
177
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
178
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
179
+ max_seqlen_in_batch: int
180
+ """
181
+ length = attention_mask_in_length.sum(dim=-1)
182
+ seqlen = attention_mask_in_length.size(-1)
183
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184
+ seqlen) < length.unsqueeze(
185
+ 1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
191
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195
+ # so we write custom forward and backward to make it a bit faster.
196
+ return (
197
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198
+ indices,
199
+ cu_seqlens,
200
+ max_seqlen_in_batch,
201
+ )
202
+
203
+
204
+ def pad_input(hidden_states, indices, batch, seqlen):
205
+ """
206
+ Arguments:
207
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209
+ batch: int, batch size for the padded sequence.
210
+ seqlen: int, maximum sequence length for the padded sequence.
211
+ Return:
212
+ hidden_states: (batch, seqlen, ...)
213
+ """
214
+ dim = hidden_states.shape[-1]
215
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
+ # output[indices] = hidden_states
217
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
block.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch import Tensor
13
+ from torchvision.ops import StochasticDepth
14
+
15
+ from .mha import MHA
16
+ from .mlp import Mlp
17
+
18
+ try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
+ except ImportError:
21
+ layer_norm_fn, RMSNorm = None, None
22
+
23
+
24
+ class Block(nn.Module):
25
+ def __init__(
26
+ self,
27
+ dim,
28
+ mixer_cls=None,
29
+ mlp_cls=None,
30
+ norm_cls=nn.LayerNorm,
31
+ dropout_cls=nn.Dropout,
32
+ prenorm=True,
33
+ resid_dropout1=0.0,
34
+ resid_dropout2=0.0,
35
+ drop_path1=0.0,
36
+ drop_path2=0.0,
37
+ fused_dropout_add_ln=False,
38
+ return_residual=False,
39
+ residual_in_fp32=False,
40
+ sequence_parallel=False,
41
+ mark_shared_params=False,
42
+ ):
43
+ """
44
+ For prenorm=True, this Block has a slightly different structure compared to a regular
45
+ prenorm Transformer block.
46
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
47
+ [Ref: https://arxiv.org/abs/2002.04745]
48
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
49
+ the hidden_states (output of the MLP) and the residual.
50
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
51
+ The residual needs to be provided (except for the very first block).
52
+
53
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
54
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
55
+
56
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
57
+ This is for performance reason: for post-norm architecture, returning the input allows us
58
+ to fuse the backward of nn.Linear with the residual connection.
59
+ """
60
+ super().__init__()
61
+ self.prenorm = prenorm
62
+ self.fused_dropout_add_ln = fused_dropout_add_ln
63
+ self.return_residual = return_residual
64
+ self.residual_in_fp32 = residual_in_fp32
65
+ if self.residual_in_fp32:
66
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
67
+ if mixer_cls is None:
68
+ mixer_cls = partial(MHA, num_heads=dim // 64)
69
+ if mlp_cls is None:
70
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
71
+ self.mixer = mixer_cls(dim)
72
+ self.dropout1 = dropout_cls(resid_dropout1)
73
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
74
+ self.norm1 = norm_cls(dim)
75
+ self.mlp = mlp_cls(dim)
76
+ if not isinstance(self.mlp, nn.Identity):
77
+ self.dropout2 = dropout_cls(resid_dropout2)
78
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
79
+ self.norm2 = norm_cls(dim)
80
+
81
+ if self.fused_dropout_add_ln:
82
+ assert layer_norm_fn is not None, "Triton is not installed"
83
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
84
+ self.dropout1, nn.Dropout
85
+ )
86
+
87
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
88
+ # then the input to each worker in the tensor parallel group will be different.
89
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
90
+ # For now this is not an issue because we always use sequence_parallel=True during training
91
+ # and only use sequence_parallel=False during inference.
92
+
93
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
94
+ if sequence_parallel:
95
+ for p in self.norm1.parameters():
96
+ p._sequence_parallel = True
97
+ if hasattr(self, "norm2"):
98
+ for p in self.norm2.parameters():
99
+ p._sequence_parallel = True
100
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
101
+ if mark_shared_params:
102
+ for p in self.norm1.parameters():
103
+ p._shared_params = True
104
+ if hasattr(self, "norm2"):
105
+ for p in self.norm2.parameters():
106
+ p._shared_params = True
107
+
108
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
109
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
110
+
111
+ def forward(
112
+ self,
113
+ hidden_states: Tensor,
114
+ residual: Optional[Tensor] = None,
115
+ mixer_subset=None,
116
+ mixer_kwargs=None,
117
+ ):
118
+ r"""Pass the input through the encoder layer.
119
+
120
+ Args:
121
+ hidden_states: the sequence to the encoder layer (required).
122
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
123
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
124
+ before applying the query projection. Useful for e.g., ViT where we only care
125
+ about the CLS token in the last layer.
126
+ """
127
+ if self.prenorm:
128
+ if not self.fused_dropout_add_ln:
129
+ dropped = self.drop_path1(self.dropout1(hidden_states))
130
+ residual = (dropped + residual) if residual is not None else dropped
131
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
132
+ if self.residual_in_fp32:
133
+ residual = residual.to(torch.float32)
134
+ else:
135
+ if self.drop_path1.p == 0 or not self.training:
136
+ rowscale1 = None
137
+ else:
138
+ rowscale1 = self.drop_path1(
139
+ torch.ones(
140
+ hidden_states.shape[:-1],
141
+ device=hidden_states.device,
142
+ dtype=hidden_states.dtype,
143
+ )
144
+ )
145
+ hidden_states, residual = layer_norm_fn(
146
+ hidden_states,
147
+ self.norm1.weight,
148
+ self.norm1.bias,
149
+ residual=residual,
150
+ eps=self.norm1.eps,
151
+ dropout_p=self.dropout1.p if self.training else 0.0,
152
+ rowscale=rowscale1,
153
+ prenorm=True,
154
+ residual_in_fp32=self.residual_in_fp32,
155
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
156
+ )
157
+ if mixer_kwargs is None:
158
+ mixer_kwargs = {}
159
+ if mixer_subset is not None:
160
+ mixer_kwargs["mixer_subset"] = mixer_subset
161
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
162
+ if mixer_subset is not None:
163
+ residual = residual[:, mixer_subset]
164
+ if not isinstance(self.mlp, nn.Identity):
165
+ if not self.fused_dropout_add_ln:
166
+ dropped = self.drop_path2(self.dropout2(hidden_states))
167
+ residual = (dropped + residual) if residual is not None else dropped
168
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
169
+ if self.residual_in_fp32:
170
+ residual = residual.to(torch.float32)
171
+ else:
172
+ if self.drop_path2.p == 0 or not self.training:
173
+ rowscale2 = None
174
+ else:
175
+ rowscale2 = self.drop_path2(
176
+ torch.ones(
177
+ hidden_states.shape[:-1],
178
+ device=hidden_states.device,
179
+ dtype=hidden_states.dtype,
180
+ )
181
+ )
182
+ hidden_states, residual = layer_norm_fn(
183
+ hidden_states,
184
+ self.norm2.weight,
185
+ self.norm2.bias,
186
+ residual=residual,
187
+ eps=self.norm2.eps,
188
+ dropout_p=self.dropout2.p if self.training else 0.0,
189
+ rowscale=rowscale2,
190
+ prenorm=True,
191
+ residual_in_fp32=self.residual_in_fp32,
192
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
193
+ )
194
+ hidden_states = self.mlp(hidden_states)
195
+ return hidden_states, residual
196
+ else:
197
+ assert residual is None
198
+ mixer_out = self.mixer(
199
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
200
+ )
201
+ if self.return_residual: # mixer out is actually a pair here
202
+ mixer_out, hidden_states = mixer_out
203
+ if not self.fused_dropout_add_ln:
204
+ hidden_states = self.norm1(
205
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
206
+ dtype=self.norm1.weight.dtype
207
+ )
208
+ )
209
+ else:
210
+ if self.drop_path1.p == 0 or not self.training:
211
+ rowscale1 = None
212
+ else:
213
+ rowscale1 = self.drop_path1(
214
+ torch.ones(
215
+ mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
216
+ )
217
+ )
218
+ hidden_states = layer_norm_fn(
219
+ mixer_out,
220
+ self.norm1.weight,
221
+ self.norm1.bias,
222
+ residual=hidden_states,
223
+ eps=self.norm1.eps,
224
+ dropout_p=self.dropout1.p if self.training else 0.0,
225
+ rowscale=rowscale1,
226
+ prenorm=False,
227
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
228
+ )
229
+ if not isinstance(self.mlp, nn.Identity):
230
+ mlp_out = self.mlp(hidden_states)
231
+ if self.return_residual: # mlp out is actually a pair here
232
+ mlp_out, hidden_states = mlp_out
233
+ if not self.fused_dropout_add_ln:
234
+ hidden_states = self.norm2(
235
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
236
+ dtype=self.norm2.weight.dtype
237
+ )
238
+ )
239
+ else:
240
+ if self.drop_path2.p == 0 or not self.training:
241
+ rowscale2 = None
242
+ else:
243
+ rowscale2 = self.drop_path2(
244
+ torch.ones(
245
+ mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
246
+ )
247
+ )
248
+ hidden_states = layer_norm_fn(
249
+ mlp_out,
250
+ self.norm2.weight,
251
+ self.norm2.bias,
252
+ residual=hidden_states,
253
+ eps=self.norm2.eps,
254
+ dropout_p=self.dropout2.p if self.training else 0.0,
255
+ rowscale=rowscale2,
256
+ prenorm=False,
257
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
258
+ )
259
+ return hidden_states
260
+
261
+
262
+ class ParallelBlock(nn.Module):
263
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
264
+ and PaLM.
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ dim,
270
+ mixer_cls=None,
271
+ mlp_cls=None,
272
+ norm_cls=nn.LayerNorm,
273
+ dropout_cls=nn.Dropout,
274
+ resid_dropout1=0.0,
275
+ resid_dropout2=0.0,
276
+ tied_norm=False,
277
+ fused_dropout_add_ln=False,
278
+ residual_in_fp32=False,
279
+ sequence_parallel=False,
280
+ mark_shared_params=False,
281
+ ):
282
+ """
283
+ This Block has a slightly different structure compared to a regular
284
+ prenorm Transformer block.
285
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
286
+ [Ref: https://arxiv.org/abs/2002.04745]
287
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
288
+ the hidden_states (output1 of the MHA / MLP) and the residual.
289
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
290
+ The residual needs to be provided (except for the very first block).
291
+ """
292
+ super().__init__()
293
+ self.tied_norm = tied_norm
294
+ self.fused_dropout_add_ln = fused_dropout_add_ln
295
+ self.residual_in_fp32 = residual_in_fp32
296
+ if mixer_cls is None:
297
+ mixer_cls = partial(MHA, num_heads=dim // 64)
298
+ if mlp_cls is None:
299
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
300
+ self.mixer = mixer_cls(dim)
301
+ self.dropout1 = dropout_cls(resid_dropout1)
302
+ self.norm1 = norm_cls(dim)
303
+ self.mlp = mlp_cls(dim)
304
+ self.dropout2 = dropout_cls(resid_dropout2)
305
+ if not self.tied_norm:
306
+ self.norm2 = norm_cls(dim)
307
+
308
+ if self.fused_dropout_add_ln:
309
+ assert layer_norm_fn is not None, "Triton is not installed"
310
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
311
+ self.dropout1, nn.Dropout
312
+ )
313
+
314
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
315
+ # then the input to each worker in the tensor parallel group will be different.
316
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
317
+ # For now this is not an issue because we always use sequence_parallel=True during training
318
+ # and only use sequence_parallel=False during inference.
319
+
320
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
321
+ if sequence_parallel:
322
+ for p in self.norm1.parameters():
323
+ p._sequence_parallel = True
324
+ if hasattr(self, "norm2"):
325
+ for p in self.norm2.parameters():
326
+ p._sequence_parallel = True
327
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
328
+ if mark_shared_params:
329
+ for p in self.norm1.parameters():
330
+ p._shared_params = True
331
+ if hasattr(self, "norm2"):
332
+ for p in self.norm2.parameters():
333
+ p._shared_params = True
334
+
335
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
336
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
337
+
338
+ def forward(
339
+ self,
340
+ hidden_states1: Tensor,
341
+ hidden_states2: Optional[Tensor] = None,
342
+ residual: Optional[Tensor] = None,
343
+ mixer_kwargs=None,
344
+ ):
345
+ r"""Pass the input through the encoder layer.
346
+
347
+ Args:
348
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
349
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
350
+ residual.
351
+ """
352
+ # TODO: Ideally we should only do the allgather / allreduce once for
353
+ # the Linear to MLP & Attention
354
+ if not self.fused_dropout_add_ln:
355
+ dropped1 = self.dropout1(hidden_states1)
356
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
357
+ if hidden_states2 is not None:
358
+ dropped2 = self.dropout2(hidden_states2)
359
+ residual = (
360
+ (residual + dropped1 + dropped2)
361
+ if residual is not None
362
+ else dropped1 + dropped2
363
+ )
364
+ else:
365
+ residual = (residual + dropped1) if residual is not None else dropped1
366
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
367
+ hidden_states2 = (
368
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
369
+ if not self.tied_norm
370
+ else hidden_states1
371
+ )
372
+ if self.residual_in_fp32:
373
+ residual = residual.to(torch.float32)
374
+ else:
375
+ weight2, bias2 = (
376
+ (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
377
+ )
378
+ hidden_states1, *rest, residual = layer_norm_fn(
379
+ hidden_states1,
380
+ self.norm1.weight,
381
+ self.norm1.bias,
382
+ residual=residual,
383
+ x1=hidden_states2,
384
+ weight1=weight2,
385
+ bias1=bias2,
386
+ eps=self.norm1.eps,
387
+ dropout_p=self.dropout1.p if self.training else 0.0,
388
+ prenorm=True,
389
+ residual_in_fp32=self.residual_in_fp32,
390
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
391
+ )
392
+ if self.tied_norm:
393
+ hidden_states2 = hidden_states1
394
+ else:
395
+ hidden_states2, = rest
396
+ if mixer_kwargs is None:
397
+ mixer_kwargs = {}
398
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
399
+ hidden_states2 = self.mlp(hidden_states2)
400
+ return hidden_states1, hidden_states2, residual
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_bert.BertConfig",
4
+ "AutoModel": "modeling_bert.BertModel",
5
+ "AutoModelForPreTraining": "modeling_bert.BertForPreTraining",
6
+ "AutoModelForMaskedLM": "modeling_bert.BertForPreTraining"
7
+ },
8
+ "attention_probs_dropout_prob": 0.1,
9
+ "bos_token_id": 0,
10
+ "eos_token_id": 2,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "layer_norm_eps": 1e-05,
17
+ "max_position_embeddings": 514,
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "output_past": true,
21
+ "pad_token_id": 1,
22
+ "position_embedding_type": "absolute",
23
+ "transformers_version": "4.17.0.dev0",
24
+ "type_vocab_size": 1,
25
+ "use_cache": false,
26
+ "vocab_size": 250002
27
+ }
configuration_bert.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BertConfig(PretrainedConfig):
4
+ def __init__(
5
+ self,
6
+ vocab_size=30522,
7
+ hidden_size=768,
8
+ num_hidden_layers=12,
9
+ num_attention_heads=12,
10
+ intermediate_size=3072,
11
+ hidden_act="gelu",
12
+ hidden_dropout_prob=0.1,
13
+ attention_probs_dropout_prob=0.1,
14
+ max_position_embeddings=512,
15
+ type_vocab_size=2,
16
+ initializer_range=0.02,
17
+ layer_norm_eps=1e-12,
18
+ pad_token_id=1,
19
+ bos_token_id=0,
20
+ eos_token_id=2,
21
+ position_embedding_type="absolute",
22
+ use_cache=True,
23
+ classifier_dropout=None,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
27
+
28
+ self.vocab_size = vocab_size
29
+ self.hidden_size = hidden_size
30
+ self.num_hidden_layers = num_hidden_layers
31
+ self.num_attention_heads = num_attention_heads
32
+ self.hidden_act = hidden_act
33
+ self.intermediate_size = intermediate_size
34
+ self.hidden_dropout_prob = hidden_dropout_prob
35
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
36
+ self.max_position_embeddings = max_position_embeddings
37
+ self.type_vocab_size = type_vocab_size
38
+ self.initializer_range = initializer_range
39
+ self.layer_norm_eps = layer_norm_eps
40
+ self.position_embedding_type = position_embedding_type
41
+ self.use_cache = use_cache
42
+ self.classifier_dropout = classifier_dropout
embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
2
+ # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
+
13
+
14
+ class BertEmbeddings(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embed_dim,
18
+ vocab_size,
19
+ max_position_embeddings,
20
+ type_vocab_size,
21
+ padding_idx=None,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ """
26
+ If max_position_embeddings <= 0, there's no position embeddings
27
+ If type_vocab_size <= 0, there's no token type embeddings
28
+ """
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ self.word_embeddings = nn.Embedding(
32
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
33
+ )
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.type_vocab_size = type_vocab_size
36
+ if self.max_position_embeddings > 0:
37
+ self.position_embeddings = nn.Embedding(
38
+ max_position_embeddings, embed_dim, **factory_kwargs
39
+ )
40
+ if self.type_vocab_size > 0:
41
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
+
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
+ """
45
+ input_ids: (batch, seqlen)
46
+ position_ids: (batch, seqlen)
47
+ token_type_ids: (batch, seqlen)
48
+ """
49
+ batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
51
+ if self.max_position_embeddings > 0:
52
+ if position_ids is None:
53
+ position_ids =create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
+ # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
55
+ position_embeddings = self.position_embeddings(position_ids)
56
+ embeddings = embeddings + position_embeddings
57
+ if self.type_vocab_size > 0:
58
+ if token_type_ids is None:
59
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
+ embeddings = embeddings + token_type_embeddings
62
+ return embeddings
mha.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
+ # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import math
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange, repeat
12
+
13
+ from flash_attn.utils.distributed import get_dim_for_local_rank
14
+
15
+ try:
16
+ from flash_attn import (
17
+ flash_attn_kvpacked_func,
18
+ flash_attn_qkvpacked_func,
19
+ flash_attn_varlen_kvpacked_func,
20
+ flash_attn_varlen_qkvpacked_func,
21
+ flash_attn_with_kvcache,
22
+ )
23
+ except ImportError:
24
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
25
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
26
+ flash_attn_with_kvcache = None
27
+
28
+ try:
29
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
30
+ except ImportError:
31
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
32
+
33
+ try:
34
+ from flash_attn.layers.rotary import RotaryEmbedding
35
+ except ImportError:
36
+ RotaryEmbedding = None
37
+
38
+
39
+ # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
40
+ def get_alibi_slopes(nheads):
41
+ def get_slopes_power_of_2(nheads):
42
+ start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
43
+ ratio = start
44
+ return [start * ratio**i for i in range(nheads)]
45
+
46
+ if math.log2(nheads).is_integer():
47
+ return get_slopes_power_of_2(nheads)
48
+ else:
49
+ closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
50
+ return (
51
+ get_slopes_power_of_2(closest_power_of_2)
52
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
53
+ )
54
+
55
+
56
+ class FlashSelfAttention(nn.Module):
57
+ """Implement the scaled dot product attention with softmax.
58
+ Arguments
59
+ ---------
60
+ softmax_scale: The temperature to use for the softmax attention.
61
+ (default: 1/sqrt(d_keys) where d_keys is computed at
62
+ runtime)
63
+ attention_dropout: The dropout rate to apply to the attention
64
+ (default: 0.0)
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ causal=False,
70
+ softmax_scale=None,
71
+ attention_dropout=0.0,
72
+ window_size=(-1, -1),
73
+ alibi_slopes=None,
74
+ deterministic=False,
75
+ ):
76
+ super().__init__()
77
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
78
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
79
+ self.causal = causal
80
+ self.softmax_scale = softmax_scale
81
+ self.drop = nn.Dropout(attention_dropout)
82
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
83
+ self.window_size = window_size
84
+ self.deterministic = deterministic
85
+
86
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
87
+ """Implements the multihead softmax attention.
88
+ Arguments
89
+ ---------
90
+ qkv: The tensor containing the query, key, and value.
91
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
92
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
93
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
94
+ causal: if passed, will override self.causal
95
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
96
+ of the sequences in the batch, used to index into qkv.
97
+ max_seqlen: int. Maximum sequence length in the batch.
98
+ Returns:
99
+ --------
100
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
101
+ else (B, S, H, D).
102
+ """
103
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
104
+ assert qkv.is_cuda
105
+ causal = self.causal if causal is None else causal
106
+ unpadded = cu_seqlens is not None
107
+ if self.alibi_slopes is not None:
108
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
109
+ if unpadded:
110
+ assert cu_seqlens.dtype == torch.int32
111
+ assert max_seqlen is not None
112
+ assert isinstance(max_seqlen, int)
113
+ return flash_attn_varlen_qkvpacked_func(
114
+ qkv,
115
+ cu_seqlens,
116
+ max_seqlen,
117
+ self.drop.p if self.training else 0.0,
118
+ softmax_scale=self.softmax_scale,
119
+ causal=causal,
120
+ alibi_slopes=self.alibi_slopes,
121
+ window_size=self.window_size,
122
+ deterministic=self.deterministic,
123
+ )
124
+ else:
125
+ return flash_attn_qkvpacked_func(
126
+ qkv,
127
+ self.drop.p if self.training else 0.0,
128
+ softmax_scale=self.softmax_scale,
129
+ causal=causal,
130
+ alibi_slopes=self.alibi_slopes,
131
+ window_size=self.window_size,
132
+ deterministic=self.deterministic,
133
+ )
134
+
135
+
136
+ class FlashCrossAttention(nn.Module):
137
+ """Implement the scaled dot product attention with softmax.
138
+ Arguments
139
+ ---------
140
+ softmax_scale: The temperature to use for the softmax attention.
141
+ (default: 1/sqrt(d_keys) where d_keys is computed at
142
+ runtime)
143
+ attention_dropout: The dropout rate to apply to the attention
144
+ (default: 0.0)
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ causal=False,
150
+ softmax_scale=None,
151
+ attention_dropout=0.0,
152
+ alibi_slopes=None,
153
+ window_size=(-1, -1),
154
+ deterministic=False,
155
+ ):
156
+ super().__init__()
157
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
158
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
159
+ self.causal = causal
160
+ self.softmax_scale = softmax_scale
161
+ self.drop = nn.Dropout(attention_dropout)
162
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
163
+ self.window_size = window_size
164
+ self.deterministic = deterministic
165
+
166
+ def forward(
167
+ self,
168
+ q,
169
+ kv,
170
+ causal=None,
171
+ cu_seqlens=None,
172
+ max_seqlen=None,
173
+ cu_seqlens_k=None,
174
+ max_seqlen_k=None,
175
+ ):
176
+ """Implements the multihead softmax attention.
177
+ Arguments
178
+ ---------
179
+ q: The tensor containing the query. (B, Sq, H, D)
180
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
181
+ causal: if passed, will override self.causal
182
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
183
+ of the sequences in the batch, used to index into q.
184
+ max_seqlen: int. Maximum sequence length in the batch of q.
185
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
186
+ of the sequences in the batch, used to index into kv.
187
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
188
+ """
189
+ assert q.dtype in [torch.float16, torch.bfloat16]
190
+ assert q.is_cuda and kv.is_cuda
191
+ causal = self.causal if causal is None else causal
192
+ unpadded = cu_seqlens is not None
193
+ if self.alibi_slopes is not None:
194
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
195
+ if unpadded:
196
+ assert cu_seqlens.dtype == torch.int32
197
+ assert max_seqlen is not None
198
+ assert isinstance(max_seqlen, int)
199
+ assert cu_seqlens_k is not None
200
+ assert cu_seqlens_k.dtype == torch.int32
201
+ assert max_seqlen_k is not None
202
+ assert isinstance(max_seqlen, int)
203
+ return flash_attn_varlen_kvpacked_func(
204
+ q,
205
+ kv,
206
+ cu_seqlens,
207
+ cu_seqlens_k,
208
+ max_seqlen,
209
+ max_seqlen_k,
210
+ self.drop.p if self.training else 0.0,
211
+ softmax_scale=self.softmax_scale,
212
+ causal=causal,
213
+ alibi_slopes=self.alibi_slopes,
214
+ window_size=self.window_size,
215
+ deterministic=self.deterministic,
216
+ )
217
+ else:
218
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
219
+ seqlen_k = kv.shape[1]
220
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
221
+ return flash_attn_kvpacked_func(
222
+ q,
223
+ kv,
224
+ self.drop.p if self.training else 0.0,
225
+ causal=causal,
226
+ softmax_scale=self.softmax_scale,
227
+ alibi_slopes=self.alibi_slopes,
228
+ window_size=self.window_size,
229
+ deterministic=self.deterministic,
230
+ )
231
+
232
+
233
+ class SelfAttention(nn.Module):
234
+ """Implement the scaled dot product attention with softmax.
235
+ Arguments
236
+ ---------
237
+ softmax_scale: The temperature to use for the softmax attention.
238
+ (default: 1/sqrt(d_keys) where d_keys is computed at
239
+ runtime)
240
+ attention_dropout: The dropout rate to apply to the attention
241
+ (default: 0.0)
242
+ """
243
+
244
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
245
+ super().__init__()
246
+ self.causal = causal
247
+ self.softmax_scale = softmax_scale
248
+ self.drop = nn.Dropout(attention_dropout)
249
+
250
+ def forward(self, qkv, causal=None, key_padding_mask=None):
251
+ """Implements the multihead softmax attention.
252
+ Arguments
253
+ ---------
254
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
255
+ causal: if passed, will override self.causal
256
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
257
+ False means to mask out. (B, S)
258
+ """
259
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
260
+ causal = self.causal if causal is None else causal
261
+ q, k, v = qkv.unbind(dim=2)
262
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
263
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
264
+ if key_padding_mask is not None:
265
+ padding_mask = torch.full(
266
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
267
+ )
268
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
269
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
270
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
271
+ if causal:
272
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
273
+ # So we have to construct the mask in float
274
+ causal_mask = torch.triu(
275
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
276
+ )
277
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
278
+ scores = scores + causal_mask.to(dtype=scores.dtype)
279
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
280
+ attention_drop = self.drop(attention)
281
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
282
+ return output
283
+
284
+
285
+ class CrossAttention(nn.Module):
286
+ """Implement the scaled dot product attention with softmax.
287
+ Arguments
288
+ ---------
289
+ softmax_scale: The temperature to use for the softmax attention.
290
+ (default: 1/sqrt(d_keys) where d_keys is computed at
291
+ runtime)
292
+ attention_dropout: The dropout rate to apply to the attention
293
+ (default: 0.0)
294
+ """
295
+
296
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
297
+ super().__init__()
298
+ self.causal = causal
299
+ self.softmax_scale = softmax_scale
300
+ self.drop = nn.Dropout(attention_dropout)
301
+
302
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
303
+ """Implements the multihead softmax attention.
304
+ Arguments
305
+ ---------
306
+ q: The tensor containing the query. (B, Sq, H, D)
307
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
308
+ causal: if passed, will override self.causal
309
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
310
+ False means to mask out. (B, Sk)
311
+ """
312
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
313
+ causal = self.causal if causal is None else causal
314
+ seqlen_k = kv.shape[1]
315
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
316
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
317
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
318
+ k, v = kv.unbind(dim=2)
319
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
320
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
321
+ if key_padding_mask is not None:
322
+ padding_mask = torch.full(
323
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
324
+ )
325
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
326
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
327
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
328
+ if causal:
329
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
330
+ row_idx = rearrange(
331
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
332
+ )
333
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
334
+ sk = (
335
+ seqlen_k
336
+ if key_padding_mask is None
337
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
338
+ )
339
+ causal_mask = col_idx > row_idx + sk - seqlen_q
340
+ scores = scores.masked_fill(causal_mask, -10000.0)
341
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
342
+ attention_drop = self.drop(attention)
343
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
344
+ return output
345
+
346
+
347
+ class LinearResidual(nn.Linear):
348
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
349
+
350
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
351
+ return super().forward(input), input
352
+
353
+
354
+ def _update_kv_cache(kv, inference_params, layer_idx):
355
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
356
+ # Pre-allocate memory for key-values for inference.
357
+ num_heads, head_dim = kv.shape[-2:]
358
+ if layer_idx not in inference_params.key_value_memory_dict:
359
+ kv_cache = torch.empty(
360
+ inference_params.max_batch_size,
361
+ inference_params.max_seqlen,
362
+ 2,
363
+ num_heads,
364
+ head_dim,
365
+ dtype=kv.dtype,
366
+ device=kv.device,
367
+ )
368
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
369
+ else:
370
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
371
+ # Adjust key and value for inference
372
+ batch_start = inference_params.batch_size_offset
373
+ batch_end = batch_start + kv.shape[0]
374
+ sequence_start = inference_params.seqlen_offset
375
+ sequence_end = sequence_start + kv.shape[1]
376
+ assert batch_end <= kv_cache.shape[0]
377
+ assert sequence_end <= kv_cache.shape[1]
378
+ assert kv_cache is not None
379
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
380
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
381
+
382
+
383
+ class MHA(nn.Module):
384
+ """Multi-head self-attention and cross-attention"""
385
+
386
+ def __init__(
387
+ self,
388
+ embed_dim,
389
+ num_heads,
390
+ num_heads_kv=None,
391
+ cross_attn=False,
392
+ qkv_proj_bias=True,
393
+ out_proj_bias=True,
394
+ dropout=0.0,
395
+ softmax_scale=None,
396
+ causal=False,
397
+ layer_idx=None,
398
+ dwconv=False,
399
+ rotary_emb_dim=0,
400
+ rotary_emb_base=10000.0,
401
+ rotary_emb_scale_base=None,
402
+ rotary_emb_interleaved=False,
403
+ use_alibi=False,
404
+ window_size=(-1, -1),
405
+ fused_bias_fc=False,
406
+ use_flash_attn=False,
407
+ return_residual=False,
408
+ checkpointing=False,
409
+ device=None,
410
+ dtype=None,
411
+ ) -> None:
412
+ """
413
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
414
+ return_residual: whether to return the input x along with the output. This is for
415
+ performance reason: for post-norm architecture, returning the input allows us
416
+ to fuse the backward of nn.Linear with the residual connection.
417
+ """
418
+ factory_kwargs = {"device": device, "dtype": dtype}
419
+ super().__init__()
420
+ self.embed_dim = embed_dim
421
+ self.cross_attn = cross_attn
422
+ self.causal = causal
423
+ self.layer_idx = layer_idx
424
+ self.dwconv = dwconv
425
+ self.rotary_emb_dim = rotary_emb_dim
426
+ self.use_flash_attn = use_flash_attn
427
+ self.return_residual = return_residual
428
+ self.checkpointing = checkpointing
429
+ if use_alibi:
430
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
431
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
432
+ else:
433
+ alibi_slopes = None
434
+ if window_size != (-1, -1):
435
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
436
+
437
+ self.num_heads = num_heads
438
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
439
+ assert (
440
+ self.num_heads % self.num_heads_kv == 0
441
+ ), "num_heads must be divisible by num_heads_kv"
442
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
443
+ self.head_dim = self.embed_dim // num_heads
444
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
445
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
446
+
447
+ if self.rotary_emb_dim > 0:
448
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
449
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
450
+ self.rotary_emb = RotaryEmbedding(
451
+ self.rotary_emb_dim,
452
+ base=rotary_emb_base,
453
+ scale_base=rotary_emb_scale_base,
454
+ interleaved=rotary_emb_interleaved,
455
+ device=device,
456
+ )
457
+
458
+ if fused_bias_fc and FusedDense is None:
459
+ raise ImportError("fused_dense is not installed")
460
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
461
+ linear_resid_cls = (
462
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
463
+ )
464
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
465
+ inner_attn_cls = (
466
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
467
+ if use_flash_attn
468
+ else SelfAttention
469
+ )
470
+ inner_cross_attn_cls = (
471
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
472
+ if use_flash_attn
473
+ else CrossAttention
474
+ )
475
+ if not self.cross_attn:
476
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
477
+ else:
478
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
479
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
480
+ if self.dwconv:
481
+ if self.num_heads_kv == self.num_heads:
482
+ self.dwconv_qkv = nn.Conv1d(
483
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
484
+ )
485
+ else:
486
+ self.dwconv_q = nn.Conv1d(
487
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
488
+ )
489
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
490
+ self.inner_attn = inner_attn_cls(
491
+ causal=causal,
492
+ softmax_scale=softmax_scale,
493
+ attention_dropout=dropout,
494
+ )
495
+ self.inner_cross_attn = inner_cross_attn_cls(
496
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
497
+ )
498
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
499
+
500
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
501
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
502
+ device = self.out_proj.weight.device
503
+ return torch.empty(
504
+ batch_size,
505
+ max_seqlen,
506
+ 2,
507
+ self.num_heads_kv,
508
+ self.head_dim,
509
+ dtype=dtype,
510
+ device=device,
511
+ )
512
+
513
+ def _update_kv_cache(self, kv, inference_params):
514
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
515
+ assert not self.dwconv, "Generation does not support dwconv yet"
516
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
517
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
518
+
519
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
520
+ """
521
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
522
+ q: (batch_size, seqlen_q, nheads, head_dim)
523
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
524
+ """
525
+ assert inference_params is not None and inference_params.seqlen_offset > 0
526
+ assert self.use_flash_attn
527
+ if self.rotary_emb_dim > 0:
528
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
529
+ self.rotary_emb._update_cos_sin_cache(
530
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
531
+ )
532
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
533
+ else:
534
+ rotary_cos, rotary_sin = None, None
535
+ batch = q.shape[0]
536
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
537
+ cache_seqlens = (
538
+ inference_params.lengths_per_sample[:batch]
539
+ if inference_params.lengths_per_sample is not None
540
+ else inference_params.seqlen_offset
541
+ )
542
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
543
+ context = flash_attn_with_kvcache(
544
+ q,
545
+ kv_cache[:, :, 0],
546
+ kv_cache[:, :, 1],
547
+ kv[:, :, 0],
548
+ kv[:, :, 1],
549
+ rotary_cos=rotary_cos,
550
+ rotary_sin=rotary_sin,
551
+ cache_seqlens=cache_seqlens,
552
+ softmax_scale=self.inner_cross_attn.softmax_scale,
553
+ causal=self.inner_cross_attn.causal,
554
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
555
+ alibi_slopes=alibi_slopes,
556
+ )
557
+ return context
558
+
559
+ def _update_kvcache_attention(self, q, kv, inference_params):
560
+ """Write kv to inference_params, then do attention"""
561
+ if (
562
+ inference_params.seqlen_offset == 0
563
+ or flash_attn_with_kvcache is None
564
+ or not self.use_flash_attn
565
+ ):
566
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
567
+ kv = self._update_kv_cache(kv, inference_params)
568
+ return self.inner_cross_attn(q, kv)
569
+ else:
570
+ batch = q.shape[0]
571
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
572
+ cache_seqlens = (
573
+ inference_params.lengths_per_sample[:batch]
574
+ if inference_params.lengths_per_sample is not None
575
+ else inference_params.seqlen_offset
576
+ )
577
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
578
+ return flash_attn_with_kvcache(
579
+ q,
580
+ kv_cache[:, :, 0],
581
+ kv_cache[:, :, 1],
582
+ kv[:, :, 0],
583
+ kv[:, :, 1],
584
+ cache_seqlens=cache_seqlens,
585
+ softmax_scale=self.inner_cross_attn.softmax_scale,
586
+ causal=self.inner_cross_attn.causal,
587
+ alibi_slopes=alibi_slopes,
588
+ )
589
+
590
+ def forward(
591
+ self,
592
+ x,
593
+ x_kv=None,
594
+ key_padding_mask=None,
595
+ cu_seqlens=None,
596
+ max_seqlen=None,
597
+ mixer_subset=None,
598
+ inference_params=None,
599
+ **kwargs,
600
+ ):
601
+ """
602
+ Arguments:
603
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
604
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
605
+ is the is the sum of the sequence lengths in the batch.
606
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
607
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
608
+ of the sequences in the batch, used to index into x. Only applicable when using
609
+ FlashAttention.
610
+ max_seqlen: int. Maximum sequence length in the batch.
611
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
612
+ (batch, seqlen). Only applicable when not using FlashAttention.
613
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
614
+ before applying the query projection. Useful for e.g., ViT where we only care
615
+ about the CLS token in the last layer.
616
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
617
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
618
+ """
619
+ if cu_seqlens is not None:
620
+ assert max_seqlen is not None
621
+ assert key_padding_mask is None
622
+ assert self.use_flash_attn
623
+ assert not self.dwconv
624
+ assert self.rotary_emb_dim == 0
625
+ if key_padding_mask is not None:
626
+ assert cu_seqlens is None
627
+ assert max_seqlen is None
628
+ assert not self.use_flash_attn
629
+ if inference_params is not None:
630
+ assert key_padding_mask is None
631
+ assert cu_seqlens is None and max_seqlen is None
632
+ assert not self.dwconv
633
+
634
+ kwargs = (
635
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
636
+ if self.use_flash_attn
637
+ else {"key_padding_mask": key_padding_mask, **kwargs}
638
+ )
639
+ seqlen_offset = (
640
+ 0
641
+ if inference_params is None
642
+ else (
643
+ inference_params.lengths_per_sample
644
+ if inference_params.lengths_per_sample is not None
645
+ else inference_params.seqlen_offset
646
+ )
647
+ )
648
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
649
+ batch, seqlen = x.shape[:2]
650
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
651
+ assert x_kv is None and mixer_subset is None
652
+ if not self.return_residual:
653
+ qkv = self.Wqkv(x)
654
+ else:
655
+ qkv, x = self.Wqkv(x)
656
+ if self.dwconv:
657
+ qkv = rearrange(
658
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
659
+ ).contiguous()
660
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
661
+ if (
662
+ inference_params is None
663
+ or inference_params.seqlen_offset == 0
664
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
665
+ or not self.use_flash_attn
666
+ ):
667
+ if self.rotary_emb_dim > 0:
668
+ qkv = self.rotary_emb(
669
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
670
+ )
671
+ if inference_params is None:
672
+ if not self.checkpointing:
673
+ context = self.inner_attn(qkv, **kwargs)
674
+ else:
675
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
676
+ else:
677
+ context = self._update_kvcache_attention(
678
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
679
+ )
680
+ else:
681
+ context = self._apply_rotary_update_kvcache_attention(
682
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
683
+ )
684
+ else:
685
+ if self.cross_attn:
686
+ if not self.return_residual:
687
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
688
+ kv = self.Wkv(x_kv if x_kv is not None else x)
689
+ else:
690
+ if x_kv is not None:
691
+ kv, x_kv = self.Wkv(x_kv)
692
+ else:
693
+ kv, x = self.Wkv(x)
694
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
695
+ else:
696
+ assert self.num_heads_kv != self.num_heads
697
+ if not self.return_residual:
698
+ qkv = self.Wqkv(x)
699
+ else:
700
+ qkv, x = self.Wqkv(x)
701
+ q = qkv[..., : self.num_heads * self.head_dim]
702
+ kv = qkv[..., self.num_heads * self.head_dim :]
703
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
704
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
705
+ if self.dwconv:
706
+ q = rearrange(
707
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
708
+ ).contiguous()
709
+ kv = rearrange(
710
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
711
+ ).contiguous()
712
+ if (
713
+ inference_params is None
714
+ or inference_params.seqlen_offset == 0
715
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
716
+ or not self.use_flash_attn
717
+ ):
718
+ if self.rotary_emb_dim > 0:
719
+ q, kv = self.rotary_emb(
720
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
721
+ )
722
+ if inference_params is None:
723
+ if not self.checkpointing:
724
+ context = self.inner_cross_attn(q, kv, **kwargs)
725
+ else:
726
+ context = torch.utils.checkpoint.checkpoint(
727
+ self.inner_cross_attn, q, kv, **kwargs
728
+ )
729
+ else:
730
+ context = self._update_kvcache_attention(q, kv, inference_params)
731
+ else:
732
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
733
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
734
+ return out if not self.return_residual else (out, x)
735
+
mlp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2
+ # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.distributed import ProcessGroup
10
+
11
+
12
+ try:
13
+ from flash_attn.ops.activations import swiglu
14
+ except ImportError:
15
+ swiglu = None
16
+
17
+ try:
18
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
19
+ except ImportError:
20
+ ColumnParallelLinear, RowParallelLinear = None, None
21
+
22
+ try:
23
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24
+ except ImportError:
25
+ FusedMLP, ParallelFusedMLP = None, None
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ activation=F.gelu,
35
+ bias1=True,
36
+ bias2=True,
37
+ return_residual=False,
38
+ device=None,
39
+ dtype=None,
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+ out_features = out_features if out_features is not None else in_features
44
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
45
+ self.return_residual = return_residual
46
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
+ self.activation = activation
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
+
50
+ def forward(self, x):
51
+ y = self.fc1(x)
52
+ y = self.activation(y)
53
+ y = self.fc2(y)
54
+ return y if not self.return_residual else (y, x)
55
+
56
+
57
+ class ParallelMLP(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_features,
61
+ hidden_features=None,
62
+ out_features=None,
63
+ activation=F.gelu,
64
+ process_group: ProcessGroup = None,
65
+ sequence_parallel=True,
66
+ bias1=True,
67
+ bias2=True,
68
+ device=None,
69
+ dtype=None,
70
+ ):
71
+ factory_kwargs = {"device": device, "dtype": dtype}
72
+ super().__init__()
73
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
74
+ assert RowParallelLinear is not None, "Need to install fused_dense"
75
+ out_features = out_features if out_features is not None else in_features
76
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
77
+ self.fc1 = ColumnParallelLinear(
78
+ in_features,
79
+ hidden_features,
80
+ process_group,
81
+ bias=bias1,
82
+ sequence_parallel=sequence_parallel,
83
+ **factory_kwargs,
84
+ )
85
+ self.activation = activation
86
+ self.fc2 = RowParallelLinear(
87
+ hidden_features,
88
+ out_features,
89
+ process_group,
90
+ bias=bias2,
91
+ sequence_parallel=sequence_parallel,
92
+ **factory_kwargs,
93
+ )
94
+
95
+ def forward(self, x):
96
+ y = self.fc1(x)
97
+ y = self.activation(y)
98
+ y = self.fc2(y)
99
+ return y
100
+
101
+
102
+ class GatedMlp(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_features,
106
+ hidden_features=None,
107
+ out_features=None,
108
+ activation=F.sigmoid,
109
+ bias1=True,
110
+ bias2=True,
111
+ multiple_of=128,
112
+ return_residual=False,
113
+ device=None,
114
+ dtype=None,
115
+ ):
116
+ factory_kwargs = {"device": device, "dtype": dtype}
117
+ super().__init__()
118
+ out_features = out_features if out_features is not None else in_features
119
+ hidden_features = (
120
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
121
+ )
122
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
123
+ self.return_residual = return_residual
124
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
125
+ self.activation = activation
126
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
127
+
128
+ def forward(self, x):
129
+ y = self.fc1(x)
130
+ if self.activation == F.sigmoid: # Special case for GLU
131
+ y = F.glu(y, dim=-1)
132
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
133
+ y, gate = y.chunk(2, dim=-1)
134
+ y = swiglu(gate, y)
135
+ else:
136
+ y, gate = y.chunk(2, dim=-1)
137
+ y = y * self.activation(gate)
138
+ y = self.fc2(y)
139
+ return y if not self.return_residual else (y, x)
140
+
141
+
142
+ class ParallelGatedMlp(nn.Module):
143
+ """Parallel GatedMlp"""
144
+
145
+ def __init__(
146
+ self,
147
+ in_features,
148
+ process_group,
149
+ hidden_features=None,
150
+ out_features=None,
151
+ activation=F.sigmoid,
152
+ bias1=True,
153
+ bias2=True,
154
+ multiple_of=128,
155
+ sequence_parallel=True,
156
+ device=None,
157
+ dtype=None,
158
+ ):
159
+ factory_kwargs = {"device": device, "dtype": dtype}
160
+ super().__init__()
161
+ out_features = out_features if out_features is not None else in_features
162
+ hidden_features = (
163
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
164
+ )
165
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
166
+ if ColumnParallelLinear is None or RowParallelLinear is None:
167
+ raise ImportError("fused_dense is not installed")
168
+ self.fc1 = ColumnParallelLinear(
169
+ in_features,
170
+ 2 * hidden_features,
171
+ process_group,
172
+ bias=bias1,
173
+ sequence_parallel=sequence_parallel,
174
+ **factory_kwargs,
175
+ )
176
+ self.activation = activation
177
+ self.fc2 = RowParallelLinear(
178
+ hidden_features,
179
+ out_features,
180
+ process_group,
181
+ bias=bias2,
182
+ sequence_parallel=sequence_parallel,
183
+ **factory_kwargs,
184
+ )
185
+
186
+ def forward(self, x):
187
+ y = self.fc1(x)
188
+ if self.activation == F.sigmoid: # Special case for GLU
189
+ y = F.glu(y, dim=-1)
190
+ else:
191
+ y, gate = y.chunk(2, dim=-1)
192
+ y = y * self.activation(gate)
193
+ y = self.fc2(y)
194
+ return y
modeling_bert.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
6
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
7
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
8
+
9
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
10
+
11
+ import logging
12
+ import re
13
+ from collections import OrderedDict
14
+ from collections.abc import Sequence
15
+ from functools import partial
16
+ from typing import Any, Mapping
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from einops import rearrange
22
+ from transformers import BertConfig, PretrainedConfig, XLMRobertaConfig # TODO check whether to use XLMRobertaConfig
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.models.bert.modeling_bert import (
25
+ BaseModelOutputWithPoolingAndCrossAttentions,
26
+ BertForPreTrainingOutput,
27
+ )
28
+
29
+ from .bert_padding import (
30
+ index_first_axis,
31
+ index_first_axis_residual,
32
+ pad_input,
33
+ unpad_input,
34
+ )
35
+ from .block import Block
36
+ from .embedding import BertEmbeddings
37
+ from .mha import MHA
38
+ from .mlp import FusedMLP, Mlp
39
+
40
+ # from flash_attn.utils.pretrained import state_dict_from_pretrained
41
+
42
+ try:
43
+ from flash_attn.ops.fused_dense import FusedDense
44
+ except ImportError:
45
+ FusedDense = None
46
+
47
+ try:
48
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
49
+ except ImportError:
50
+ layer_norm_fn = None
51
+
52
+
53
+ try:
54
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
55
+ except ImportError:
56
+ CrossEntropyLoss = None
57
+
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
62
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
63
+ use_flash_attn = getattr(config, "use_flash_attn", False)
64
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
65
+ rotary_kwargs = {}
66
+ if config.position_embedding_type == "rotary":
67
+ rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
68
+ rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
69
+ rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
70
+ rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
71
+ mixer_cls = partial(
72
+ MHA,
73
+ num_heads=config.num_attention_heads,
74
+ cross_attn=cross_attn,
75
+ dropout=config.attention_probs_dropout_prob,
76
+ causal=False,
77
+ fused_bias_fc=fused_bias_fc,
78
+ use_flash_attn=use_flash_attn,
79
+ return_residual=return_residual,
80
+ **rotary_kwargs,
81
+ )
82
+ return mixer_cls
83
+
84
+
85
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
86
+ inner_dim = config.intermediate_size
87
+ fused_mlp = getattr(config, "fused_mlp", False)
88
+ if fused_mlp:
89
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
90
+ "fused_mlp only " "supports approximate gelu"
91
+ )
92
+ if not fused_mlp:
93
+ approximate = (
94
+ "tanh"
95
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
96
+ else "none"
97
+ )
98
+ mlp_cls = partial(
99
+ Mlp,
100
+ hidden_features=inner_dim,
101
+ activation=partial(F.gelu, approximate=approximate),
102
+ return_residual=return_residual,
103
+ )
104
+ else:
105
+ if FusedMLP is None:
106
+ raise ImportError("fused_dense is not installed")
107
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
108
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
109
+ if isinstance(mlp_checkpoint_lvl, Sequence):
110
+ assert layer_idx is not None
111
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
112
+ mlp_cls = partial(
113
+ FusedMLP,
114
+ hidden_features=inner_dim,
115
+ checkpoint_lvl=mlp_checkpoint_lvl,
116
+ return_residual=return_residual,
117
+ )
118
+ return mlp_cls
119
+
120
+
121
+ def create_block(config, layer_idx=None):
122
+ last_layer_subset = getattr(config, "last_layer_subset", False)
123
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
124
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
125
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
126
+ # one layer) so we just choose not to return residual in this case.
127
+ return_residual = not cross_attn
128
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
129
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
130
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
131
+ block = Block(
132
+ config.hidden_size,
133
+ mixer_cls,
134
+ mlp_cls,
135
+ norm_cls=norm_cls,
136
+ prenorm=False,
137
+ resid_dropout1=config.hidden_dropout_prob,
138
+ resid_dropout2=config.hidden_dropout_prob,
139
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
140
+ return_residual=return_residual,
141
+ )
142
+ return block
143
+
144
+
145
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
146
+ def _init_weights(module, initializer_range=0.02):
147
+ if isinstance(module, nn.Linear):
148
+ nn.init.normal_(module.weight, std=initializer_range)
149
+ if module.bias is not None:
150
+ nn.init.zeros_(module.bias)
151
+ elif isinstance(module, nn.Embedding):
152
+ nn.init.normal_(module.weight, std=initializer_range)
153
+ if module.padding_idx is not None:
154
+ nn.init.zeros_(module.weight[module.padding_idx])
155
+
156
+
157
+ class BertEncoder(nn.Module):
158
+ def __init__(self, config: BertConfig):
159
+ super().__init__()
160
+ self.use_flash_attn = getattr(config, "use_flash_attn", False)
161
+ self.layers = nn.ModuleList(
162
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
163
+ )
164
+
165
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
166
+ """If subset_mask is not None, we only want output for the subset of the sequence.
167
+ This means that we only compute the last layer output for these tokens.
168
+ subset_mask: (batch, seqlen), dtype=torch.bool
169
+ """
170
+ if key_padding_mask is None or not self.use_flash_attn:
171
+ mixer_kwargs = (
172
+ {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
173
+ )
174
+ for layer in self.layers:
175
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
176
+ if subset_mask is not None:
177
+ hidden_states = hidden_states[subset_mask]
178
+ else:
179
+ batch, seqlen = hidden_states.shape[:2]
180
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
181
+ hidden_states, key_padding_mask
182
+ )
183
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
184
+ if subset_mask is None:
185
+ for layer in self.layers:
186
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
187
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
188
+ else:
189
+ for layer in self.layers[:-1]:
190
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
191
+ if key_padding_mask is not None:
192
+ subset_idx = torch.nonzero(
193
+ subset_mask[key_padding_mask], as_tuple=False
194
+ ).flatten()
195
+ subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
196
+ subset_cu_seqlens = F.pad(
197
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
198
+ )
199
+ else:
200
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
201
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
202
+ subset_cu_seqlens = F.pad(
203
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
204
+ )
205
+ hidden_states_subset, hidden_states = index_first_axis_residual(
206
+ hidden_states, subset_idx
207
+ )
208
+ # It's ok to set max_seqlen_q to be much larger
209
+ mixer_kwargs = {
210
+ "x_kv": hidden_states,
211
+ "cu_seqlens": subset_cu_seqlens,
212
+ "max_seqlen": max_seqlen_in_batch,
213
+ "cu_seqlens_k": cu_seqlens,
214
+ "max_seqlen_k": max_seqlen_in_batch,
215
+ }
216
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
217
+ return hidden_states
218
+
219
+
220
+ class BertPooler(nn.Module):
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
224
+ if fused_bias_fc and FusedDense is None:
225
+ raise ImportError("fused_dense is not installed")
226
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
227
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
228
+ self.activation = nn.Tanh()
229
+
230
+ def forward(self, hidden_states, pool=True):
231
+ # We "pool" the model by simply taking the hidden state corresponding
232
+ # to the first token.
233
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
234
+ pooled_output = self.dense(first_token_tensor)
235
+ pooled_output = self.activation(pooled_output)
236
+ return pooled_output
237
+
238
+
239
+ class BertPredictionHeadTransform(nn.Module):
240
+ def __init__(self, config):
241
+ super().__init__()
242
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
243
+ if fused_bias_fc and FusedDense is None:
244
+ raise ImportError("fused_dense is not installed")
245
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
246
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
247
+ raise ImportError("Triton is not installed")
248
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
249
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
250
+ approximate = (
251
+ "tanh"
252
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
253
+ else "none"
254
+ )
255
+ self.transform_act_fn = nn.GELU(approximate=approximate)
256
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
257
+
258
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
259
+ hidden_states = self.dense(hidden_states)
260
+ hidden_states = self.transform_act_fn(hidden_states)
261
+ if not self.fused_dropout_add_ln:
262
+ hidden_states = self.layer_norm(hidden_states)
263
+ else:
264
+ hidden_states = layer_norm_fn(
265
+ hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
266
+ )
267
+ return hidden_states
268
+
269
+
270
+ class BertLMPredictionHead(nn.Module):
271
+ def __init__(self, config):
272
+ super().__init__()
273
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
274
+ if fused_bias_fc and FusedDense is None:
275
+ raise ImportError("fused_dense is not installed")
276
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
277
+
278
+ self.transform = BertPredictionHeadTransform(config)
279
+
280
+ # The output weights are the same as the input embeddings, but there is
281
+ # an output-only bias for each token.
282
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
283
+
284
+ def forward(self, hidden_states):
285
+ hidden_states = self.transform(hidden_states)
286
+ hidden_states = self.decoder(hidden_states)
287
+ return hidden_states
288
+
289
+
290
+ class BertPreTrainingHeads(nn.Module):
291
+ def __init__(self, config):
292
+ super().__init__()
293
+ self.predictions = BertLMPredictionHead(config)
294
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
295
+
296
+ def forward(self, sequence_output, pooled_output):
297
+ prediction_scores = self.predictions(sequence_output)
298
+ seq_relationship_score = self.seq_relationship(pooled_output)
299
+ return prediction_scores, seq_relationship_score
300
+
301
+
302
+ # class BertPreTrainedModel(nn.Module):
303
+ # """An abstract class to handle weights initialization and
304
+ # a simple interface for dowloading and loading pretrained models.
305
+ # """
306
+ #
307
+ # def __init__(self, config, *inputs, **kwargs):
308
+ # super().__init__()
309
+ # if not isinstance(config, BertConfig):
310
+ # raise ValueError(
311
+ # "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
312
+ # "To create a model from a Google pretrained model use "
313
+ # "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
314
+ # self.__class__.__name__, self.__class__.__name__
315
+ # )
316
+ # )
317
+ # self.config = config
318
+ #
319
+ # @classmethod
320
+ # def from_pretrained(cls, model_name, config, *inputs, **kwargs):
321
+ # """
322
+ # Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
323
+ # Download and cache the pre-trained model file if needed.
324
+ #
325
+ # Params:
326
+ # pretrained_model_name_or_path: either:
327
+ # - a path or url to a pretrained model archive containing:
328
+ # . `bert_config.json` a configuration file for the model
329
+ # . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
330
+ # - a path or url to a pretrained model archive containing:
331
+ # . `bert_config.json` a configuration file for the model
332
+ # . `model.chkpt` a TensorFlow checkpoint
333
+ # *inputs, **kwargs: additional input for the specific Bert class
334
+ # (ex: num_labels for BertForSequenceClassification)
335
+ # """
336
+ # # Instantiate model.
337
+ # model = cls(config, *inputs, **kwargs)
338
+ # load_return = model.load_state_dict(
339
+ # remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
340
+ # )
341
+ # logger.info(load_return)
342
+ # return model
343
+
344
+ class BertPreTrainedModel(PreTrainedModel):
345
+ """An abstract class to handle weights initialization and
346
+ a simple interface for dowloading and loading pretrained models.
347
+ """
348
+ config_class = XLMRobertaConfig
349
+ base_model_prefix = "bert"
350
+ supports_gradient_checkpointing = True
351
+
352
+ def _set_gradient_checkpointing(self, module, value=False):
353
+ if isinstance(module, BertEncoder):
354
+ module.gradient_checkpointing = value
355
+
356
+
357
+
358
+ class BertModel(BertPreTrainedModel):
359
+ def __init__(self, config: BertConfig, add_pooling_layer=True):
360
+ super().__init__(config)
361
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
362
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
363
+ config.vocab_size += self.pad_vocab_size_multiple - (
364
+ config.vocab_size % self.pad_vocab_size_multiple
365
+ )
366
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
367
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
368
+ raise ImportError("Triton is not installed")
369
+ assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
370
+
371
+ self.embeddings = BertEmbeddings(
372
+ config.hidden_size,
373
+ config.vocab_size,
374
+ config.max_position_embeddings,
375
+ config.type_vocab_size,
376
+ padding_idx=config.pad_token_id,
377
+ )
378
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
379
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
380
+ self.encoder = BertEncoder(config)
381
+ self.pooler = BertPooler(config) if add_pooling_layer else None
382
+
383
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
384
+
385
+ def forward(
386
+ self,
387
+ input_ids,
388
+ position_ids=None,
389
+ token_type_ids=None,
390
+ attention_mask=None,
391
+ masked_tokens_mask=None,
392
+ ):
393
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
394
+ we only want the output for the masked tokens. This means that we only compute the last
395
+ layer output for these tokens.
396
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
397
+ """
398
+ hidden_states = self.embeddings(
399
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
400
+ )
401
+ # TD [2022-12:18]: Don't need to force residual in fp32
402
+ # BERT puts embedding LayerNorm before embedding dropout.
403
+ if not self.fused_dropout_add_ln:
404
+ hidden_states = self.emb_ln(hidden_states)
405
+ else:
406
+ hidden_states = layer_norm_fn(
407
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
408
+ )
409
+ hidden_states = self.emb_drop(hidden_states)
410
+
411
+ if masked_tokens_mask is not None:
412
+ batch_size, seqlen = input_ids.shape[:2]
413
+ # We also need the first column for the CLS token
414
+ first_col_mask = torch.zeros(
415
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
416
+ )
417
+ first_col_mask[:, 0] = True
418
+ subset_mask = masked_tokens_mask | first_col_mask
419
+ else:
420
+ subset_mask = None
421
+
422
+ sequence_output = self.encoder(
423
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
424
+ )
425
+
426
+ if masked_tokens_mask is None:
427
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
428
+ else:
429
+ # TD [2022-03-01]: the indexing here is very tricky.
430
+ if attention_mask is not None:
431
+ subset_idx = subset_mask[attention_mask]
432
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
433
+ sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
434
+ else:
435
+ pool_input = sequence_output[first_col_mask[subset_mask]]
436
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
437
+ pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
438
+
439
+ return BaseModelOutputWithPoolingAndCrossAttentions(
440
+ last_hidden_state=sequence_output,
441
+ pooler_output=pooled_output,
442
+ )
443
+
444
+
445
+ class BertForPreTraining(BertPreTrainedModel):
446
+ def __init__(self, config: BertConfig):
447
+ import pdb
448
+ pdb.set_trace()
449
+ super().__init__(config)
450
+ # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
451
+ # (around 15%) to the classifier heads.
452
+ self.dense_seq_output = getattr(config, "dense_seq_output", False)
453
+ # If last_layer_subset, we only need the compute the last layer for a subset of tokens
454
+ # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
455
+ self.last_layer_subset = getattr(config, "last_layer_subset", False)
456
+ if self.last_layer_subset:
457
+ assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
458
+ use_xentropy = getattr(config, "use_xentropy", False)
459
+ if use_xentropy and CrossEntropyLoss is None:
460
+ raise ImportError("xentropy_cuda is not installed")
461
+ loss_cls = (
462
+ nn.CrossEntropyLoss
463
+ if not use_xentropy
464
+ else partial(CrossEntropyLoss, inplace_backward=True)
465
+ )
466
+
467
+ self.bert = BertModel(config)
468
+ self.cls = BertPreTrainingHeads(config)
469
+ self.mlm_loss = loss_cls(ignore_index=0)
470
+ self.nsp_loss = loss_cls(ignore_index=-1)
471
+
472
+ # Initialize weights and apply final processing
473
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
474
+ self.tie_weights()
475
+
476
+ def tie_weights(self):
477
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
478
+
479
+ def forward(
480
+ self,
481
+ input_ids,
482
+ position_ids=None,
483
+ token_type_ids=None,
484
+ attention_mask=None,
485
+ labels=None,
486
+ next_sentence_label=None,
487
+ ):
488
+ """
489
+ If labels are provided, they must be 0 for masked out tokens (as specified in the attention
490
+ mask).
491
+ Outputs:
492
+ if `labels` and `next_sentence_label` are not `None`:
493
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
494
+ sentence classification loss.
495
+ if `labels` or `next_sentence_label` is `None`:
496
+ Outputs a tuple comprising
497
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
498
+ - the next sentence classification logits of shape [batch_size, 2].
499
+
500
+ """
501
+ masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
502
+ outputs = self.bert(
503
+ input_ids,
504
+ position_ids=position_ids,
505
+ token_type_ids=token_type_ids,
506
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
507
+ masked_tokens_mask=masked_tokens_mask,
508
+ )
509
+ sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
510
+ if self.dense_seq_output and labels is not None:
511
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
512
+ if not self.last_layer_subset:
513
+ sequence_output = index_first_axis(
514
+ rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
515
+ )
516
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
517
+
518
+ total_loss = None
519
+ if labels is not None and next_sentence_label is not None:
520
+ if (
521
+ self.dense_seq_output and labels is not None
522
+ ): # prediction_scores are already flattened
523
+ masked_lm_loss = self.mlm_loss(
524
+ prediction_scores, labels.flatten()[masked_token_idx]
525
+ )
526
+ else:
527
+ masked_lm_loss = self.mlm_loss(
528
+ rearrange(prediction_scores, "... v -> (...) v"),
529
+ rearrange(labels, "... -> (...)"),
530
+ )
531
+ next_sentence_loss = self.nsp_loss(
532
+ rearrange(seq_relationship_score, "... t -> (...) t"),
533
+ rearrange(next_sentence_label, "... -> (...)"),
534
+ )
535
+ total_loss = masked_lm_loss.float() + next_sentence_loss.float()
536
+
537
+ return BertForPreTrainingOutput(
538
+ loss=total_loss,
539
+ prediction_logits=prediction_scores,
540
+ seq_relationship_logits=seq_relationship_score,
541
+ )
542
+
543
+
544
+ def remap_state_dict(state_dict, config: PretrainedConfig):
545
+ """
546
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
547
+ """
548
+
549
+ # LayerNorm
550
+ def key_mapping_ln_gamma_beta(key):
551
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
552
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
553
+ return key
554
+
555
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
556
+
557
+ # Layers
558
+ def key_mapping_layers(key):
559
+ return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
560
+
561
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
562
+
563
+ # LayerNorm
564
+ def key_mapping_ln(key):
565
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
566
+ key = re.sub(
567
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
568
+ r"bert.encoder.layers.\1.norm1.\2",
569
+ key,
570
+ )
571
+ key = re.sub(
572
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
573
+ r"bert.encoder.layers.\1.norm2.\2",
574
+ key,
575
+ )
576
+ key = re.sub(
577
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
578
+ r"cls.predictions.transform.layer_norm.\1",
579
+ key,
580
+ )
581
+ return key
582
+
583
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
584
+
585
+ # MLP
586
+ def key_mapping_mlp(key):
587
+ key = re.sub(
588
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
589
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
590
+ key,
591
+ )
592
+ key = re.sub(
593
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
594
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
595
+ key,
596
+ )
597
+ return key
598
+
599
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
600
+
601
+ # Attention
602
+ last_layer_subset = getattr(config, "last_layer_subset", False)
603
+ for d in range(config.num_hidden_layers):
604
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
605
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
606
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
607
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
608
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
609
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
610
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
611
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
612
+ [Wq, Wk, Wv], dim=0
613
+ )
614
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
615
+ else:
616
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
617
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
618
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
619
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
620
+
621
+ def key_mapping_attn(key):
622
+ return re.sub(
623
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
624
+ r"bert.encoder.layers.\1.mixer.out_proj.\2",
625
+ key,
626
+ )
627
+
628
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
629
+
630
+ def key_mapping_decoder_bias(key):
631
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
632
+
633
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
634
+
635
+ # Word embedding
636
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
637
+ if pad_vocab_size_multiple > 1:
638
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
639
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
640
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
641
+ )
642
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
643
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
644
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
645
+ )
646
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
647
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
648
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
649
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
650
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
651
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
652
+ )
653
+
654
+ return state_dict
655
+
656
+
657
+ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
658
+ """
659
+ Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
660
+
661
+ This function is meant to be the inverse of remap_state_dict.
662
+ """
663
+ # Word embedding
664
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
665
+ if pad_vocab_size_multiple > 1:
666
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
667
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
668
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
669
+ # unpad embeddings
670
+ state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
671
+ : config.orig_vocab_size, :
672
+ ]
673
+ state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
674
+ state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
675
+
676
+ for d in range(config.num_hidden_layers):
677
+ last_layer_subset = getattr(config, "last_layer_subset", False)
678
+ if not last_layer_subset or d != (config.num_hidden_layers - 1):
679
+ Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
680
+ Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
681
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
682
+ : Wqkv_weights.shape[0] // 3, :
683
+ ]
684
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
685
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
686
+ ]
687
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
688
+ 2 * Wqkv_weights.shape[0] // 3 :, :
689
+ ]
690
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
691
+ : Wqkv_biases.shape[0] // 3
692
+ ]
693
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
694
+ Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
695
+ ]
696
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
697
+ 2 * Wqkv_biases.shape[0] // 3 :
698
+ ]
699
+ else:
700
+ Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
701
+ Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
702
+ Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
703
+ Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
704
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
705
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
706
+ : Wkv_weights.shape[0] // 2, :
707
+ ]
708
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
709
+ Wkv_weights.shape[0] // 2 :, :
710
+ ]
711
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
712
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
713
+ : Wkv_biases.shape[0] // 2
714
+ ]
715
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
716
+ Wkv_biases.shape[0] // 2 :
717
+ ]
718
+
719
+ def inv_key_mapping_ln(key):
720
+ key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
721
+ key = re.sub(
722
+ r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
723
+ r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
724
+ key,
725
+ )
726
+ key = re.sub(
727
+ r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
728
+ r"bert.encoder.layers.\1.output.LayerNorm.\2",
729
+ key,
730
+ )
731
+ key = re.sub(
732
+ r"cls.predictions.transform.layer_norm.(weight|bias)",
733
+ r"cls.predictions.transform.LayerNorm.\1",
734
+ key,
735
+ )
736
+ return key
737
+
738
+ def inv_key_mapping_ln_gamma_beta(key):
739
+ key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
740
+ key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
741
+ return key
742
+
743
+ def inv_key_mapping_layers(key):
744
+ return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
745
+
746
+ def inv_key_mapping_mlp(key):
747
+ key = re.sub(
748
+ r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
749
+ r"bert.encoder.layer.\1.intermediate.dense.\2",
750
+ key,
751
+ )
752
+ key = re.sub(
753
+ r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
754
+ r"bert.encoder.layer.\1.output.dense.\2",
755
+ key,
756
+ )
757
+ return key
758
+
759
+ def inv_key_mapping_attn(key):
760
+ return re.sub(
761
+ r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
762
+ r"bert.encoder.layer.\1.attention.output.dense.\2",
763
+ key,
764
+ )
765
+
766
+ def inv_key_mapping_decoder_bias(key):
767
+ return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
768
+
769
+ state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
770
+ state_dict = OrderedDict(
771
+ (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
772
+ )
773
+ state_dict = OrderedDict(
774
+ (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
775
+ )
776
+ state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
777
+ state_dict = OrderedDict(
778
+ (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
779
+ )
780
+ state_dict = OrderedDict(
781
+ (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
782
+ )
783
+
784
+ return state_dict
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61bdee1ea6ae50618c387234ae94a500df9ce095e59d836b8aefef33e9d8884e
3
+ size 1112222546
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 512}