Text2Text Generation
Transformers
PyTorch
French
flash_t5
custom_code
bourdoiscatie commited on
Commit
0743270
·
verified ·
1 Parent(s): 66ded02

Add FAT5-small

Browse files
README.md ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: fr
3
+ datasets:
4
+ - uonlp/CulturaX
5
+ - wikimedia/wikipedia
6
+ - eckendoerffer/justice_fr
7
+ - bigcode/the-stack-dedup
8
+ metrics:
9
+ - f1
10
+ - exact_match
11
+ library_name: transformers
12
+ co2_eq_emissions: 13.5
13
+ license: apache-2.0
14
+ ---
15
+
16
+ # FAT5 (Flash Attention T5) ⚡
17
+
18
+ <br>
19
+
20
+ <div align="center" style="line-height: 1;">
21
+ <a href="https://huggingface.co/spaces/CATIE-AQ/FAT5-report" style="margin: 2px;">
22
+ <img alt="Blog post (EN)" src="https://img.shields.io/badge/📝_Blog_post-English_version-f5de53?&color=blue" style="display: inline-block; vertical-align: middle;"/>
23
+ </a>
24
+ <a href="https://huggingface.co/spaces/CATIE-AQ/FAT5-rapport" style="margin: 2px;">
25
+ <img alt="Blog post (FR)" src="https://img.shields.io/badge/📝_Blog_post-French_version-f5de53?&color=blue" style="display: inline-block; vertical-align: middle;"/>
26
+ <a href="https://huggingface.co/collections/CATIE-AQ/catie-french-fat5-ul2-677697a35feea336389d6403" target="_blank" style="margin: 2px;">
27
+ <img alt="Hugging Face collection" src="https://img.shields.io/badge/🤗_Hugging_Face-Collection-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
28
+ </a>
29
+ </a>
30
+ <a href="https://github.com/catie-aq/flashT5" target="_blank" style="margin: 2px;">
31
+ <img alt="FAT5 GitHub" src="https://img.shields.io/github/stars/catie-aq/flashT5?style=social" style="display: inline-block; vertical-align: middle;"/>
32
+ </a>
33
+ </a>
34
+ <a href="https://opensource.org/licenses/Apache-2.0" target="_blank" style="margin: 2px;">
35
+ <img alt="License: Apache 2.0" src="https://img.shields.io/badge/License-Apache--2.0-green.svg" style="display: inline-block; vertical-align: middle;"/>
36
+ </a>
37
+ </div>
38
+
39
+ ## Introduction
40
+
41
+ FAT5 (for Flash Attention T5) is an implementation of T5 in PyTorch with an [UL2](https://arxiv.org/abs/2205.05131) objective optimized for GPGPU for both training and inference. It uses an experimental feature for using [Flash Attention](https://arxiv.org/abs/2307.08691) (v2) with relative position encoding biases that allow to train or finetune the model on longer sequence lengths than the original T5. It also has support for other positional embeddings such as [RoPE](https://arxiv.org/abs/2104.09864), [ALiBi](https://arxiv.org/abs/2108.12409) or [FIRE](https://arxiv.org/abs/2310.04418).
42
+ This methodology enabled us to efficiently pretrain as a proof of concept a T5 with 147M parameters in French in a reasonable time (1,461H to see 419B tokens) and with limited resources (1 single A100; i.e. a computational budget of around €2,200) which you'll find the weights in this repo.
43
+ To achieve this, we designed CUDA/Triton kernels to make Flash Attention compatible with T5, and to provide linear inference, thus extending the context size that can be taken into account by the model.
44
+ Other optimizations have also been implemented, as detailed in a subsequent [blog post](https://huggingface.co/spaces/CATIE-AQ/FAT5-report).
45
+
46
+ <br>
47
+
48
+ ## Motivation
49
+
50
+ While a lot of effort has been focused on optimizing decoder-only models, in many practical applications older architectures remains useful.
51
+ We focus on [T5](http://jmlr.org/papers/v21/20-074.html), an encoder-decoder architecture exhibiting very decent performances for [instruction tuning](https://arxiv.org/pdf/2306.04757.pdf) or even sometimes outperforming much larger models when [finetuned](https://arxiv.org/pdf/2402.00841.pdf). Moreover it’s a natural architecture while considering [distillation](https://arxiv.org/abs/2305.02301) of much larger models.
52
+
53
+ A critical limitation of this architecture is the length of the sequence that these models can deal with due to the quadratic size in memory. While this quadratic term cannot be removed without considering other form of attention (like for [LongT5](https://arxiv.org/abs/2112.07916)), it can still be alleviated to accomodate longer sequence lengths.
54
+ Another limitation is the pre-training time, since techniques such as Flash Attention are not available for this architecture.
55
+
56
+ <br>
57
+
58
+ ## Our work
59
+
60
+ We used the [nanoT5](https://github.com/PiotrNawrot/nanoT5?tab=readme-ov-file#cite) implementation as the base for our work.
61
+
62
+ We worked on optimizing the core component of the model, which is the attention part. We used the Flash Attention (v2) that optimize both the memory usage and the efficient use of Tensor Cores.
63
+
64
+ We support different implementation of attention biases:
65
+ - Full attention biases with Flash Attention 2 using this [PR](https://github.com/Dao-AILab/flash-attention/pull/617)
66
+ - T5-like relative position encoding biases with Flash Attention 2 using this [PR](https://github.com/Dao-AILab/flash-attention/pull/956)
67
+ - Full attention biases with a [triton implementation](src/model/ops/flash_attention_v2_bias.py) of Flash Attention 2
68
+
69
+ <center>
70
+ <img src="https://github.com/catie-aq/flashT5/raw/main/assets/FAT5_dark.gif" alt="FAT5_dark" width="100%">
71
+ </center>
72
+
73
+ Other parts of the architecture where optimized using [ad-hoc Triton kernels](https://github.com/catie-aq/flashT5/tree/main/src/model/ops) for the cross-entropy (and z-loss) and layernorm.
74
+
75
+ For pretext tasks during pre-training, we use the [UL2](https://arxiv.org/abs/2205.05131v3) mixture of denoisers with the following 7 tasks:
76
+
77
+ ```python
78
+ denoiser_list=[
79
+ {"mu": 3.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[R]"},
80
+ {"mu": 8.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[R]"},
81
+ {"mu": 4.0, "r": 0.0, "max_spans": 1, "prefix": "[S]"},
82
+ {"mu": 3.0, "r": 0.5, "max_spans": max_token_length, "prefix": "[X]"},
83
+ {"mu": 8.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[X]"},
84
+ {"mu": 64.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[X]"},
85
+ {"mu": 64.0, "r": 0.5, "max_spans": max_token_length, "prefix": "[X]"}]
86
+
87
+ denoiser_proportions=[0.165, 0.165, 0.34, 0.0825, 0.0825, 0.0825, 0.0825]
88
+ ```
89
+ where `mu`: the span size, `r`: the % of masking in the span and `prefix`: the type of the pretext task (the meaning of the letters `[R]`, `[S]` and `[X]` is described [here](https://huggingface.co/google/ul2#mixture-of-denoisers)).
90
+
91
+ As there was no implementation available in PyTorch, we [added one](https://github.com/catie-aq/flashT5/blob/main/src/data/data_collator_ul2.py) and adapted a dynamic batching mechanism to reduce padding in the model.
92
+
93
+ <br>
94
+
95
+ ## Benchmarks
96
+
97
+ #### TFLOPS
98
+
99
+ The number of TFLOPS (trillions of floating-point calculations a processor can perform in one second) is probably the most eloquent measure of the impact of the optimizations carried out.
100
+ We therefore compare four approaches:
101
+ • the SPDA (Scaled Dot Product Attention) implementation with full bias,
102
+ • the same implementation but in Triton,
103
+ • the Flash Attention RPE implementation (our kernel),
104
+ • the Flash Attention implementation, i.e. without bias. We've included it here for reference, as it's unusable in practice for a T5.
105
+
106
+ For the forward pass, we have:
107
+
108
+ <center>
109
+ <img src="https://github.com/catie-aq/flashT5/raw/main/assets/benchmarks/FWD-causal-True_dark.png" alt="FWD-causal-True_dark" width="100%">
110
+ </center>
111
+
112
+ For the forward pass, we can see that the Triton approach achieves 1.34 times more FLOPS than SPDA, and that the Flash Attention RPE approach achieves 1.99 times more FLOPS than SPDA.
113
+ We can also see that our bf16 implementation is equivalent to fp16 (doing even better at size 512).
114
+
115
+ For the backward pass, we have:
116
+
117
+ <center>
118
+ <img src="https://github.com/catie-aq/flashT5/raw/main/assets/benchmarks/BWD-causal-True_dark.png" alt="BWD-causal-True_dark" width="100%">
119
+ </center>
120
+
121
+ For the backward pass, the Triton implementation is less efficient than SPDA, with 0.71 times the FLOPS of SPDA. The Flash Attention RPE implementation is more or less equivalent to SPDA (1.018 times more FLOPS).
122
+ We can also observe that Triton in head_dim 64 is more efficient than Triton in head_dim 128.
123
+
124
+
125
+ #### Torch vs Triton
126
+
127
+ We mentioned above that we had optimized parts of the architecture using ad hoc Triton kernels, namely the cross-entropy and RMSNorm layer. The following benchmarks should illustrate why.
128
+ For cross-entropy, we obtain a forward pass 7 to 11.4 times faster, a backward pass 3.26 to 3.75 times faster and a memory reduced by a factor of 4:
129
+
130
+ <center>
131
+ <img src="https://github.com/catie-aq/flashT5/raw/main/assets/benchmarks/CE_dark.png" alt="CE_dark" width="100%">
132
+ </center>
133
+
134
+ For the RMSNorm layer, we obtain a 3 to 5 times faster forward pass, a 2.33 to 4.33 times faster reverse pass and a memory reduced by a factor of 3.2:
135
+ <center>
136
+ <img src="https://github.com/catie-aq/flashT5/raw/main/assets/benchmarks/LN_dark.png" alt="LN_dark" width="100%">
137
+ </center>
138
+
139
+ Note that all benchmark graphs can be generated automatically using the following [code](https://github.com/catie-aq/flashT5/tree/main/benchmarks).
140
+
141
+ <br>
142
+
143
+ ## Applications
144
+
145
+ ### To French
146
+ We've pretrained a small (147M parameters) FAT5-UL2 in French. This is the model you'll find in this Hugging Face repo.
147
+ The dataset we used is a mixture of [CulturaX](https://huggingface.co/datasets/uonlp/CulturaX), [Wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia), [justice_fr](https://huggingface.co/datasets/eckendoerffer/justice_fr) and [The Stack](https://huggingface.co/datasets/bigcode/the-stack-dedup).
148
+ Our tokenizer of size 32,768 (8**5) is trained on CulturaX and The Stack.
149
+ Our model is pre-trained on a sequence of 1,024 tokens on a single A100 for 1,461H (= 419.4B tokens seen) at an estimated cost of €2,200.
150
+
151
+ Carbon emissions for the pretraining of this model were estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact:
152
+ • Hardware Type: A100 PCIe 40/80GB
153
+ • Hours used: 1,461h
154
+ • Cloud Provider: Private Infrastructure
155
+ • Carbon Efficiency (kg/kWh): 0.03696 kg (estimated from [electricitymaps](https://app.electricitymaps.com/zone/FR) (average carbon intensity in France average between October 18, 2024 and December 19, 2024)
156
+ • **Carbon Emitted** *(Power consumption x Time x Carbon produced based on location of power grid)*: **13.5 kg eq. CO2**.
157
+
158
+
159
+ ### To other language
160
+
161
+ Our contribution focuses on French, with the pre-training and finetuning of models for comparison against French benchmarks. For other languages, we can't afford to do the same kind of work.
162
+
163
+ Nevertheless, to ensure that it can be used in other languages, we have developed a [code](https://github.com/catie-aq/flashT5/blob/main/convert_huggingface_t5.py) for adapting already pre-trained (m)T5/FLAN-T5 weights to our method. In this way, we hope users of a specific language will be able to efficiently continue pre-training one of these models to adapt it to more recent data, for example.
164
+ Note, however, that this adaptation is limited, since the additional pre-training will have to be carried out within the precision of the original model. For example, if the model's weights are in FP32 (which is the case with the FLAN-T5), training will not be as fast as with the FAT5, which is in BF16.
165
+
166
+ For English speakers, we have already adapted the weights of the various versions of the [FLANT-T5](https://arxiv.org/abs/2210.11416) to our method. All weights can be found in this Hugging Face [collection](https://huggingface.co/collections/CATIE-AQ/catie-english-fat5-flan-662b679a8e855c7c0137d69e).
167
+ To use one of the models, simply do the command:
168
+
169
+ ```
170
+ from transformers import AutoModel, AutoTokenizer
171
+ model = AutoModel.from_pretrained("CATIE-AQ/FAT5-small-flan-en", trust_remote_code=True)
172
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
173
+ ```
174
+
175
+ <br>
176
+
177
+ ## Pretraining
178
+
179
+ If you want to pre-train your own model (to be specialized in a specific domain for example, and thus benefit from a custom tokenizer), we included a [tutorial](examples/minipile) to pretrain a small model on [minipile](https://huggingface.co/datasets/JeanKaddour/minipile) to show how it should be done.
180
+ You can find the documentation of the model configuration file [here](docs/configuration_file.md).
181
+ Note that we tested and trained the model of the tutorial on A100. It may or may not work with other GPUs.
182
+
183
+ <br>
184
+
185
+ <!--
186
+ ## Finetuning
187
+
188
+ Once you've pre-trained your model, you'll want to finetune it. Because it's a custom model (hence the need for `trust_remote_code=True` to load it), it's currently causing some difficulties due to flaws in Hugging Face's Transformers library (during the `push_to_hub`, for example). This should be resolved in January 2025, as we're working on porting the FAT5 directly into Transformers with the help of their lovely team 🤗
189
+ -->
190
+
191
+ ## Roadmap
192
+
193
+ We invite you to consult the “Next stage” section of the blog post.
194
+
195
+ <br>
196
+
197
+ ## Citation
198
+ ```
199
+ The DOI must be generated once model in public
200
+ ```
201
+
202
+ <br>
203
+
204
+ ## License
205
+ [Apache-2.0 license](https://github.com/catie-aq/flashT5/tree/main?tab=Apache-2.0-1-ov-file#readme)
206
+
207
+ <br>
208
+
209
+ ## Ackowledgment
210
+
211
+ We use the following repos and thanks the authors for this:
212
+ - [nanoT5](https://github.com/PiotrNawrot/nanoT5) for the simple implementation and the optimizer.
213
+ - [Flash attention](https://github.com/Dao-AILab/flash-attention) for the groundbreaking algorithm for computing attention.
214
+ - [Hugging Face](https://github.com/huggingface/transformers) for their excellent library.
215
+ - [FlagAttention](https://github.com/FlagOpen/FlagAttention) for the implementation of FA2 in Triton.
216
+ - [Unsloth](https://github.com/unslothai/unsloth) for the simple Triton kernels of the cross-entropy and layernorm that we adapted to our usage.
217
+ - [TurboT5](https://github.com/Knowledgator/TurboT5) for the improvement of the February 2024 version of our work.
218
+
219
+ This work was support by the [Vaniila platform](http://vaniila.ai/).<br>
220
+ <div align="center">
221
+ <a href="[https://example.com](http://vaniila.ai/)" target="_blank">
222
+ <img src="https://www.vaniila.ai/wp-content/uploads/2020/02/Vaniila_bleu_horizontal.png" alt="Vaniila Logo" width="200">
223
+ </a>
224
+ </div>
adamw_scaled.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from torch.optim import Optimizer
5
+ from torch.optim.optimizer import _default_to_fused_or_foreach
6
+ from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
7
+ from typing import Iterable, Tuple
8
+ from torch import nn, Tensor
9
+
10
+ class AdamWScale(Optimizer):
11
+ """
12
+ This AdamW implementation is copied from Huggingface.
13
+ We modified it with Adagrad scaling by rms of a weight tensor
14
+
15
+ Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
16
+ Regularization](https://arxiv.org/abs/1711.05101).
17
+
18
+ Parameters:
19
+ params (`Iterable[nn.parameter.Parameter]`):
20
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
21
+ lr (`float`, *optional*, defaults to 1e-3):
22
+ The learning rate to use.
23
+ betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
24
+ Adam's betas parameters (b1, b2).
25
+ eps (`float`, *optional*, defaults to 1e-6):
26
+ Adam's epsilon for numerical stability.
27
+ weight_decay (`float`, *optional*, defaults to 0.0):
28
+ Decoupled weight decay to apply.
29
+ kahan_sum (`bool`, *optional*, defaults to False):
30
+ Whether to use Kahan summation for updating parameters.
31
+ foreach (`bool`, *optional*, defaults to False):
32
+ Whether to use the foreach implementation.
33
+ correct_bias (`bool`, *optional*, defaults to True):
34
+ Whether to correct bias in Adam.
35
+ use_state_dtype (`torch.dtype`, *optional*, defaults to None):
36
+ The dtype to use for optimizer state. If None, use the default dtype.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ params: Iterable[nn.parameter.Parameter],
42
+ lr: float = 1e-3,
43
+ betas: Tuple[float, float] = (0.9, 0.999),
44
+ eps: float = 1e-6,
45
+ weight_decay: float = 0.0,
46
+ kahan_sum: bool = False,
47
+ foreach: bool = False,
48
+ correct_bias: bool = True,
49
+ use_state_dtype: torch.dtype = None
50
+ ):
51
+ if lr < 0.0:
52
+ raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
53
+ if not 0.0 <= betas[0] < 1.0:
54
+ raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
55
+ if not 0.0 <= betas[1] < 1.0:
56
+ raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
57
+ if not 0.0 <= eps:
58
+ raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
59
+
60
+ assert not (foreach and use_state_dtype is not None), "foreach is not supported with use_state_dtype"
61
+
62
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, \
63
+ kahan_sum=kahan_sum, correct_bias=correct_bias, use_state_dtype=use_state_dtype)
64
+
65
+ super().__init__(params, defaults)
66
+
67
+ @staticmethod
68
+ def _rms(tensor):
69
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
70
+
71
+ @torch.no_grad()
72
+ def step(self, closure=None):
73
+ """
74
+ Performs a single optimization step.
75
+
76
+ Arguments:
77
+ closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
78
+ """
79
+ loss = None
80
+ if closure is not None:
81
+ loss = closure()
82
+
83
+ for group in self.param_groups:
84
+ params, grads, exp_avgs, exp_avg_sqs, steps, kahan_comps = [], [], [], [], [], []
85
+
86
+ # Initialization
87
+ for p in group['params']:
88
+ if p.grad is None:
89
+ continue
90
+
91
+ params.append(p)
92
+ if p.grad.is_sparse:
93
+ raise RuntimeError('AdamWScale does not support sparse gradients')
94
+ grads.append(p.grad)
95
+
96
+ state = self.state[p]
97
+
98
+ # State initialization
99
+ if "kahan_comp" not in state:
100
+ state['step'] = torch.tensor(0, dtype=torch.int32, device=p.device)
101
+
102
+ if group["use_state_dtype"] in [torch.float16, torch.bfloat16]:
103
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=group["use_state_dtype"])
104
+ state['exp_avg_sq'] = torch.zeros_like(p, device=p.device, dtype=group["use_state_dtype"])
105
+ else:
106
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
107
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
108
+
109
+ if group["kahan_sum"] and p.dtype in [torch.float16, torch.bfloat16]:
110
+ state["kahan_comp"] = torch.zeros_like(p, memory_format=torch.preserve_format)
111
+ else:
112
+ state["kahan_comp"] = None
113
+ group["kahan_sum"] = False
114
+
115
+ exp_avgs.append(state['exp_avg'])
116
+ exp_avg_sqs.append(state['exp_avg_sq'])
117
+ kahan_comps.append(state["kahan_comp"])
118
+ steps.append(state["step"])
119
+
120
+ torch._foreach_add_(steps, 1)
121
+
122
+ # AdamW step
123
+ if group["foreach"] and _default_to_fused_or_foreach(params, False, False):
124
+ self._foreach_adamwscaled(params,
125
+ grads,
126
+ exp_avgs,
127
+ exp_avg_sqs,
128
+ steps,
129
+ kahan_comps,
130
+ group["lr"],
131
+ group["betas"][0],
132
+ group["betas"][1],
133
+ group["weight_decay"],
134
+ group["eps"],
135
+ group["kahan_sum"],
136
+ group["correct_bias"])
137
+ else:
138
+ self._adamwscaled(params,
139
+ grads,
140
+ exp_avgs,
141
+ exp_avg_sqs,
142
+ steps,
143
+ kahan_comps,
144
+ group["lr"],
145
+ group["betas"][0],
146
+ group["betas"][1],
147
+ group["weight_decay"],
148
+ group["eps"],
149
+ group["kahan_sum"],
150
+ group["correct_bias"])
151
+
152
+ return loss
153
+
154
+ def _adamwscaled(self,
155
+ params: list[Tensor],
156
+ grads: list[Tensor],
157
+ exp_avgs: list[Tensor],
158
+ exp_avg_sqs: list[Tensor],
159
+ steps: list[Tensor],
160
+ kahan_comps: list[Tensor],
161
+ lr: float,
162
+ beta1: float,
163
+ beta2: float,
164
+ weight_decay: float,
165
+ eps: float,
166
+ do_kahan_sum: bool,
167
+ correct_bias: bool):
168
+
169
+ for i, p in enumerate(params):
170
+
171
+ exp_avg, exp_avg_sq, grad, step, kahan_comp = exp_avgs[i], exp_avg_sqs[i], grads[i], steps[i], kahan_comps[i]
172
+
173
+ # Decay the first and second moment running average coefficient
174
+ # In-place operations to update the averages at the same time
175
+ exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
176
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1.0 - beta2))
177
+ denom = exp_avg_sq.sqrt().add_(eps)
178
+
179
+ step_size = lr
180
+ if correct_bias: # No bias correction for Bert
181
+ bias_correction1 = 1.0 - beta1 ** step
182
+ bias_correction2 = 1.0 - beta2 ** step
183
+ step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
184
+
185
+ # Adapt Step from Adafactor
186
+ step_size = step_size * max(1e-3, self._rms(p.data))
187
+
188
+ if do_kahan_sum:
189
+ # Adam step
190
+ kahan_comp.addcdiv_(exp_avg, denom, value=-step_size)
191
+
192
+ # update weights with kahan compensation using dev_grads as temp buffer
193
+ grad.copy_(p)
194
+ p.add_(kahan_comp)
195
+
196
+ # save error back to kahan compensation for next iteration
197
+ grad.sub_(p, alpha=1)
198
+ kahan_comp.add_(grad, alpha=1)
199
+ else:
200
+ p.addcdiv_(exp_avg, denom, value=-step_size)
201
+
202
+ # Just adding the square of the weights to the loss function is *not*
203
+ # the correct way of using L2 regularization/weight decay with Adam,
204
+ # since that will interact with the m and v parameters in strange ways.
205
+ #
206
+ # Instead we want to decay the weights in a manner that doesn't interact
207
+ # with the m/v parameters. This is equivalent to adding the square
208
+ # of the weights to the loss with plain (non-momentum) SGD.
209
+ # Add weight decay at the end (fixed version)
210
+ if weight_decay > 0.0:
211
+ p.add_(p, alpha=(-lr * weight_decay))
212
+
213
+ def _foreach_adamwscaled(self,
214
+ params: list[Tensor],
215
+ grads: list[Tensor],
216
+ exp_avgs: list[Tensor],
217
+ exp_avg_sqs: list[Tensor],
218
+ steps: list[Tensor],
219
+ kahan_comps: list[Tensor],
220
+ lr: float,
221
+ beta1: float,
222
+ beta2: float,
223
+ weight_decay: float,
224
+ eps: float,
225
+ do_kahan_sum: bool,
226
+ correct_bias: bool):
227
+
228
+ grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, kahan_comps])
229
+
230
+ for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_kahan_comps), _) in grouped_tensors.items():
231
+ # Foreach implementation
232
+ torch._foreach_mul_(dev_exp_avgs, beta1)
233
+ torch._foreach_add_(dev_exp_avgs, dev_grads, alpha=1 - beta1)
234
+
235
+ torch._foreach_mul_(dev_exp_avg_sqs, beta2)
236
+ torch._foreach_addcmul_(dev_exp_avg_sqs, dev_grads, dev_grads, 1 - beta2)
237
+
238
+ # Compute denominator
239
+ torch._foreach_copy_(dev_grads, dev_exp_avg_sqs)
240
+ torch._foreach_sqrt_(dev_grads)
241
+ torch._foreach_add_(dev_grads, eps)
242
+
243
+ step_size = [torch.tensor(lr, dtype=torch.float32, device=p.device) for p in dev_params]
244
+
245
+ if correct_bias:
246
+ torch._foreach_mul_(step_size,
247
+ [torch.tensor((math.sqrt(1 - beta2 ** steps[i].item()) / (1 - beta1 ** steps[i].item()) ), dtype=torch.float32, device=p.device)
248
+ for i, p in enumerate(dev_params)])
249
+
250
+ # Adapt step size using RMS of parameters
251
+ rms_p = torch._foreach_norm(dev_params)
252
+ numel = [torch.tensor(math.sqrt(p.numel())) for p in dev_params]
253
+ torch._foreach_div_(rms_p, numel)
254
+ torch._foreach_maximum_(rms_p, 1e-3)
255
+
256
+ torch._foreach_mul_(step_size, rms_p)
257
+ torch._foreach_div_(dev_grads, step_size)
258
+
259
+ # explicitly delete tensors when not used
260
+ del rms_p
261
+ del numel
262
+ del step_size
263
+
264
+ # Update parameters
265
+ if do_kahan_sum:
266
+ # Adam step
267
+ torch._foreach_addcdiv_(dev_kahan_comps, dev_exp_avgs, dev_grads, value=-1)
268
+
269
+ # update weights with kahan compensation using dev_grads as temp buffer
270
+ torch._foreach_copy_(dev_grads, dev_params)
271
+ torch._foreach_add_(dev_params, dev_kahan_comps, alpha=1)
272
+
273
+ # save error back to kahan compensation for next iteration
274
+ torch._foreach_sub_(dev_grads, dev_params, alpha=1)
275
+ torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1)
276
+ else:
277
+ torch._foreach_addcdiv_(dev_params, dev_exp_avgs, dev_grads, value=-1)
278
+
279
+ # Weight decay
280
+ if weight_decay > 0.0:
281
+ torch._foreach_add_(dev_params, dev_params, alpha=-weight_decay * lr)
attn_ref.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def attn_ref(q, k, v, b, sm_scale, dropout_p=0.0, causal=False, upcast=False):
4
+ if upcast:
5
+ q, k, v = q.float(), k.float(), v.float()
6
+ if b is not None:
7
+ b = b.float()
8
+
9
+ if b is not None:
10
+ if (b.shape[0] != q.shape[0]) or (b.shape[1] != q.shape[1]):
11
+ b = b.expand(q.shape[0], q.shape[1], q.shape[2], k.shape[2])
12
+
13
+ ms = torch.arange(q.shape[2], device=q.device).unsqueeze(-1)
14
+ ns = torch.arange(k.shape[2], device=q.device)
15
+
16
+ p = torch.matmul(q, k.transpose(2, 3))
17
+ p *= sm_scale
18
+ if b is not None:
19
+ p += b
20
+
21
+ if causal:
22
+ p = torch.where(ms + k.shape[2] - q.shape[2] >= ns, p, float("-inf"))
23
+
24
+ p = torch.softmax(p.float(), dim=-1).to(q.dtype)
25
+ if dropout_p > 0.0:
26
+ p = torch.dropout(p, dropout_p, train=True)
27
+
28
+ ref_out = torch.matmul(p, v)
29
+ return ref_out
config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alibi_mode": "symetric",
3
+ "architectures": [
4
+ "FlashT5ForConditionalGeneration"
5
+ ],
6
+ "attention_dropout_rate": 0.0,
7
+ "attention_scale": 1.0,
8
+ "attention_type": "triton",
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_flash_t5.FlashT5Config",
11
+ "AutoModel": "modeling_flash_t5.FlashT5EncoderModel",
12
+ "AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
13
+ "AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
14
+ "AutoModelForSequenceClassification": "custom_heads_flash_t5.FlashT5ForSequenceClassification",
15
+ "AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification"
16
+ },
17
+ "classifier_dropout": 0.0,
18
+ "crossentropy_inplace_backward": false,
19
+ "d_ff": 2048,
20
+ "d_kv": 64,
21
+ "d_model": 512,
22
+ "decoder_start_token_id": 0,
23
+ "dense_act_fn": "relu",
24
+ "dropout_rate": 0.0,
25
+ "eos_token_id": 1,
26
+ "feed_forward_proj": "relu",
27
+ "fire_mlp_width": 32,
28
+ "initializer_factor": 1.0,
29
+ "is_encoder_decoder": false,
30
+ "is_gated_act": false,
31
+ "label_smoothing": 0.0,
32
+ "layer_norm_epsilon": 1e-06,
33
+ "max_sequence_length": 1024,
34
+ "model_type": "flash_t5",
35
+ "num_decoder_layers": 12,
36
+ "num_heads": 8,
37
+ "num_layers": 12,
38
+ "pad_token_id": 3,
39
+ "position_encoding_type": "t5",
40
+ "relative_attention_max_distance": 128,
41
+ "relative_attention_num_buckets": 32,
42
+ "rotary_base": 10000,
43
+ "rotary_emb_fraction": 1.0,
44
+ "rotary_interleaved": false,
45
+ "rotary_scale_base": null,
46
+ "tie_word_embeddings": false,
47
+ "torch_dtype": "float32",
48
+ "transformers_version": "4.46.0.dev0",
49
+ "use_cache": true,
50
+ "use_full_bias_size": false,
51
+ "use_gelu_act": true,
52
+ "use_glu_mlp": true,
53
+ "use_masking": false,
54
+ "use_randomized_position_encoding": false,
55
+ "use_triton_crossentropy": true,
56
+ "use_triton_layernorm": true,
57
+ "vocab_size": 32768,
58
+ "z_loss": 0.0001
59
+ }
configuration_flash_t5.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from collections import OrderedDict
3
+ from typing import Mapping
4
+ import logging
5
+
6
+ from transformers import T5Config
7
+
8
+ AUTO_MAP = {
9
+ "AutoModel": "modeling_flash_t5.FlashT5EncoderModel",
10
+ "AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
11
+ "AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification",
12
+ "AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
13
+ "AutoModelForSequenceClassification": "custom_heads_flash_t5.FlashT5ForSequenceClassification",
14
+ }
15
+
16
+ class FlashT5Config(T5Config):
17
+
18
+ model_type = "flash_t5"
19
+
20
+ def __init__(
21
+ self,
22
+ decoder_start_token_id=0,
23
+ pad_token_id=-100,
24
+ use_glu_mlp=False,
25
+ position_encoding_type="t5",
26
+ use_randomized_position_encoding=False,
27
+ label_smoothing=0.0,
28
+ z_loss=None,
29
+ attention_type="ref",
30
+ max_sequence_length=1024,
31
+ attention_dropout_rate=0.0,
32
+ alibi_mode="symetric",
33
+ use_triton_layernorm=False,
34
+ use_triton_crossentropy=False,
35
+ crossentropy_inplace_backward=False,
36
+ use_gelu_act=True,
37
+ use_full_bias_size=False,
38
+ rotary_emb_fraction=1.0,
39
+ rotary_base=10000,
40
+ rotary_interleaved=False,
41
+ rotary_scale_base=None,
42
+ fire_mlp_width=32,
43
+ use_masking=False,
44
+ attention_scale=None,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(**kwargs)
48
+
49
+ self.decoder_start_token_id = decoder_start_token_id
50
+ self.pad_token_id = pad_token_id
51
+ self.use_glu_mlp = use_glu_mlp
52
+ self.position_encoding_type = position_encoding_type
53
+ self.use_randomized_position_encoding = use_randomized_position_encoding
54
+ self.label_smoothing = label_smoothing
55
+ self.z_loss = z_loss
56
+ self.attention_type = attention_type
57
+ self.max_sequence_length = max_sequence_length
58
+ self.alibi_mode = alibi_mode
59
+ self.attention_dropout_rate = attention_dropout_rate
60
+ self.use_triton_layernorm = use_triton_layernorm
61
+ self.use_triton_crossentropy = use_triton_crossentropy
62
+ self.crossentropy_inplace_backward = crossentropy_inplace_backward
63
+ self.use_gelu_act = use_gelu_act
64
+ self.use_full_bias_size = use_full_bias_size
65
+ self.rotary_base = rotary_base
66
+ self.rotary_interleaved = rotary_interleaved
67
+ self.rotary_scale_base = rotary_scale_base
68
+ self.rotary_emb_fraction = rotary_emb_fraction
69
+ self.fire_mlp_width = fire_mlp_width
70
+ self.use_masking = use_masking
71
+ self.attention_scale = attention_scale
72
+
73
+ self.auto_map = AUTO_MAP
74
+
75
+ def str_to_class(classname):
76
+ return getattr(sys.modules[__name__], classname)
77
+
78
+ # Register model in Auto API
79
+ try:
80
+ FlashT5Config.register_for_auto_class()
81
+ for key, value in AUTO_MAP.items():
82
+ str_to_class(value.split(".")[-1]).register_for_auto_class(key)
83
+ except:
84
+ logging.warn("AutoRegister isn't available.")
cross_entropy_loss.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Copyright 2024 CATIE. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Modification to the original version from Tri Dao:
17
+ # - support for torch.compile
18
+
19
+ from typing import Tuple, Optional
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+ import triton
25
+ import triton.language as tl
26
+
27
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
28
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
29
+ # version of PyTorch. The following 2 lines are for backward compatibility with
30
+ # older PyTorch.
31
+ if "all_gather_into_tensor" not in dir(torch.distributed):
32
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
33
+
34
+
35
+ @triton.heuristics(
36
+ {
37
+ "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
38
+ }
39
+ )
40
+ @triton.jit
41
+ def cross_entropy_fwd_kernel(
42
+ loss_ptr, # data ptrs
43
+ lse_ptr,
44
+ z_loss_ptr,
45
+ logits_ptr,
46
+ labels_ptr,
47
+ smoothing,
48
+ logit_scale,
49
+ lse_square_scale,
50
+ ignore_index,
51
+ total_classes,
52
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
53
+ n_cols, # shapes
54
+ logits_row_stride, # strides
55
+ BLOCK_SIZE: tl.constexpr,
56
+ HAS_SMOOTHING: tl.constexpr,
57
+ # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
58
+ SPLIT: tl.constexpr,
59
+ PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
60
+ ):
61
+ row_idx = tl.program_id(0)
62
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
63
+ sum_logits = 0.0 # For smoothing
64
+ if not PRECOMPUTED_LSE:
65
+ # Statistics for online softmax
66
+ m_i = -float("inf")
67
+ l_i = 0.0
68
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
69
+ cols = col_offset + tl.arange(0, BLOCK_SIZE)
70
+ logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
71
+ tl.float32
72
+ ) * logit_scale
73
+ if HAS_SMOOTHING:
74
+ sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
75
+ m_i_new = tl.maximum(m_i, tl.max(logits))
76
+ l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
77
+ m_i = m_i_new
78
+ lse = tl.log(l_i) + m_i
79
+ tl.store(lse_ptr + row_idx, lse)
80
+ else:
81
+ lse = tl.load(lse_ptr + row_idx)
82
+ label_idx = tl.load(labels_ptr + row_idx)
83
+ if label_idx == ignore_index:
84
+ loss = 0.0
85
+ z_loss = 0.0
86
+ else:
87
+ label_idx -= class_start_idx
88
+ if label_idx >= 0 and label_idx < n_cols:
89
+ logits_label = tl.load(logits_ptr + label_idx) * logit_scale
90
+ if HAS_SMOOTHING:
91
+ loss = (
92
+ (lse if not SPLIT else 0.0)
93
+ - smoothing * sum_logits / total_classes
94
+ - (1 - smoothing) * logits_label
95
+ )
96
+ else:
97
+ loss = (lse if not SPLIT else 0.0) - logits_label
98
+ else:
99
+ # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
100
+ if HAS_SMOOTHING:
101
+ loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
102
+ else:
103
+ loss = 0.0
104
+ if not SPLIT:
105
+ z_loss = lse_square_scale * lse * lse
106
+ loss += z_loss
107
+ else:
108
+ z_loss = 0.0
109
+ tl.store(loss_ptr + row_idx, loss)
110
+ if not SPLIT:
111
+ tl.store(z_loss_ptr + row_idx, z_loss)
112
+
113
+
114
+ @triton.heuristics(
115
+ {
116
+ "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
117
+ }
118
+ )
119
+ @triton.jit
120
+ def cross_entropy_bwd_kernel(
121
+ dlogits_ptr, # data ptrs
122
+ dloss_ptr,
123
+ logits_ptr,
124
+ lse_ptr,
125
+ labels_ptr,
126
+ smoothing,
127
+ logit_scale,
128
+ lse_square_scale,
129
+ ignore_index,
130
+ total_classes,
131
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
132
+ n_cols, # shapes
133
+ logits_row_stride, # strides
134
+ dlogits_row_stride,
135
+ dloss_row_stride,
136
+ BLOCK_SIZE: tl.constexpr,
137
+ HAS_SMOOTHING: tl.constexpr,
138
+ ):
139
+ row_idx = tl.program_id(0)
140
+ col_block_idx = tl.program_id(1)
141
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
142
+ dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
143
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
144
+ label_idx = tl.load(labels_ptr + row_idx)
145
+ if label_idx != ignore_index:
146
+ dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
147
+ else:
148
+ dloss = 0.0
149
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
150
+ tl.float32
151
+ ) * logit_scale
152
+ lse = tl.load(lse_ptr + row_idx)
153
+ probs = tl.exp(logits - lse)
154
+ probs += 2.0 * lse_square_scale * lse * probs
155
+ label_idx -= class_start_idx
156
+ if HAS_SMOOTHING:
157
+ smooth_positive = 1.0 - smoothing
158
+ smooth_negative = smoothing / total_classes
159
+ probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
160
+ else:
161
+ probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
162
+ tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
163
+
164
+ @torch.library.custom_op("flasht5::cross_entropy_triton_fwd", mutates_args=(), device_types="cuda")
165
+ def cross_entropy_triton_fwd(
166
+ logits: torch.Tensor,
167
+ labels: torch.Tensor,
168
+ precomputed_lse: torch.Tensor,
169
+ use_precomputed_lse: bool,
170
+ split: bool,
171
+ smoothing: float,
172
+ logit_scale: float,
173
+ lse_square_scale: float,
174
+ ignore_index: int,
175
+ total_classes: int,
176
+ class_start_idx: int,
177
+ n_cols: int,
178
+ n_rows: int,
179
+ BLOCK_SIZE: int,
180
+ num_warps: int
181
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
182
+
183
+ if logits.stride(-1) != 1:
184
+ logits = logits.contiguous()
185
+
186
+ losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
187
+ if use_precomputed_lse:
188
+ assert precomputed_lse.shape == (n_rows,)
189
+ lse = precomputed_lse.contiguous()
190
+ else:
191
+ lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
192
+
193
+ z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
194
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
195
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
196
+ with torch.cuda.device(logits.device.index):
197
+ cross_entropy_fwd_kernel[(n_rows,)](
198
+ losses, # data ptrs
199
+ lse,
200
+ z_losses,
201
+ logits,
202
+ labels,
203
+ smoothing,
204
+ logit_scale,
205
+ lse_square_scale,
206
+ ignore_index,
207
+ total_classes,
208
+ class_start_idx,
209
+ n_cols, # shapes
210
+ logits.stride(0), # strides
211
+ BLOCK_SIZE=BLOCK_SIZE, # constants
212
+ SPLIT=split,
213
+ PRECOMPUTED_LSE=use_precomputed_lse,
214
+ num_warps=num_warps,
215
+ )
216
+
217
+ return losses, z_losses, lse
218
+
219
+
220
+ @torch.library.register_fake("flasht5::cross_entropy_triton_fwd")
221
+ def cross_entropy_triton_fwd_abstract(logits, labels, precomputed_lse, use_precomputed_lse, split, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps):
222
+ losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
223
+ z_losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
224
+ logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
225
+
226
+ return losses, z_losses, logsumexp
227
+
228
+ @torch.library.custom_op("flasht5::cross_entropy_triton_bwd", mutates_args={"logits"}, device_types="cuda")
229
+ def cross_entropy_triton_bwd(
230
+ dlosses: torch.Tensor,
231
+ logits: torch.Tensor,
232
+ lse: torch.Tensor,
233
+ labels: torch.Tensor,
234
+ inplace_backward: bool,
235
+ smoothing: float,
236
+ logit_scale: float,
237
+ lse_square_scale: float,
238
+ ignore_index: int,
239
+ total_classes: int,
240
+ class_start_idx: int,
241
+ n_cols: int,
242
+ n_rows: int,
243
+ BLOCK_SIZE: int,
244
+ num_warps: int
245
+ ) -> torch.Tensor:
246
+
247
+ dlogits = logits if inplace_backward else torch.empty_like(logits)
248
+
249
+ grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
250
+
251
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
252
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
253
+ with torch.cuda.device(logits.device.index):
254
+ cross_entropy_bwd_kernel[grid](
255
+ dlogits, # data ptrs
256
+ dlosses,
257
+ logits,
258
+ lse,
259
+ labels,
260
+ smoothing,
261
+ logit_scale,
262
+ lse_square_scale,
263
+ ignore_index,
264
+ total_classes,
265
+ class_start_idx,
266
+ n_cols, # shapes
267
+ logits.stride(0), # strides
268
+ dlogits.stride(0),
269
+ dlosses.stride(0),
270
+ BLOCK_SIZE=BLOCK_SIZE, # constants
271
+ num_warps=num_warps,
272
+ )
273
+
274
+ return dlogits if not inplace_backward else None
275
+
276
+ @torch.library.register_fake("flasht5::cross_entropy_triton_bwd")
277
+ def cross_entropy_triton_bwd_abstract(dlosses, logits, lse, labels, inplace_backward, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps):
278
+ return torch.empty_like(logits)
279
+
280
+ class CrossEntropyLoss(torch.autograd.Function):
281
+
282
+ @staticmethod
283
+ def forward(
284
+ ctx,
285
+ logits,
286
+ labels,
287
+ precomputed_lse=None,
288
+ smoothing=0.0,
289
+ logit_scale=1.0,
290
+ lse_square_scale=0.0,
291
+ ignore_index=-100,
292
+ inplace_backward=False,
293
+ process_group=None,
294
+ ):
295
+ # For some reason Triton generates wrong code when labels has dtype long and its address
296
+ # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
297
+ if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
298
+ labels = F.pad(labels, (0, 1))[..., :-1]
299
+ assert labels.data_ptr() % 16 == 0
300
+
301
+ n_rows, n_cols = logits.shape
302
+ assert labels.shape == (n_rows,)
303
+ world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
304
+ total_classes = world_size * n_cols
305
+ rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
306
+ class_start_idx = rank * n_cols
307
+ use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
308
+
309
+ MAX_BLOCK_SIZE = 16 * 1024
310
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
311
+ num_warps = (
312
+ 4
313
+ if BLOCK_SIZE < 2048
314
+ else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
315
+ )
316
+
317
+ losses, z_losses, lse = torch.ops.flasht5.cross_entropy_triton_fwd(
318
+ logits, labels, precomputed_lse, use_precomputed_lse, \
319
+ world_size > 1, smoothing, logit_scale, lse_square_scale, \
320
+ ignore_index, total_classes, class_start_idx, \
321
+ n_cols, n_rows, BLOCK_SIZE, num_warps
322
+ )
323
+
324
+ if world_size > 1:
325
+ # If there's no smoothing, if labels are in the vocab of this partition, losses contains
326
+ # - predicted logit, and 0 otherwise.
327
+ # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
328
+ # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
329
+ # For labels not in the vocab of this partition, losses contains
330
+ # -0.1 * sum logit / total_classes.
331
+ if world_size > 1:
332
+ lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
333
+ torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
334
+ handle_losses = torch.distributed.all_reduce(
335
+ losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
336
+ )
337
+ lse = torch.logsumexp(lse_allgather, dim=0)
338
+ handle_losses.wait()
339
+ # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
340
+ # we just have to add the (global) lse.
341
+ # If there's smoothing=0.1, the total losses are
342
+ # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
343
+ # Again, we just have to add the (global) lse.
344
+ losses += lse
345
+ if lse_square_scale != 0.0:
346
+ z_losses = lse_square_scale * lse.square()
347
+ z_losses.masked_fill_(labels == ignore_index, 0.0)
348
+ losses += z_losses
349
+ else:
350
+ z_losses = torch.zeros_like(losses)
351
+ losses.masked_fill_(labels == ignore_index, 0.0)
352
+
353
+ ctx.save_for_backward(logits, lse, labels)
354
+ ctx.mark_non_differentiable(z_losses)
355
+ ctx.smoothing = smoothing
356
+ ctx.logit_scale = logit_scale
357
+ ctx.lse_square_scale = lse_square_scale
358
+ ctx.ignore_index = ignore_index
359
+ ctx.total_classes = total_classes
360
+ ctx.class_start_idx = class_start_idx
361
+ ctx.inplace_backward = inplace_backward
362
+
363
+ return losses, z_losses
364
+
365
+ @staticmethod
366
+ def backward(ctx, grad_losses, grad_z_losses):
367
+ del grad_z_losses # z_losses are only for logging.
368
+
369
+ logits, lse, labels = ctx.saved_tensors
370
+
371
+ n_rows, n_cols = logits.shape
372
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
373
+ num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
374
+
375
+ dlogits = torch.ops.flasht5.cross_entropy_triton_bwd(
376
+ grad_losses, logits, lse, labels, \
377
+ ctx.inplace_backward, ctx.smoothing, ctx.logit_scale, \
378
+ ctx.lse_square_scale, ctx.ignore_index, ctx.total_classes, \
379
+ ctx.class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps
380
+ )
381
+
382
+ if ctx.inplace_backward:
383
+ dlogits = logits
384
+
385
+ return dlogits, None, None, None, None, None, None, None, None, None
386
+
387
+
388
+ def cross_entropy_loss(
389
+ logits: torch.Tensor,
390
+ labels: torch.Tensor,
391
+ precomputed_lse: Optional[torch.Tensor] = None,
392
+ label_smoothing: float = 0.0,
393
+ logit_scale: float = 1.0,
394
+ lse_square_scale: float = 0.0,
395
+ ignore_index=-100,
396
+ inplace_backward: bool = False,
397
+ process_group=None,
398
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
399
+ """
400
+ Arguments:
401
+ logits: (batch, vocab_size)
402
+ labels: (batch,)
403
+ label_smoothing: float
404
+ logit_scale: float. Multiply logits by this scale before calculating the loss.
405
+ lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
406
+ This is also referred to as "z-loss".
407
+ ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
408
+ inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
409
+ This saves memory.
410
+ process_group: if not None, we're doing Tensor Parallel: each process is responsible for
411
+ one part of the vocab. The loss will be aggregated across processes.
412
+ Returns:
413
+ losses: (batch,), float
414
+ z_losses: (batch,), float
415
+ """
416
+ return CrossEntropyLoss.apply(
417
+ logits.view(-1, logits.shape[-1]),
418
+ labels.view(-1),
419
+ precomputed_lse,
420
+ label_smoothing,
421
+ logit_scale,
422
+ lse_square_scale,
423
+ ignore_index,
424
+ inplace_backward,
425
+ process_group,
426
+ )
custom_heads_flash_t5.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4
+ import copy
5
+ from typing import Optional, Union, Tuple, List
6
+ from transformers.modeling_outputs import (
7
+ Seq2SeqQuestionAnsweringModelOutput,
8
+ QuestionAnsweringModelOutput,
9
+ TokenClassifierOutput,
10
+ BaseModelOutput,
11
+ Seq2SeqSequenceClassifierOutput,
12
+ SequenceClassifierOutput
13
+ )
14
+
15
+ from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model
16
+ from .configuration_flash_t5 import FlashT5Config
17
+
18
+
19
+ ################## Encoder only head ##################
20
+ class FlashT5ForTokenClassification(FlashT5PreTrainedModel):
21
+
22
+ def __init__(self, config: FlashT5Config):
23
+ super().__init__(config)
24
+ self.num_labels = config.num_labels
25
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
26
+
27
+ self.encoder = FlashT5Stack(config, self.shared)
28
+ self.dropout = nn.Dropout(config.classifier_dropout)
29
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
30
+
31
+ # Initialize weights and apply final processing
32
+ self.post_init()
33
+
34
+ # Initialize classifier
35
+ self.classifier.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
36
+ self.classifier.bias.data.zero_()
37
+
38
+ self.model_parallel = False
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: Optional[torch.Tensor] = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ head_mask: Optional[torch.Tensor] = None,
45
+ inputs_embeds: Optional[torch.Tensor] = None,
46
+ labels: Optional[torch.Tensor] = None,
47
+ output_attentions: Optional[bool] = None,
48
+ output_hidden_states: Optional[bool] = None,
49
+ return_dict: Optional[bool] = None,
50
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
51
+ r"""
52
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
53
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
54
+ Returns:
55
+ """
56
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
+
58
+ outputs = self.encoder(
59
+ input_ids=input_ids,
60
+ attention_mask=attention_mask,
61
+ inputs_embeds=inputs_embeds,
62
+ head_mask=head_mask,
63
+ output_attentions=output_attentions,
64
+ output_hidden_states=output_hidden_states,
65
+ return_dict=return_dict,
66
+ )
67
+
68
+ hidden_states = outputs[0]
69
+ hidden_states = self.dropout(hidden_states)
70
+ logits = self.classifier(hidden_states)
71
+
72
+ loss = None
73
+ if labels is not None:
74
+ loss_fct = nn.CrossEntropyLoss()
75
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
76
+
77
+ if not return_dict:
78
+ output = (logits, outputs[2:-1])
79
+ return ((loss,) + output) if loss is not None else output
80
+
81
+ return TokenClassifierOutput(
82
+ loss=loss,
83
+ logits=logits,
84
+ hidden_states=outputs.hidden_states,
85
+ attentions=outputs.attentions,
86
+ )
87
+
88
+
89
+ class FlashT5ClassificationHead(nn.Module):
90
+ """Head for sentence-level classification tasks."""
91
+
92
+ def __init__(self, config: FlashT5Config):
93
+ super().__init__()
94
+ self.dense = nn.Linear(config.d_model, config.d_model)
95
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
96
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
97
+
98
+ # initialize weights
99
+ factor = config.initializer_factor
100
+ self.dense.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
101
+ if hasattr(self.dense, "bias") and self.dense.bias is not None:
102
+ self.dense.bias.data.zero_()
103
+ self.out_proj.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
104
+ if hasattr(self.out_proj, "bias") and self.out_proj.bias is not None:
105
+ self.out_proj.bias.data.zero_()
106
+
107
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
108
+ hidden_states = self.dropout(hidden_states)
109
+ hidden_states = self.dense(hidden_states)
110
+ hidden_states = torch.tanh(hidden_states)
111
+ hidden_states = self.dropout(hidden_states)
112
+ hidden_states = self.out_proj(hidden_states)
113
+ return hidden_states
114
+
115
+
116
+ class FlashT5ForSequenceClassification(FlashT5PreTrainedModel):
117
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
118
+
119
+ def __init__(self, config: FlashT5Config):
120
+ super().__init__(config)
121
+ self.model_dim = config.d_model
122
+ self.config.problem_type = None
123
+ self.config.is_encoder_decoder = False
124
+
125
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
126
+
127
+ encoder_config = copy.deepcopy(config)
128
+ encoder_config.is_decoder = False
129
+ encoder_config.is_encoder_decoder = False
130
+ encoder_config.use_cache = False
131
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
132
+ self.classification_head = FlashT5ClassificationHead(config)
133
+
134
+ # Initialize weights and apply final processing
135
+ self.post_init()
136
+
137
+ self.model_parallel = False
138
+
139
+ def forward(
140
+ self,
141
+ input_ids: torch.LongTensor = None,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ head_mask: Optional[torch.Tensor] = None,
144
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
145
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
146
+ inputs_embeds: Optional[torch.FloatTensor] = None,
147
+ labels: Optional[torch.LongTensor] = None,
148
+ use_cache: Optional[bool] = None,
149
+ output_attentions: Optional[bool] = None,
150
+ output_hidden_states: Optional[bool] = None,
151
+ return_dict: Optional[bool] = None,
152
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
153
+ r"""
154
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
155
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
156
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
157
+ Returns:
158
+ """
159
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
160
+ if labels is not None:
161
+ use_cache = False
162
+
163
+ if input_ids is None and inputs_embeds is not None:
164
+ raise NotImplementedError(
165
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
166
+ )
167
+
168
+
169
+ outputs = self.encoder(
170
+ input_ids=input_ids,
171
+ attention_mask=attention_mask,
172
+ inputs_embeds=inputs_embeds,
173
+ head_mask=head_mask,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ )
178
+ sequence_output = outputs[0]
179
+
180
+ eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
181
+
182
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
183
+ raise ValueError("All examples must have the same number of <eos> tokens.")
184
+ batch_size, _, hidden_size = sequence_output.shape
185
+ sentence_representation = sequence_output[:, -1, :]
186
+ # sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
187
+ logits = self.classification_head(sentence_representation)
188
+
189
+ loss = None
190
+ if labels is not None:
191
+ labels = labels.to(logits.device)
192
+ if self.config.problem_type is None:
193
+ if self.config.num_labels == 1:
194
+ self.config.problem_type = "regression"
195
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
196
+ self.config.problem_type = "single_label_classification"
197
+ else:
198
+ self.config.problem_type = "multi_label_classification"
199
+
200
+ if self.config.problem_type == "regression":
201
+ loss_fct = nn.MSELoss()
202
+ if self.config.num_labels == 1:
203
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
204
+ else:
205
+ loss = loss_fct(logits, labels)
206
+ elif self.config.problem_type == "single_label_classification":
207
+ loss_fct = nn.CrossEntropyLoss()
208
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
209
+ elif self.config.problem_type == "multi_label_classification":
210
+ loss_fct = nn.BCEWithLogitsLoss()
211
+ loss = loss_fct(logits, labels)
212
+ if not return_dict:
213
+ output = (logits,) + outputs[1:]
214
+ return ((loss,) + output) if loss is not None else output
215
+
216
+ return SequenceClassifierOutput(
217
+ loss=loss,
218
+ logits=logits,
219
+ hidden_states=outputs.hidden_states,
220
+ attentions=outputs.attentions
221
+ )
222
+
223
+
224
+ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
225
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
226
+
227
+ def __init__(self, config: FlashT5Config):
228
+ super().__init__(config)
229
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
230
+
231
+ encoder_config = copy.deepcopy(config)
232
+ encoder_config.is_decoder = False
233
+ encoder_config.is_encoder_decoder = False
234
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
235
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
236
+
237
+ # Initialize weights and apply final processing
238
+ self.post_init()
239
+
240
+ self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
241
+ self.qa_outputs.bias.data.zero_()
242
+
243
+ self.model_parallel = False
244
+
245
+ def forward(
246
+ self,
247
+ input_ids: Optional[torch.LongTensor] = None,
248
+ attention_mask: Optional[torch.FloatTensor] = None,
249
+ head_mask: Optional[torch.FloatTensor] = None,
250
+ inputs_embeds: Optional[torch.FloatTensor] = None,
251
+ start_positions: Optional[torch.LongTensor] = None,
252
+ end_positions: Optional[torch.LongTensor] = None,
253
+ output_attentions: Optional[bool] = None,
254
+ output_hidden_states: Optional[bool] = None,
255
+ return_dict: Optional[bool] = None,
256
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
257
+ r"""
258
+ Returns:
259
+
260
+ Example:
261
+
262
+ ```python
263
+ >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
264
+
265
+ >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
266
+ >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
267
+ >>> input_ids = tokenizer(
268
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
269
+ ... ).input_ids # Batch size 1
270
+ >>> outputs = model(input_ids=input_ids)
271
+ >>> start_logits = outputs.start_logits
272
+ >>> end_logits = outputs.end_logits
273
+ ```"""
274
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
275
+
276
+ outputs = self.encoder(
277
+ input_ids,
278
+ attention_mask=attention_mask,
279
+ inputs_embeds=inputs_embeds,
280
+ )
281
+ sequence_output = outputs[0]
282
+
283
+ logits = self.qa_outputs(sequence_output)
284
+ start_logits, end_logits = logits.split(1, dim=-1)
285
+ start_logits = start_logits.squeeze(-1).contiguous()
286
+ end_logits = end_logits.squeeze(-1).contiguous()
287
+
288
+ total_loss = None
289
+ if start_positions is not None and end_positions is not None:
290
+ # If we are on multi-GPU, split add a dimension
291
+ if len(start_positions.size()) > 1:
292
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
293
+ if len(end_positions.size()) > 1:
294
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
295
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
296
+ ignored_index = start_logits.size(1)
297
+ start_positions = start_positions.clamp(0, ignored_index)
298
+ end_positions = end_positions.clamp(0, ignored_index)
299
+
300
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
301
+ start_loss = loss_fct(start_logits, start_positions)
302
+ end_loss = loss_fct(end_logits, end_positions)
303
+ total_loss = (start_loss + end_loss) / 2
304
+
305
+ if not return_dict:
306
+ output = (start_logits, end_logits) + outputs[1:]
307
+ return ((total_loss,) + output) if total_loss is not None else output
308
+
309
+ return QuestionAnsweringModelOutput(
310
+ loss=total_loss,
311
+ start_logits=start_logits,
312
+ end_logits=end_logits,
313
+ hidden_states=outputs.hidden_states,
314
+ attentions=outputs.attentions,
315
+ )
flash_attention_v2_bias.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 BAAI
2
+ # Copyright 2024 CATIE
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Modifications to the orignal file
17
+ # - Support for biases following https://github.com/FlagOpen/FlagAttention/pull/5
18
+ # - Support for shape (1,1,q,k) biases
19
+
20
+ import math
21
+ import torch
22
+ import triton
23
+ import triton.language as tl
24
+
25
+ from typing import Tuple
26
+
27
+ @torch.library.custom_op("flasht5::flash_attn_v2_fwd", mutates_args=(), device_types="cuda")
28
+ def flash_attn_v2_fwd(
29
+ q: torch.Tensor,
30
+ k: torch.Tensor,
31
+ v: torch.Tensor,
32
+ bias: torch.Tensor,
33
+ causal: bool,
34
+ sm_scale: float,
35
+ BLOCK_M: int,
36
+ BLOCK_N: int,
37
+ num_warps: int, num_stages: int
38
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
39
+
40
+ B, H, M, D = q.shape
41
+ N = k.shape[2]
42
+ P_SEQ = N - M
43
+ larger_m = M > N
44
+
45
+ # Trick to support shape such as (1, 1, seqlen_q, seqlen_k)
46
+ bias_batch_stride = bias.stride(0) if bias is not None else 0
47
+ bias_heads_stride = bias.stride(1) if bias is not None else 0
48
+ if bias is not None:
49
+ if (bias.shape[0] != q.shape[0]) and (bias.shape[0] == 1):
50
+ bias_batch_stride = 0
51
+ if (bias.shape[1] != q.shape[1]) and (bias.shape[1] == 1):
52
+ bias_heads_stride = 0
53
+
54
+ divisible_m = M % BLOCK_M == 0
55
+ divisible_n = N % BLOCK_N == 0
56
+ # consider using 3d grid to avoid div & rem
57
+ grid = (triton.cdiv(M, BLOCK_M), H, B)
58
+ o = torch.empty_like(q)
59
+ L = torch.empty((B, H, M), device=q.device, dtype=torch.float32)
60
+
61
+ with torch.cuda.device(q.device.index):
62
+ _fwd_kernel[grid](
63
+ q, k, v, bias, sm_scale,
64
+ L, o,
65
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
66
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
67
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
68
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
69
+ bias_batch_stride, bias_heads_stride,
70
+ bias.stride(2) if bias is not None else 0,
71
+ bias.stride(3) if bias is not None else 0,
72
+ B, H, M, N, P_SEQ,
73
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
74
+ IS_CAUSAL=causal, LARGER_M=larger_m,
75
+ DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
76
+ HAS_BIAS=(bias is not None),
77
+ num_warps=num_warps, num_stages=num_stages,
78
+ )
79
+
80
+ return o, L
81
+
82
+
83
+ @torch.library.register_fake("flasht5::flash_attn_v2_fwd")
84
+ def flash_attn_v2_fwd_abstract(q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):
85
+ B, H, M, D = q.shape
86
+ o = torch.empty_like(q)
87
+ L = torch.empty((B, H, M), dtype=torch.float32, device=q.device)
88
+
89
+ return o, L
90
+
91
+ @torch.library.custom_op("flasht5::flash_attn_v2_bwd", mutates_args=(), device_types="cuda")
92
+ def flash_attn_v2_bwd(
93
+ o: torch.Tensor,
94
+ do: torch.Tensor,
95
+ q: torch.Tensor,
96
+ k: torch.Tensor,
97
+ v: torch.Tensor,
98
+ bias: torch.Tensor,
99
+ L: torch.Tensor,
100
+ causal: bool,
101
+ sm_scale: float,
102
+ BLOCK_M: int,
103
+ BLOCK_N: int,
104
+ num_warps: int,
105
+ num_stages: int
106
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
107
+
108
+ B, H, M, D = q.shape
109
+ N = k.shape[2]
110
+ P_SEQ = N - M
111
+ larger_m = M > N
112
+
113
+ divisible_m = M % BLOCK_M == 0
114
+ divisible_n = N % BLOCK_N == 0
115
+
116
+ # Trick to support shape such as (1, 1, seqlen_q, seqlen_k)
117
+ bias_batch_stride = bias.stride(0) if bias is not None else 0
118
+ bias_heads_stride = bias.stride(1) if bias is not None else 0
119
+ if bias is not None:
120
+ if (bias.shape[0] != q.shape[0]) and (bias.shape[0] == 1):
121
+ bias_batch_stride = 0
122
+ if (bias.shape[1] != q.shape[1]) and (bias.shape[1] == 1):
123
+ bias_heads_stride = 0
124
+
125
+ delta = torch.empty_like(L)
126
+ grid = (triton.cdiv(M, BLOCK_M), H, B)
127
+
128
+ with torch.cuda.device(q.device.index):
129
+ _bwd_preprocess[grid](
130
+ o, do,
131
+ delta,
132
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
133
+ do.stride(0), do.stride(1), do.stride(2), do.stride(3),
134
+ delta.stride(0), delta.stride(1), delta.stride(2),
135
+ M,
136
+ BLOCK_M=BLOCK_M, D_HEAD=D,
137
+ DIVISIBLE_M=divisible_m,
138
+ )
139
+
140
+ dk = torch.empty_like(k)
141
+ dv = torch.empty_like(v)
142
+
143
+ HAS_BIAS = bias is not None
144
+ RETURN_DS = HAS_BIAS
145
+ IS_BATCH_REDUCED = (bias_batch_stride == 0)
146
+ #GROUP_SIZE_BIAS = min(B, 16)
147
+ GROUP_SIZE_BIAS = B
148
+
149
+ ds = None
150
+ locks = None
151
+ if RETURN_DS:
152
+ if IS_BATCH_REDUCED:
153
+ if causal:
154
+ ds = torch.zeros((GROUP_SIZE_BIAS, *bias.shape[1:]), dtype=bias.dtype, device=bias.device)
155
+ else:
156
+ ds = torch.empty((GROUP_SIZE_BIAS, *bias.shape[1:]), dtype=bias.dtype, device=bias.device)
157
+ locks = torch.zeros(2 * GROUP_SIZE_BIAS, dtype=torch.int32, device=q.device)
158
+ else:
159
+ if causal:
160
+ ds = torch.zeros_like(bias)
161
+ else:
162
+ ds = torch.empty_like(bias)
163
+
164
+ grid = (triton.cdiv(N, BLOCK_N), H, B)
165
+ with torch.cuda.device(q.device.index):
166
+ _bwd_kv_kernel[grid](
167
+ q, k, v, bias, sm_scale, do,
168
+ dk, dv, ds,
169
+ L, delta,
170
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
171
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
172
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
173
+ bias.stride(0) if HAS_BIAS else 0,
174
+ bias_heads_stride,
175
+ bias.stride(2) if HAS_BIAS else 0,
176
+ bias.stride(3) if HAS_BIAS else 0,
177
+ do.stride(0), do.stride(1), do.stride(2), do.stride(3),
178
+ dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
179
+ dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
180
+ B, H, M, N, P_SEQ,
181
+ locks,
182
+ BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal,
183
+ DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
184
+ HAS_BIAS=HAS_BIAS,
185
+ RETURN_DS=RETURN_DS,
186
+ IS_BATCH_REDUCED=IS_BATCH_REDUCED,
187
+ GROUP_SIZE_BIAS=GROUP_SIZE_BIAS,
188
+ num_stages=num_stages, num_warps=num_warps,
189
+ )
190
+
191
+ dq = torch.empty_like(q)
192
+ grid = (triton.cdiv(M, BLOCK_M), H, B)
193
+ with torch.cuda.device(q.device.index):
194
+ _bwd_q_kernel[grid](
195
+ q, k, v, bias, sm_scale, do,
196
+ dq,
197
+ L, delta,
198
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
199
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
200
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
201
+ bias_batch_stride, bias_heads_stride,
202
+ bias.stride(2) if HAS_BIAS else 0,
203
+ bias.stride(3) if HAS_BIAS else 0,
204
+ do.stride(0), do.stride(1), do.stride(2), do.stride(3),
205
+ dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
206
+ B, H, M, N, P_SEQ,
207
+ BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
208
+ CAUSAL=causal, LARGER_M=larger_m,
209
+ DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
210
+ HAS_BIAS=HAS_BIAS,
211
+ num_stages=num_stages, num_warps = num_warps,
212
+ )
213
+
214
+ if RETURN_DS and IS_BATCH_REDUCED and GROUP_SIZE_BIAS > 1:
215
+ ds = ds.sum(0, keepdim=True)
216
+
217
+ return dq, dk, dv, ds
218
+
219
+ @torch.library.register_fake("flasht5::flash_attn_v2_bwd")
220
+ def flash_attn_v2_bwd_abstract(o, do, q, k, v, bias, L, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):
221
+ dq = torch.empty_like(q)
222
+ dk = torch.empty_like(k)
223
+ dv = torch.empty_like(v)
224
+ ds = torch.empty_like(bias) if bias is not None else None
225
+
226
+ return dq, dk, dv, ds
227
+
228
+ class FlashAttentionAdditiveBias(torch.autograd.Function):
229
+ @staticmethod
230
+ def forward(ctx, q, k, v, bias, causal, sm_scale):
231
+ Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1]
232
+
233
+ assert Dq == Dk == Dv
234
+ assert Dk in {16, 32, 64, 128}
235
+
236
+ B, H, M, D = q.shape
237
+ N = k.shape[2]
238
+
239
+ if sm_scale is None:
240
+ sm_scale = 1. / math.sqrt(D)
241
+
242
+ config = get_fwd_config(B, H, M, N, D, causal)
243
+ BLOCK_M, BLOCK_N, num_stages, num_warps = config
244
+
245
+ o, L = torch.ops.flasht5.flash_attn_v2_fwd(q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages)
246
+
247
+ # autograd context maintenance
248
+ ctx.save_for_backward(q, k, v, bias, o, L)
249
+ ctx.sm_scale = sm_scale
250
+ ctx.causal = causal
251
+
252
+ return o
253
+
254
+ @staticmethod
255
+ def backward(ctx, do, *ignored):
256
+ q, k, v, bias, o, L = ctx.saved_tensors
257
+ sm_scale = ctx.sm_scale
258
+ causal = ctx.causal
259
+
260
+ B, H, M, D = q.shape
261
+ N = k.shape[2]
262
+
263
+ if sm_scale is None:
264
+ sm_scale = 1. / math.sqrt(D)
265
+
266
+ config = get_bwd_config(B, H, M, N, D, causal)
267
+ BLOCK_M, BLOCK_N, num_stages, num_warps = config
268
+
269
+ dq, dk, dv, ds = torch.ops.flasht5.flash_attn_v2_bwd(o, do, q, k, v, bias, L, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages)
270
+
271
+ return dq, dk, dv, ds, None, None, None, None
272
+
273
+
274
+ def flash_attention_v2_bias(q, k, v, bias, causal=False, sm_scale=None):
275
+ """
276
+ An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691).
277
+
278
+ Arguments:
279
+ q(torch.Tensor): The first queries. The shape is (batch_size, nheads, seqlen_q, headdim).
280
+ k(torch.Tensor): The first keys. The shape is (batch_size, nheads, seqlen_k, headdim).
281
+ v(torch.Tensor): The values. The shape is (batch_size, nheads, seqlen_k, headdim).
282
+ causal(bool): Whether causal masking is applied to attention scores before applying softmax.
283
+ sm_scale(float): The scaling of attention scores before applying softmax.
284
+
285
+ Returns:
286
+ out(torch.Tensor): The output. The shape is (batch_size, nheads, seqlen_q, headdim).
287
+ """
288
+ return FlashAttentionAdditiveBias.apply(q, k, v, bias, causal, sm_scale)
289
+
290
+
291
+ # --------------------------- Forward ---------------------------
292
+ # NOTE: this function can be overwritten at runtime to use your custom config
293
+ def get_fwd_config(B, H, M, N, D, causal):
294
+ if torch.cuda.get_device_capability() == (8, 0):
295
+ if not causal:
296
+ if D <= 64:
297
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
298
+ else:
299
+ if M <= 1024:
300
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
301
+ else:
302
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
303
+ else:
304
+ if D <= 64:
305
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
306
+ else:
307
+ if M <= 1024:
308
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
309
+ else:
310
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
311
+ elif torch.cuda.get_device_capability() == (8, 6):
312
+ if not causal:
313
+ if D <= 64:
314
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
315
+ else:
316
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
317
+ else: # causal
318
+ if D <= 64:
319
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4
320
+ else:
321
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
322
+ else:
323
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
324
+ return (BLOCK_M, BLOCK_N, num_stages, num_warps)
325
+
326
+
327
+ @triton.jit
328
+ def _fwd_kernel(
329
+ Q, K, V, B, sm_scale,
330
+ L, O,
331
+ stride_qz, stride_qh, stride_qm, stride_qk,
332
+ stride_kz, stride_kh, stride_kn, stride_kk,
333
+ stride_vz, stride_vh, stride_vn, stride_vk,
334
+ stride_oz, stride_oh, stride_om, stride_ok,
335
+ stride_bz, stride_bh, stride_bm, stride_bn,
336
+ Z, H, M, N, P_SEQ,
337
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
338
+ IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
339
+ DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
340
+ HAS_BIAS: tl.constexpr,
341
+ ):
342
+ input_dtype = Q.dtype.element_ty
343
+ # -- grid id --
344
+ start_m = tl.program_id(0)
345
+ off_h = tl.program_id(1)
346
+ off_z = tl.program_id(2)
347
+
348
+ # scale sm_scale by log_2(e) and use
349
+ # 2^x instead of exp in the loop because CSE and LICM
350
+ # don't work as expected with `exp` in the loop
351
+ log2e: tl.constexpr = 1.4426950408889634
352
+
353
+ # offset pointers for (batch, head)
354
+ Q += off_z * stride_qz + off_h * stride_qh
355
+ K += off_z * stride_kz + off_h * stride_kh
356
+ V += off_z * stride_vz + off_h * stride_vh
357
+ O += off_z * stride_oz + off_h * stride_oh
358
+ if HAS_BIAS:
359
+ B += off_z * stride_bz + off_h * stride_bh
360
+ L += (off_z * H + off_h) * M # l's shape is (B, H, M)
361
+
362
+ offs_m_base = tl.arange(0, BLOCK_M)
363
+ offs_m = start_m * BLOCK_M + offs_m_base
364
+ offs_n_base = tl.arange(0, BLOCK_N)
365
+ offs_k = tl.arange(0, BLOCK_DMODEL)
366
+
367
+ # initialize pointers to value-like data
368
+ q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
369
+ o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL)
370
+ l_ptrs = L + offs_m
371
+
372
+ # initialize pointer to m and l, fp32 for accumulators
373
+ m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
374
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
375
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
376
+
377
+ # load q
378
+ mask_m = offs_m < M
379
+ if DIVISIBLE_M:
380
+ q = tl.load(q_ptrs, cache_modifier=".cg")
381
+ else:
382
+ q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg")
383
+
384
+ #Dot I trick: to place q in registers, it saves shared memory
385
+ if BLOCK_DMODEL < 128:
386
+ I = tl.where(offs_k[:, None] == offs_k,
387
+ tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),
388
+ tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))
389
+ q = tl.dot(q, I).to(input_dtype)
390
+ # else:
391
+ # I = tl.where(offs_m_base[:, None] == offs_m_base,
392
+ # tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype),
393
+ # tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype))
394
+ # q = tl.dot(I, q).to(input_dtype)
395
+
396
+ # NOTE: Loop-Bound-For-N
397
+ # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`.
398
+ # According to the rule of causal masking, then max index in n-dimension that this block may access
399
+ # is `P_SEQ + (start_m + 1) * BLOCK_M`.
400
+ # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`).
401
+ # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`.
402
+ # At this case, there would be illegal memory access when loading k & v tiles
403
+ # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true).
404
+ # See also https://github.com/FlagOpen/FlagAttention/pull/8
405
+ if IS_CAUSAL:
406
+ hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
407
+ if LARGER_M:
408
+ hi = tl.maximum(0, hi)
409
+ else:
410
+ hi = N
411
+
412
+ # loop over k, v and update accumulators
413
+ offs_n_init = offs_n_base
414
+ k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N)
415
+ v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
416
+ if HAS_BIAS:
417
+ bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n_init[None, :] * stride_bn) # (BLOCK_M, BLOCK_N)
418
+
419
+ for start_n in range(0, hi, BLOCK_N):
420
+ start_n = tl.multiple_of(start_n, BLOCK_N)
421
+ offs_n = start_n + offs_n_base
422
+
423
+ # -- load k, v --
424
+ mask_n = offs_n < N
425
+ if DIVISIBLE_N:
426
+ k = tl.load(k_ptrs, cache_modifier=".cg")
427
+ v = tl.load(v_ptrs, cache_modifier=".cg")
428
+ else:
429
+ k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg")
430
+ v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg")
431
+
432
+ # -- load bias --
433
+ if HAS_BIAS:
434
+ if DIVISIBLE_M and DIVISIBLE_N:
435
+ b = tl.load(bias_ptrs)
436
+ else:
437
+ b = tl.load(bias_ptrs, mask_m[:, None] & mask_n[None, :])
438
+
439
+ # -- compute qk ---
440
+ s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
441
+ s += tl.dot(q, k) * sm_scale
442
+ if HAS_BIAS:
443
+ s += b
444
+
445
+ if not DIVISIBLE_N:
446
+ s = tl.where(mask_n[None, :], s, float("-inf"))
447
+ if IS_CAUSAL:
448
+ causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :]
449
+ s = tl.where(causal_mask, s, float("-inf"))
450
+
451
+ # -- compute scaling constant ---
452
+ m_i_new = tl.maximum(m_i, tl.max(s, 1))
453
+ alpha = tl.math.exp2((m_i - m_i_new)*log2e)
454
+ p = tl.math.exp2((s - m_i_new[:, None])*log2e)
455
+
456
+ # -- scale and update acc: acc *= alpha[:, None]--
457
+ acc *= alpha[:, None]
458
+ acc += tl.dot(p.to(input_dtype), v)
459
+
460
+ # -- update m_i and l_i --
461
+ l_i = l_i * alpha + tl.sum(p, 1)
462
+ m_i = m_i_new
463
+ # update pointers
464
+ k_ptrs += BLOCK_N * stride_kn
465
+ v_ptrs += BLOCK_N * stride_vn
466
+ if HAS_BIAS:
467
+ bias_ptrs += BLOCK_N * stride_bn
468
+
469
+ # write back l & o
470
+ if IS_CAUSAL and LARGER_M:
471
+ is_empty_line = (offs_m + P_SEQ) < 0
472
+ acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None]))
473
+ l = tl.where(is_empty_line, float("-inf"), m_i + tl.log(l_i))
474
+ else:
475
+ acc = acc * (1.0 / l_i[:, None])
476
+ l = m_i + tl.log(l_i) # log(normalizer)
477
+
478
+ if DIVISIBLE_M:
479
+ tl.store(l_ptrs, l, cache_modifier=".cg")
480
+ tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg")
481
+ else:
482
+ tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg")
483
+ tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg")
484
+
485
+
486
+ # --------------------------- Backward ---------------------------
487
+ # NOTE: this function can be overwritten at runtime to use your custom config
488
+ def get_bwd_config(B, H, M, N, D, causal):
489
+ if torch.cuda.get_device_capability() == (8, 0):
490
+ if not causal:
491
+ BLOCK_M = 128 if D <= 64 else 64
492
+ BLOCK_N = 64
493
+ num_stages = 2
494
+ num_warps = 4
495
+ else:
496
+ BLOCK_M = 64
497
+ BLOCK_N = 64
498
+ num_stages = 3 if D <= 64 else 2
499
+ num_warps = 4
500
+ elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6)
501
+ if not causal:
502
+ if D <= 64:
503
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
504
+ else:
505
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 8
506
+ else:
507
+ if D <= 64:
508
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
509
+ else:
510
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4
511
+ else:
512
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
513
+ return (BLOCK_M, BLOCK_N, num_stages, num_warps)
514
+
515
+
516
+ @triton.jit
517
+ def _bwd_preprocess(
518
+ Out, DO,
519
+ Delta,
520
+ stride_oz, stride_oh, stride_om, stride_ok,
521
+ stride_doz, stride_doh, stride_dom, stride_dok,
522
+ stride_dz, stride_dh, stride_dm,
523
+ M,
524
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
525
+ DIVISIBLE_M: tl.constexpr,
526
+ ):
527
+ off_h = tl.program_id(1)
528
+ off_z = tl.program_id(2)
529
+ Out += off_z * stride_oz + off_h * stride_oh
530
+ DO += off_z * stride_doz + off_h * stride_doh
531
+ Delta += off_z * stride_dz + off_h * stride_dh
532
+
533
+ # compute (Out * Dout).sum() for vector interpretation
534
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
535
+ off_n = tl.arange(0, D_HEAD)
536
+
537
+ # load
538
+ o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok
539
+ do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok
540
+
541
+ if DIVISIBLE_M:
542
+ o = tl.load(o_ptrs).to(tl.float32)
543
+ do = tl.load(do_ptrs).to(tl.float32)
544
+ else:
545
+ mask_m = off_m < M
546
+ o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32)
547
+ do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32)
548
+
549
+ # compute
550
+ delta = tl.sum(o * do, axis=1)
551
+ # write-back
552
+ d_ptrs = Delta + off_m * stride_dm
553
+ if DIVISIBLE_M:
554
+ tl.store(d_ptrs, delta)
555
+ else:
556
+ tl.store(d_ptrs, delta, mask=mask_m)
557
+
558
+
559
+ @triton.jit
560
+ def _bwd_kv_kernel(
561
+ Q, K, V, B, sm_scale, DO,
562
+ DK, DV, DS,
563
+ L,
564
+ D,
565
+ stride_qz, stride_qh, stride_qm, stride_qk,
566
+ stride_kz, stride_kh, stride_kn, stride_kk,
567
+ stride_vz, stride_vh, stride_vn, stride_vk,
568
+ stride_bz, stride_bh, stride_bm, stride_bn,
569
+ stride_doz, stride_doh, stride_dom, stride_dok,
570
+ stride_dkz, stride_dkh, stride_dkn, stride_dkk,
571
+ stride_dvz, stride_dvh, stride_dvn, stride_dvk,
572
+ Z, H, M, N, P_SEQ,
573
+ lock,
574
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
575
+ CAUSAL: tl.constexpr,
576
+ DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
577
+ HAS_BIAS: tl.constexpr,
578
+ RETURN_DS: tl.constexpr,
579
+ IS_BATCH_REDUCED: tl.constexpr,
580
+ GROUP_SIZE_BIAS: tl.constexpr,
581
+ ):
582
+ input_dtype = Q.dtype.element_ty
583
+ # -- grid id --
584
+ start_n = tl.program_id(0)
585
+ off_h = tl.program_id(1)
586
+ off_z = tl.program_id(2)
587
+ log2e: tl.constexpr = 1.4426950408889634
588
+
589
+ # offset pointers for (batch, head)
590
+ Q += off_z * stride_qz + off_h * stride_qh
591
+ K += off_z * stride_kz + off_h * stride_kh
592
+ V += off_z * stride_vz + off_h * stride_vh
593
+ if HAS_BIAS:
594
+ if IS_BATCH_REDUCED:
595
+ B += off_h * stride_bh
596
+ else:
597
+ B += off_z * stride_bz + off_h * stride_bh
598
+ DO += off_z * stride_doz + off_h * stride_doh
599
+
600
+ # offset pointers for batch/head
601
+ DK += off_z * stride_dkz + off_h * stride_dkh
602
+ DV += off_z * stride_dvz + off_h * stride_dvh
603
+
604
+ # offset pointer for ds tensor and locks for the reduction
605
+ if RETURN_DS:
606
+ DS += off_z * stride_bz + off_h * stride_bh
607
+
608
+ # offset pointers for batch/head
609
+ D += (off_z * H + off_h) * M
610
+ L += (off_z * H + off_h) * M
611
+
612
+ if CAUSAL:
613
+ lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0)
614
+ lo = (lo // BLOCK_M) * BLOCK_M
615
+ else:
616
+ lo = 0
617
+
618
+ offs_m_init = lo + tl.arange(0, BLOCK_M)
619
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
620
+ offs_m_base = tl.arange(0, BLOCK_M)
621
+ offs_k = tl.arange(0, BLOCK_DMODEL)
622
+
623
+ # initialize pointers to value-like data
624
+ q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
625
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
626
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
627
+ do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
628
+
629
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL)
630
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) # (BLOCK_N, BLOCK_DMODEL)
631
+
632
+ if HAS_BIAS:
633
+ bias_ptrs = B + (offs_m_init[:, None] * stride_bm + offs_n[None, :] * stride_bn)
634
+
635
+ if RETURN_DS:
636
+ ds_ptrs = DS + (offs_m_init[:, None] * stride_bm + offs_n[None, :] * stride_bn)
637
+
638
+ # k and v stay in SRAM throughout
639
+ mask_n = offs_n < N
640
+ if DIVISIBLE_N:
641
+ v = tl.load(v_ptrs)
642
+ k = tl.load(k_ptrs)
643
+ else:
644
+ v = tl.load(v_ptrs, mask=mask_n[:, None])
645
+ k = tl.load(k_ptrs, mask=mask_n[:, None])
646
+
647
+ # initialize dk amd dv
648
+ dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
649
+ dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
650
+
651
+ # loop over a col
652
+ for start_m in range(lo, M, BLOCK_M):
653
+ start_m = tl.multiple_of(start_m, BLOCK_M)
654
+ offs_m = start_m + offs_m_base
655
+ causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
656
+
657
+ # load q1, k1, q2, k2, v, do on-chip
658
+ mask_m = offs_m < M
659
+ if DIVISIBLE_M:
660
+ q = tl.load(q_ptrs)
661
+ else:
662
+ valid_mask = mask_m[:, None] # & mask_n
663
+ q = tl.load(q_ptrs, mask=mask_m[:, None])
664
+
665
+ # load bias
666
+ if HAS_BIAS:
667
+ if DIVISIBLE_M and DIVISIBLE_N:
668
+ b = tl.load(bias_ptrs)
669
+ else:
670
+ b = tl.load(bias_ptrs, mask=mask_m[:, None] & mask_n[None, :])
671
+
672
+ # recompute p = softmax(qk * sm_scale, dim=-1)
673
+ s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
674
+ s += tl.dot(q, tl.trans(k)) * sm_scale
675
+
676
+ if HAS_BIAS:
677
+ s += b
678
+
679
+ # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
680
+ # So masking on s is not needed.
681
+ # s = tl.where(valid_mask, s , float("-inf"))
682
+ # if CAUSAL:
683
+ # s = tl.where(causal_mask, s, float("-inf"))
684
+
685
+ # -- recompute p ---
686
+ if DIVISIBLE_M:
687
+ l = tl.load(L + offs_m)
688
+ else:
689
+ l = tl.load(L + offs_m, mask=mask_m)
690
+ p = tl.math.exp2((s - l[:, None])*log2e) # (BLOCK_M, BLOCK_N)
691
+
692
+ if not DIVISIBLE_M:
693
+ p = tl.where(valid_mask, p, 0.0)
694
+ if CAUSAL:
695
+ p = tl.where(causal_mask, p, 0.0)
696
+
697
+ # compute dv = dot(p, do)
698
+ if DIVISIBLE_M:
699
+ do = tl.load(do_ptrs)
700
+ else:
701
+ do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL)
702
+ dv += tl.dot(tl.trans(p.to(do.dtype)), do) # (BLOCK_N, BLOCK_DMODEL) # still correct
703
+
704
+ # compute dp = dot(v, do)
705
+ if DIVISIBLE_M:
706
+ delta = tl.load(D + offs_m)
707
+ else:
708
+ delta = tl.load(D + offs_m, mask=mask_m)
709
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
710
+ dp += tl.dot(do, tl.trans(v))
711
+
712
+ # compute ds = p * (dp - delta[:, None])
713
+ ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
714
+
715
+ if not DIVISIBLE_M:
716
+ ds = tl.where(valid_mask, ds, 0.0)
717
+ if CAUSAL:
718
+ ds = tl.where(causal_mask, ds, 0.0)
719
+
720
+ ds = ds.to(input_dtype)
721
+ # compute dk = dot(ds.T, q) masking
722
+ dk += tl.dot(tl.trans(ds), q)
723
+
724
+ # store ds
725
+ if RETURN_DS:
726
+ if DIVISIBLE_M and DIVISIBLE_N:
727
+ tl.store(ds_ptrs, ds)
728
+ else:
729
+ tl.store(ds_ptrs, ds, mask=mask_m[:, None] & mask_n[None, :])
730
+
731
+ # increment pointers
732
+ q_ptrs += BLOCK_M * stride_qm
733
+ do_ptrs += BLOCK_M * stride_dom
734
+ if HAS_BIAS:
735
+ bias_ptrs += BLOCK_M * stride_bm
736
+ if RETURN_DS:
737
+ ds_ptrs += BLOCK_M * stride_bm
738
+
739
+ dk *= sm_scale
740
+ if DIVISIBLE_N:
741
+ tl.store(dk_ptrs, dk.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL)
742
+ tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL,)
743
+ else:
744
+ tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL)
745
+ tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL,)
746
+
747
+
748
+ @triton.jit
749
+ def _bwd_q_kernel(
750
+ Q, K, V, B, sm_scale, DO,
751
+ DQ,
752
+ L,
753
+ D,
754
+ stride_qz, stride_qh, stride_qm, stride_qk,
755
+ stride_kz, stride_kh, stride_kn, stride_kk,
756
+ stride_vz, stride_vh, stride_vn, stride_vk,
757
+ stride_bz, stride_bh, stride_bm, stride_bn,
758
+ stride_doz, stride_doh, stride_dom, stride_dok,
759
+ stride_dqz, stride_dqh, stride_dqm, stride_dqk,
760
+ Z, H, M, N, P_SEQ,
761
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
762
+ CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
763
+ DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
764
+ HAS_BIAS: tl.constexpr
765
+ ):
766
+ input_dtype = Q.dtype.element_ty
767
+ # -- grid id --
768
+ start_m = tl.program_id(0)
769
+ off_h = tl.program_id(1)
770
+ off_z = tl.program_id(2)
771
+
772
+ # scale sm_scale by log_2(e) and use
773
+ # 2^x instead of exp in the loop because CSE and LICM
774
+ # don't work as expected with `exp` in the loop
775
+ log2e: tl.constexpr = 1.4426950408889634
776
+
777
+ # offset pointers for (batch, head)
778
+ Q += off_z * stride_qz + off_h * stride_qh
779
+ K += off_z * stride_kz + off_h * stride_kh
780
+ V += off_z * stride_vz + off_h * stride_vh
781
+ if HAS_BIAS:
782
+ B += off_z * stride_bz + off_h * stride_bh
783
+ DO += off_z * stride_doz + off_h * stride_doh
784
+ D += (off_z * H + off_h) * M
785
+ L += (off_z * H + off_h) * M
786
+
787
+ # offset pointers for batch/head
788
+ DQ += off_z * stride_dqz + off_h * stride_dqh
789
+
790
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
791
+ offs_n_base = tl.arange(0, BLOCK_N)
792
+ offs_n_init = offs_n_base
793
+ offs_k = tl.arange(0, BLOCK_DMODEL)
794
+
795
+ # initialize pointers to value-like data
796
+ q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
797
+ k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
798
+ v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
799
+
800
+ if HAS_BIAS:
801
+ bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n_init[None, :] * stride_bn)
802
+
803
+ dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) # (BLOCK_M, BLOCK_DMODEL)
804
+ do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
805
+
806
+ # pointer to row-wise quantities in value-like data
807
+ d_ptrs = D + offs_m
808
+ l_ptrs = L + offs_m
809
+
810
+ # load q: it will stay in SRAM throughout
811
+ mask_m = offs_m < M
812
+ if DIVISIBLE_M:
813
+ q = tl.load(q_ptrs)
814
+ do = tl.load(do_ptrs)
815
+ delta = tl.load(d_ptrs)
816
+ l = tl.load(l_ptrs)
817
+ else:
818
+ q = tl.load(q_ptrs, mask=mask_m[:, None])
819
+ do = tl.load(do_ptrs, mask=mask_m[:, None])
820
+ delta = tl.load(d_ptrs, mask=mask_m)
821
+ l = tl.load(l_ptrs, mask=mask_m)
822
+
823
+ # initialize dq
824
+ dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
825
+
826
+ # loop over k, v and update accumulator
827
+ # see note "Loop-Bound-For-N"
828
+ if CAUSAL:
829
+ hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
830
+ if LARGER_M:
831
+ hi = tl.maximum(0, hi)
832
+ else:
833
+ hi = N
834
+
835
+ # loop over a row
836
+ for start_n in range(0, hi, BLOCK_N):
837
+ offs_n = start_n + offs_n_base
838
+
839
+ # load k1, k2, v on chip
840
+ mask_n = offs_n < N
841
+ if DIVISIBLE_N:
842
+ v = tl.load(v_ptrs)
843
+ k = tl.load(k_ptrs)
844
+ else:
845
+ v = tl.load(v_ptrs, mask=mask_n[:, None])
846
+ k = tl.load(k_ptrs, mask=mask_n[:, None])
847
+
848
+ # load bias
849
+ if HAS_BIAS:
850
+ if DIVISIBLE_M and DIVISIBLE_N:
851
+ b = tl.load(bias_ptrs)
852
+ else:
853
+ b = tl.load(bias_ptrs, mask=mask_m[:, None] & mask_n[None, :])
854
+
855
+ # recompute p = softmax(qk * sm_scale, dim=-1)
856
+ if not DIVISIBLE_N:
857
+ valid_mask = mask_n # & mask_m[:, None]
858
+ if CAUSAL:
859
+ causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
860
+
861
+ s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
862
+ s += tl.dot(q, tl.trans(k)) * sm_scale
863
+ if HAS_BIAS:
864
+ s += b
865
+
866
+ # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
867
+ # So masking on s is not needed.
868
+ # if CAUSAL:
869
+ # s = tl.where(causal_mask & valid_mask, s, float("-inf"))
870
+ # else:
871
+ # s = tl.where(valid_mask, s, float("-inf"))
872
+ p = tl.math.exp2((s - l[:, None])*log2e) # (BLOCK_M, BLOCK_N)
873
+
874
+ # compute dp = dot(v, do)
875
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
876
+ dp += tl.dot(do.to(input_dtype), tl.trans(v))
877
+ # no need to mask dp
878
+ # if CAUSAL:
879
+ # dp = tl.where(causal_mask & valid_mask, dp, 0.0)
880
+ # else:
881
+ # dp = tl.where(valid_mask, dp, 0.0)
882
+
883
+ # compute ds = p * (dp - delta[:, None])
884
+ # move scale out to dq at last
885
+ ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
886
+
887
+ # mask ds to ensure no small values
888
+ if not DIVISIBLE_N:
889
+ ds = tl.where(valid_mask, ds, 0.0)
890
+ if CAUSAL:
891
+ ds = tl.where(causal_mask, ds, 0.0)
892
+
893
+ dq += tl.dot(ds.to(input_dtype), k)
894
+
895
+ # increment pointers
896
+ k_ptrs += BLOCK_N * stride_kn
897
+ v_ptrs += BLOCK_N * stride_vn
898
+ if HAS_BIAS:
899
+ bias_ptrs += BLOCK_N * stride_bn
900
+
901
+ dq *= sm_scale
902
+ if DIVISIBLE_M:
903
+ tl.store(dq_ptrs, dq.to(input_dtype))
904
+ else:
905
+ tl.store(dq_ptrs, dq.to(input_dtype), mask=mask_m[:, None])
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 3,
6
+ "transformers_version": "4.46.0.dev0"
7
+ }
modeling_flash_t5.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
2
+
3
+ import copy
4
+ import math
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+ from transformers.modeling_utils import ModuleUtilsMixin
12
+ from transformers.modeling_outputs import ModelOutput, Seq2SeqModelOutput, BaseModelOutput, Seq2SeqLMOutput
13
+ from transformers import PreTrainedModel
14
+
15
+ try:
16
+ from .rms_norm import fast_rms_layernorm
17
+ except ImportError:
18
+ fast_rms_layernorm = None
19
+
20
+ try:
21
+ from .cross_entropy_loss import cross_entropy_loss as fast_cross_entropy_loss
22
+ except ImportError:
23
+ fast_cross_entropy_loss = None
24
+
25
+ try:
26
+ from .flash_attention_v2_bias import flash_attention_v2_bias
27
+ except ImportError:
28
+ flash_attention_v2_bias = None
29
+
30
+ try:
31
+ from flash_attn import flash_attn_kvpacked_func, flash_attn_func
32
+ except ImportError:
33
+ flash_attn_kvpacked_func, flash_attn_func = None, None
34
+
35
+ from .attn_ref import attn_ref
36
+
37
+ from .configuration_flash_t5 import FlashT5Config
38
+ from .positional_encoding import ALiBiPositionalEncoding, RelativePositionalEncoding, RotaryPositionalEncoding, FIRE
39
+
40
+ class FlashT5CrossEntropyLoss(nn.Module):
41
+ def __init__(self, z_loss_factor=0.0, label_smoothing=0.0, use_triton_crossentropy=False, inplace_backward=False):
42
+
43
+ super().__init__()
44
+
45
+ if use_triton_crossentropy and fast_cross_entropy_loss is None:
46
+ raise ImportError("fast_cross_entropy_loss is not available")
47
+
48
+ self.use_triton_crossentropy = use_triton_crossentropy
49
+ self.z_loss_factor = z_loss_factor
50
+ self.label_smoothing = label_smoothing
51
+ self.inplace_backward = inplace_backward
52
+
53
+ self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
54
+
55
+ def compute_zloss(self, logits: torch.Tensor, z_loss: float):
56
+ logits_sum = torch.logsumexp(logits, dim=-1, keepdim=True)
57
+ log_z = torch.squeeze(logits_sum, axis=-1)
58
+ total_z_loss = z_loss * torch.square(log_z)
59
+ return total_z_loss.mean()
60
+
61
+ def forward(self, logits, labels):
62
+
63
+ if self.use_triton_crossentropy:
64
+ return fast_cross_entropy_loss(logits, labels, \
65
+ lse_square_scale=self.z_loss_factor, \
66
+ label_smoothing=self.label_smoothing, \
67
+ inplace_backward=self.inplace_backward \
68
+ )[0].mean()
69
+
70
+ # use standard method
71
+ batch, seq_len, d = logits.shape
72
+ logits_flatten = logits.float().view(batch*seq_len, d) # Must cast to float32 for numerical stability
73
+ labels_flatten = labels.view(-1)
74
+ loss = self.cross_entropy_loss(logits_flatten, labels_flatten)
75
+ z_loss = 0.0
76
+ if self.z_loss_factor != 0.0:
77
+ z_loss = self.compute_zloss(logits_flatten[labels_flatten != -100],
78
+ z_loss=self.z_loss_factor)
79
+ return loss + z_loss
80
+
81
+ class FlashT5LayerNorm(nn.Module):
82
+ def __init__(self, hidden_size, eps=1e-6, use_triton_layernorm=False):
83
+ """
84
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
85
+ """
86
+ super().__init__()
87
+
88
+ if use_triton_layernorm and fast_rms_layernorm is None:
89
+ raise ImportError("fast_rms_layernorm is not available")
90
+
91
+ self.use_triton_layernorm = use_triton_layernorm
92
+ self.weight = nn.Parameter(torch.ones(hidden_size))
93
+ self.variance_epsilon = eps
94
+
95
+ def forward(self, hidden_states):
96
+
97
+ if self.use_triton_layernorm:
98
+ return fast_rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
99
+
100
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
101
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
102
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
103
+ # half-precision inputs is done in fp32
104
+
105
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
106
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
107
+
108
+ # convert into half-precision if necessary
109
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
110
+ hidden_states = hidden_states.to(self.weight.dtype)
111
+
112
+ return self.weight * hidden_states
113
+
114
+ class FlashT5DenseAct(nn.Module):
115
+ def __init__(self, config: FlashT5Config):
116
+ super().__init__()
117
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
118
+ self.dropout = nn.Dropout(config.dropout_rate)
119
+ self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
120
+
121
+ def forward(self, hidden_states):
122
+ hidden_states = self.wi(hidden_states)
123
+ hidden_states = self.act(hidden_states)
124
+ hidden_states = self.dropout(hidden_states)
125
+ if (
126
+ isinstance(self.wo.weight, torch.Tensor)
127
+ and hidden_states.dtype != self.wo.weight.dtype
128
+ and self.wo.weight.dtype != torch.int8
129
+ ):
130
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
131
+
132
+ return hidden_states
133
+
134
+ class FlashT5DenseGatedAct(nn.Module):
135
+ def __init__(self, config: FlashT5Config):
136
+ super().__init__()
137
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
138
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
139
+ self.dropout = nn.Dropout(config.dropout_rate)
140
+ self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
141
+
142
+ self.use_gelu_act = config.use_gelu_act
143
+
144
+ def forward(self, hidden_states):
145
+
146
+ hidden_act = self.act(self.wi_0(hidden_states))
147
+ hidden_linear = self.wi_1(hidden_states)
148
+ hidden_states = hidden_act * hidden_linear
149
+ hidden_states = self.dropout(hidden_states)
150
+
151
+ return hidden_states
152
+
153
+ class FlashT5LayerFF(nn.Module):
154
+ def __init__(self, config: FlashT5Config):
155
+ super().__init__()
156
+ if config.use_glu_mlp:
157
+ self.act = FlashT5DenseGatedAct(config)
158
+ else:
159
+ self.act = FlashT5DenseAct(config)
160
+
161
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
162
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
163
+ self.dropout = nn.Dropout(config.dropout_rate)
164
+
165
+ def forward(self, hidden_states):
166
+ forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states)
167
+ forwarded_states = self.act(forwarded_states)
168
+ forwarded_states = self.wo(forwarded_states)
169
+ hidden_states = hidden_states + self.dropout(forwarded_states)
170
+ return hidden_states
171
+
172
+
173
+ class FlashT5Attention(nn.Module, ModuleUtilsMixin):
174
+ def __init__(self, config: FlashT5Config, has_positional_encoding=False, is_causal=False):
175
+ super().__init__()
176
+ self.is_decoder = config.is_decoder
177
+ self.has_positional_encoding = has_positional_encoding
178
+ self.is_causal = is_causal
179
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
180
+ self.relative_attention_max_distance = config.relative_attention_max_distance
181
+ self.d_model = config.d_model
182
+ self.key_value_proj_dim = config.d_kv
183
+ self.n_heads = config.num_heads
184
+ self.p_dropout = config.attention_dropout_rate
185
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
186
+ self.attention_type = config.attention_type
187
+ self.position_encoding_type = config.position_encoding_type
188
+ self.max_sequence_length = config.max_sequence_length
189
+ self.softmax_scale = config.attention_scale if config.attention_scale is not None else 1.0/math.sqrt(self.n_heads)
190
+ self.use_full_bias_size = config.use_full_bias_size
191
+ self.use_masking = config.use_masking
192
+
193
+ if self.use_masking and not self.use_full_bias_size:
194
+ raise ValueError("Masking can only be used with full batch size.")
195
+
196
+ if self.attention_type == "triton" and flash_attention_v2_bias is None:
197
+ raise ImportError("flash_attention_triton is not available")
198
+ elif self.attention_type.startswith("fa2") and flash_attn_func is None:
199
+ raise ImportError("Flash Attention 2 is not available")
200
+
201
+ if self.attention_type == "fa2_rpe" and self.position_encoding_type != "t5":
202
+ raise ValueError("fa2_rpe is not compatible with non-T5 position encoding")
203
+
204
+ assert (self.p_dropout == 0.0) or (self.attention_type != "triton"), "Triton attention does not support dropout"
205
+
206
+ self.pe_encoding = None
207
+ if self.position_encoding_type == "ALiBi" and has_positional_encoding:
208
+ # build alibi matrix with an upper bound on seq length
209
+ self.pe_encoding = ALiBiPositionalEncoding(self.max_sequence_length,
210
+ self.n_heads,
211
+ config.alibi_mode,
212
+ randomized_position=config.use_randomized_position_encoding)
213
+ elif self.position_encoding_type == "t5" and has_positional_encoding:
214
+ self.pe_encoding = RelativePositionalEncoding(self.relative_attention_num_buckets,
215
+ self.relative_attention_max_distance,
216
+ self.n_heads,
217
+ self.max_sequence_length,
218
+ bidirectional=(not self.is_decoder),
219
+ randomized_position=config.use_randomized_position_encoding)
220
+ elif self.position_encoding_type == "RoPE":
221
+ self.pe_encoding = RotaryPositionalEncoding(int(self.key_value_proj_dim * config.rotary_emb_fraction),
222
+ self.max_sequence_length,
223
+ config.rotary_base,
224
+ config.rotary_interleaved,
225
+ config.rotary_scale_base,
226
+ randomized_position=config.use_randomized_position_encoding)
227
+ elif self.position_encoding_type == "FIRE" and has_positional_encoding:
228
+ self.pe_encoding = FIRE(num_heads=self.n_heads,
229
+ mlp_width=config.fire_mlp_width,
230
+ init_c=0.1,
231
+ init_L=self.relative_attention_max_distance)
232
+
233
+ self.Wq = nn.Linear(self.d_model, self.inner_dim, bias=False)
234
+ self.Wk = nn.Linear(self.d_model, self.inner_dim, bias=False)
235
+ self.Wv = nn.Linear(self.d_model, self.inner_dim, bias=False)
236
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states,
241
+ mask=None,
242
+ key_value_states=None,
243
+ position_bias=None,
244
+ ):
245
+ """
246
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
247
+ """
248
+ # Input is (batch_size, seq_length, dim)
249
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
250
+ batch_size, seq_length = hidden_states.shape[:2]
251
+ key_length = seq_length if key_value_states is None else key_value_states.shape[1]
252
+ q = self.Wq(hidden_states)
253
+ if key_value_states is None:
254
+ k = self.Wk(hidden_states)
255
+ v = self.Wv(hidden_states)
256
+ else:
257
+ k = self.Wk(key_value_states)
258
+ v = self.Wv(key_value_states)
259
+
260
+ q = q.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim)
261
+ k = k.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
262
+ v = v.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
263
+
264
+ if position_bias is None and self.pe_encoding is not None and self.attention_type != "fa2_rpe":
265
+ q, k, v, position_bias = self.pe_encoding(q, k, v)
266
+
267
+ if position_bias is not None and self.use_full_bias_size:
268
+ position_bias = position_bias.expand(q.shape[0], q.shape[2], q.shape[1], k.shape[1])
269
+ if self.attention_type == "fa2_bias" or self.attention_type == "triton":
270
+ position_bias = position_bias.contiguous()
271
+
272
+ if position_bias is not None and mask is not None and self.use_masking:
273
+ mask = mask.unsqueeze(1)
274
+ if len(mask.shape) == 3:
275
+ mask = mask.unsqueeze(3)
276
+ position_bias = torch.where(mask, position_bias, torch.finfo(hidden_states.dtype).min)
277
+
278
+ if self.attention_type == "fa2_bias":
279
+ output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, \
280
+ attn_bias=position_bias, causal=self.is_causal)
281
+ elif self.attention_type == "fa2_rpe":
282
+ output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, \
283
+ rpe_weights=self.pe_encoding.relative_attention_bias.weight.t(), \
284
+ rpe_max_distance=self.relative_attention_max_distance, \
285
+ causal=self.is_causal)
286
+ elif self.attention_type == "triton":
287
+ q = q.permute(0, 2, 1, 3)
288
+ k = k.permute(0, 2, 1, 3)
289
+ v = v.permute(0, 2, 1, 3)
290
+ output = flash_attention_v2_bias(q, k, v, position_bias, self.is_causal, self.softmax_scale)
291
+ output = output.permute(0, 2, 1, 3)
292
+ else: # use flash attention
293
+ q = q.permute(0, 2, 1, 3)
294
+ k = k.permute(0, 2, 1, 3)
295
+ v = v.permute(0, 2, 1, 3)
296
+ output = attn_ref(q, k, v, position_bias, dropout_p=self.p_dropout, sm_scale=self.softmax_scale, causal=self.is_causal)
297
+ output = output.permute(0, 2, 1, 3)
298
+
299
+ output = self.o(output.reshape(output.shape[0], output.shape[1], self.inner_dim))
300
+ return (output, position_bias)
301
+
302
+
303
+ class FlashT5LayerSelfAttention(nn.Module):
304
+ def __init__(self, config, has_positional_encoding=False):
305
+ super().__init__()
306
+ self.self_attention = FlashT5Attention(config, has_positional_encoding=has_positional_encoding, is_causal=config.is_decoder)
307
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
308
+ self.dropout = nn.Dropout(config.dropout_rate)
309
+
310
+ def forward(
311
+ self,
312
+ hidden_states,
313
+ attention_mask=None,
314
+ position_bias=None,
315
+ ):
316
+ normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states)
317
+ attention_output = self.self_attention(
318
+ normed_hidden_states,
319
+ mask=attention_mask,
320
+ position_bias=position_bias,
321
+ )
322
+ hidden_states = hidden_states + self.dropout(attention_output[0])
323
+ outputs = (hidden_states,) + attention_output[1:]
324
+ return outputs
325
+
326
+
327
+ class FlashT5LayerCrossAttention(nn.Module):
328
+ def __init__(self, config):
329
+ super().__init__()
330
+ self.cross_attention = FlashT5Attention(config, has_positional_encoding=False)
331
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
332
+ self.dropout = nn.Dropout(config.dropout_rate)
333
+
334
+ def forward(
335
+ self,
336
+ hidden_states,
337
+ key_value_states,
338
+ attention_mask=None,
339
+ position_bias=None,
340
+ ):
341
+ normed_hidden_states = self.layer_norm(hidden_states)
342
+ attention_output = self.cross_attention(
343
+ normed_hidden_states,
344
+ mask=attention_mask,
345
+ key_value_states=key_value_states,
346
+ position_bias=position_bias,
347
+ )
348
+ layer_output = hidden_states + self.dropout(attention_output[0])
349
+ outputs = (layer_output,) + attention_output[1:]
350
+ return outputs
351
+
352
+
353
+ class FlashT5Block(nn.Module):
354
+ def __init__(self, config, has_positional_encoding=False):
355
+ super().__init__()
356
+ self.is_decoder = config.is_decoder
357
+
358
+ self.self_attention_layer = FlashT5LayerSelfAttention(config, has_positional_encoding=has_positional_encoding)
359
+
360
+ if self.is_decoder:
361
+ self.cross_attention_layer = FlashT5LayerCrossAttention(config)
362
+
363
+ self.ff_layer = FlashT5LayerFF(config)
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states,
368
+ attention_mask=None,
369
+ position_bias=None,
370
+ encoder_hidden_states=None,
371
+ encoder_attention_mask=None,
372
+ encoder_decoder_position_bias=None,
373
+ ):
374
+ self_attention_outputs = self.self_attention_layer(
375
+ hidden_states,
376
+ attention_mask=attention_mask,
377
+ position_bias=position_bias,
378
+ )
379
+ hidden_states = self_attention_outputs[0]
380
+ attention_outputs = self_attention_outputs[1:] # Relative position weights
381
+
382
+ if self.is_decoder and encoder_hidden_states is not None:
383
+ cross_attention_outputs = self.cross_attention_layer(
384
+ hidden_states,
385
+ key_value_states=encoder_hidden_states,
386
+ attention_mask=encoder_attention_mask,
387
+ position_bias=encoder_decoder_position_bias,
388
+ )
389
+ hidden_states = cross_attention_outputs[0]
390
+
391
+ # Keep relative position weights
392
+ attention_outputs = attention_outputs + cross_attention_outputs[1:]
393
+
394
+ # Apply Feed Forward layer
395
+ hidden_states = self.ff_layer(hidden_states)
396
+
397
+ outputs = (hidden_states,) + attention_outputs
398
+ return outputs # hidden-states, (self-attention position bias), (cross-attention position bias)
399
+
400
+
401
+ class FlashT5Stack(nn.Module, ModuleUtilsMixin):
402
+ def __init__(self, config, embed_tokens):
403
+ super().__init__()
404
+ assert embed_tokens is not None
405
+
406
+ self.config = config
407
+ self.embed_tokens = embed_tokens
408
+ self.is_decoder = config.is_decoder
409
+
410
+ self.block = nn.ModuleList(
411
+ [FlashT5Block(config, has_positional_encoding=bool(i == 0)) for i in range(config.num_layers)]
412
+ )
413
+
414
+ self.final_layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
415
+ self.dropout = nn.Dropout(config.dropout_rate)
416
+
417
+ def forward(
418
+ self,
419
+ input_ids=None,
420
+ # input_ids: Optional[torch.LongTensor] = None,
421
+ attention_mask=None,
422
+ encoder_hidden_states=None,
423
+ encoder_attention_mask=None,
424
+ inputs_embeds=None,
425
+ head_mask=None,
426
+ cross_attn_head_mask=None,
427
+ past_key_values=None,
428
+ use_cache=None,
429
+ output_attentions=None,
430
+ output_hidden_states=None,
431
+ return_dict=None,
432
+ ) -> BaseModelOutput:
433
+ input_shape = input_ids.size()
434
+ batch_size, seq_length = input_shape
435
+
436
+ if inputs_embeds is None:
437
+ inputs_embeds = self.embed_tokens(input_ids)
438
+
439
+ if torch.is_autocast_enabled() and input_ids.device.type == 'cuda':
440
+ inputs_embeds = inputs_embeds.to(torch.get_autocast_gpu_dtype())
441
+
442
+ # Masking
443
+ if attention_mask is None:
444
+ attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=torch.bool)
445
+
446
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
447
+ encoder_seq_length = encoder_hidden_states.shape[1]
448
+ encoder_attention_mask = torch.ones(
449
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.bool
450
+ )
451
+
452
+ position_bias = None
453
+ encoder_decoder_position_bias = None
454
+
455
+ hidden_states = self.dropout(inputs_embeds)
456
+
457
+ for _, layer_module in enumerate(self.block):
458
+ layer_outputs = layer_module(
459
+ hidden_states,
460
+ attention_mask=attention_mask,
461
+ position_bias=position_bias,
462
+ encoder_hidden_states=encoder_hidden_states,
463
+ encoder_attention_mask=encoder_attention_mask,
464
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
465
+ )
466
+
467
+ # We share the position biases between the layers - the first layer store them
468
+ position_bias = layer_outputs[1]
469
+ if self.is_decoder and encoder_hidden_states is not None:
470
+ encoder_decoder_position_bias = layer_outputs[2]
471
+
472
+ hidden_states = layer_outputs[0]
473
+
474
+ hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states)
475
+ hidden_states = self.dropout(hidden_states)
476
+
477
+ return BaseModelOutput(
478
+ last_hidden_state=hidden_states
479
+ )
480
+
481
+
482
+
483
+ class FlashT5PreTrainedModel(PreTrainedModel):
484
+ """
485
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
486
+ models.
487
+ """
488
+
489
+ config_class = FlashT5Config
490
+ base_model_prefix = "transformer"
491
+ is_parallelizable = False
492
+ supports_gradient_checkpointing = True
493
+ _no_split_modules = ["FlashT5Block"]
494
+ _keep_in_fp32_modules = []
495
+
496
+ def _init_weights(self, module):
497
+ factor = self.config.initializer_factor # Used for testing weights initialization
498
+ if isinstance(module, FlashT5LayerNorm):
499
+ module.weight.data.fill_(factor * 1.0)
500
+ elif isinstance(module, (FlashT5ForConditionalGeneration)):
501
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
502
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
503
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * self.config.d_model ** -0.5)
504
+ elif isinstance(module, FlashT5DenseGatedAct):
505
+ d_ff, d_model = module.wi_0.weight.data.size()
506
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
507
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
508
+ elif isinstance(module, FlashT5LayerFF):
509
+ d_ff, d_model = module.wo.weight.data.size()
510
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
511
+ elif isinstance(module, FlashT5Attention):
512
+ d_model = self.config.d_model
513
+ key_value_proj_dim = self.config.d_kv
514
+ n_heads = self.config.num_heads
515
+ module.Wq.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
516
+ module.Wk.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
517
+ module.Wv.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
518
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
519
+ if module.has_positional_encoding:
520
+ if hasattr(module.pe_encoding, "relative_attention_bias"):
521
+ module.pe_encoding.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
522
+
523
+ def _shift_right(self, input_ids):
524
+ decoder_start_token_id = self.config.decoder_start_token_id
525
+ pad_token_id = self.config.pad_token_id
526
+
527
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
528
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
529
+ shifted_input_ids[..., 0] = decoder_start_token_id
530
+
531
+ # replace possible -100 values in labels by `pad_token_id`
532
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
533
+
534
+ return shifted_input_ids
535
+
536
+
537
+ class FlashT5Model(FlashT5PreTrainedModel):
538
+
539
+ def __init__(self, config: FlashT5Config):
540
+ super().__init__(config)
541
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
542
+
543
+ encoder_config = copy.deepcopy(config)
544
+ encoder_config.is_decoder = False
545
+ encoder_config.use_cache = False
546
+ encoder_config.is_encoder_decoder = False
547
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
548
+
549
+ decoder_config = copy.deepcopy(config)
550
+ decoder_config.is_decoder = True
551
+ decoder_config.is_encoder_decoder = False
552
+ decoder_config.num_layers = config.num_decoder_layers
553
+ self.decoder = FlashT5Stack(decoder_config, self.shared)
554
+
555
+ # Initialize weights and apply final processing
556
+ self.post_init()
557
+
558
+ # Model parallel
559
+ self.model_parallel = False
560
+ self.device_map = None
561
+
562
+ def get_input_embeddings(self):
563
+ return self.shared
564
+
565
+ def set_input_embeddings(self, new_embeddings):
566
+ self.shared = new_embeddings
567
+ self.encoder.set_input_embeddings(new_embeddings)
568
+ self.decoder.set_input_embeddings(new_embeddings)
569
+
570
+ def get_encoder(self):
571
+ return self.encoder
572
+
573
+ def get_decoder(self):
574
+ return self.decoder
575
+
576
+ def forward(
577
+ self,
578
+ input_ids=None,
579
+ # input_ids: Optional[torch.LongTensor] = None,
580
+ attention_mask=None,
581
+ encoder_hidden_states=None,
582
+ encoder_attention_mask=None,
583
+ inputs_embeds=None,
584
+ head_mask=None,
585
+ cross_attn_head_mask=None,
586
+ past_key_values=None,
587
+ use_cache=None,
588
+ output_attentions=None,
589
+ output_hidden_states=None,
590
+ return_dict=None,
591
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
592
+
593
+ # Encode if needed (training, first prediction pass)
594
+ if encoder_outputs is None:
595
+ encoder_outputs = self.encoder(
596
+ input_ids=input_ids,
597
+ attention_mask=attention_mask,
598
+ inputs_embeds=inputs_embeds
599
+ )
600
+
601
+ hidden_states = encoder_outputs[0]
602
+
603
+ # Decode
604
+ decoder_outputs = self.decoder(
605
+ input_ids=decoder_input_ids,
606
+ attention_mask=decoder_attention_mask,
607
+ inputs_embeds=decoder_inputs_embeds,
608
+ encoder_hidden_states=hidden_states,
609
+ encoder_attention_mask=attention_mask
610
+ )
611
+
612
+ return Seq2SeqModelOutput(
613
+ last_hidden_state=decoder_outputs.last_hidden_state,
614
+ decoder_hidden_states=decoder_outputs.hidden_states,
615
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
616
+ encoder_hidden_states=encoder_outputs.hidden_states,
617
+ )
618
+
619
+ class FlashT5ForConditionalGeneration(FlashT5PreTrainedModel):
620
+
621
+ def __init__(self, config: FlashT5Config):
622
+ super().__init__(config)
623
+ config.is_encoder_decoder = False
624
+ assert not config.tie_word_embeddings
625
+
626
+ self.config = config
627
+ self.model_dim = config.d_model
628
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
629
+
630
+ encoder_config = copy.deepcopy(config)
631
+ encoder_config.is_decoder = False
632
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
633
+
634
+ decoder_config = copy.deepcopy(config)
635
+ decoder_config.is_decoder = True
636
+ decoder_config.num_layers = config.num_decoder_layers
637
+ self.decoder = FlashT5Stack(decoder_config, self.shared)
638
+
639
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
640
+
641
+ self.loss_fct = FlashT5CrossEntropyLoss(z_loss_factor=config.z_loss,
642
+ label_smoothing=config.label_smoothing,
643
+ use_triton_crossentropy=config.use_triton_crossentropy,
644
+ inplace_backward=config.crossentropy_inplace_backward)
645
+
646
+ # Initialize weights and apply final processing
647
+ self.post_init()
648
+
649
+ def prepare_inputs_for_generation(
650
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
651
+ ):
652
+ # do nothing
653
+ model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
654
+
655
+ return model_inputs
656
+
657
+ def get_input_embeddings(self):
658
+ return self.shared
659
+
660
+ def set_input_embeddings(self, value):
661
+ self.shared = value
662
+
663
+ def generate(
664
+ self,
665
+ input_ids: Optional[torch.LongTensor] = None,
666
+ attention_mask: Optional[torch.FloatTensor] = None,
667
+ max_length = 32,
668
+ **kwargs,
669
+ ) -> torch.LongTensor:
670
+ """
671
+ input_ids: B x L_encoder, int64
672
+ attention_mask: B x L_encoder, int64
673
+ 1 for tokens to attend to, 0 for tokens to ignore
674
+
675
+ Generation:
676
+ Starts with 0, ends with 1, padding is 0
677
+
678
+ # For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s
679
+ """
680
+ B, _ = input_ids.size()
681
+ labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device)
682
+ encoder_hidden_states = None
683
+
684
+ for _ in range(max_length):
685
+ out = self.forward(
686
+ input_ids=input_ids,
687
+ attention_mask=attention_mask,
688
+ decoder_input_ids=labels,
689
+ encoder_hidden_states=encoder_hidden_states,
690
+ )
691
+ encoder_hidden_states = out.encoder_hidden_states
692
+ top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1)
693
+ labels = torch.cat([labels, top_labels], dim=-1)
694
+
695
+ if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B:
696
+ break
697
+
698
+ labels[:, -1] = 1
699
+
700
+ # Mask out the padding, i.e., all positions after the first 1 with 0
701
+ B, L = labels.size()
702
+ mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1)
703
+ labels = labels.masked_fill(~mask, 0)
704
+
705
+ return labels
706
+
707
+ def forward(
708
+ self,
709
+ input_ids: Optional[torch.LongTensor] = None,
710
+ attention_mask: Optional[torch.FloatTensor] = None,
711
+ decoder_input_ids: Optional[torch.LongTensor] = None,
712
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
713
+ labels: Optional[torch.LongTensor] = None,
714
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
715
+ ) -> Seq2SeqLMOutput:
716
+ """
717
+ input_ids: B x L_encoder, int64
718
+ attention_mask: B x L_encoder, int64
719
+ 1 for tokens to attend to, 0 for tokens to ignore
720
+ labels: B x L_decoder, int64
721
+ """
722
+ if encoder_hidden_states is None:
723
+ encoder_hidden_states = self.encoder(
724
+ input_ids=input_ids,
725
+ attention_mask=attention_mask,
726
+ )[0]
727
+
728
+ hidden_states = encoder_hidden_states
729
+
730
+ if labels is not None and decoder_input_ids is None:
731
+ decoder_input_ids = self._shift_right(labels)
732
+
733
+ decoder_outputs = self.decoder(
734
+ input_ids=decoder_input_ids,
735
+ attention_mask=decoder_attention_mask,
736
+ encoder_hidden_states=hidden_states,
737
+ encoder_attention_mask=attention_mask,
738
+ )
739
+
740
+ sequence_output = decoder_outputs[0]
741
+ lm_logits = self.lm_head(sequence_output)
742
+
743
+ loss = None
744
+ if labels is not None:
745
+ loss = self.loss_fct(lm_logits, labels)
746
+
747
+ return Seq2SeqLMOutput(
748
+ loss=loss,
749
+ logits=lm_logits,
750
+ encoder_hidden_states=encoder_hidden_states
751
+ )
752
+
753
+
754
+ class FlashT5EncoderModel(FlashT5PreTrainedModel):
755
+ def __init__(self, config: FlashT5Config):
756
+ super().__init__(config)
757
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
758
+ encoder_config = copy.deepcopy(config)
759
+ encoder_config.use_cache = False
760
+ encoder_config.is_encoder_decoder = False
761
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
762
+ # Initialize weights and apply final processing
763
+ self.post_init()
764
+ # Model parallel
765
+ self.model_parallel = False
766
+ self.device_map = None
767
+ def get_input_embeddings(self):
768
+ return self.shared
769
+ def set_input_embeddings(self, new_embeddings):
770
+ self.shared = new_embeddings
771
+ self.encoder.set_input_embeddings(new_embeddings)
772
+ def get_encoder(self):
773
+ return self.encoder
774
+ def forward(
775
+ self,
776
+ input_ids: Optional[torch.LongTensor] = None,
777
+ attention_mask: Optional[torch.FloatTensor] = None,
778
+ head_mask: Optional[torch.FloatTensor] = None,
779
+ inputs_embeds: Optional[torch.FloatTensor] = None,
780
+ output_attentions: Optional[bool] = None,
781
+ output_hidden_states: Optional[bool] = None,
782
+ return_dict: Optional[bool] = None,
783
+ token_type_ids: Optional[bool] = None,
784
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
785
+ encoder_outputs = self.encoder(
786
+ input_ids=input_ids,
787
+ attention_mask=attention_mask,
788
+ inputs_embeds=inputs_embeds
789
+ )
790
+ return encoder_outputs
optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a61ade6d36e7273c43a001cc4e665c4bae25570aecff45ed23f189b8a2b8687
3
+ size 1174905530
positional_encoding.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange, repeat
5
+ try:
6
+ from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_func, apply_rotary_emb_kv_
7
+ except:
8
+ apply_rotary_emb_qkv_, apply_rotary_emb_func, apply_rotary_emb_kv_ = None, None, None
9
+
10
+ class RelativePositionalEncoding(nn.Module):
11
+
12
+ def __init__(self, relative_attention_num_buckets, relative_attention_max_distance, n_heads, max_sequence_length, bidirectional=True, randomized_position=False):
13
+
14
+ super().__init__()
15
+
16
+ self.relative_attention_num_buckets = relative_attention_num_buckets
17
+ self.relative_attention_max_distance = relative_attention_max_distance
18
+ self.n_heads = n_heads
19
+ self.max_sequence_length = max_sequence_length
20
+ self.bidirectional = bidirectional
21
+ self.randomized_position = randomized_position
22
+
23
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
24
+
25
+ @staticmethod
26
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
27
+ """
28
+ Adapted from Mesh Tensorflow:
29
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
30
+
31
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
32
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
33
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
34
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
35
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
36
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
37
+
38
+ Args:
39
+ relative_position: an int32 Tensor
40
+ bidirectional: a boolean - whether the attention is bidirectional
41
+ num_buckets: an integer
42
+ max_distance: an integer
43
+
44
+ Returns:
45
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
46
+ """
47
+ relative_buckets = 0
48
+ if bidirectional:
49
+ num_buckets //= 2
50
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
51
+ relative_position = torch.abs(relative_position)
52
+ else:
53
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
54
+ # now relative_position is in the range [0, inf)
55
+
56
+ # half of the buckets are for exact increments in positions
57
+ max_exact = num_buckets // 2
58
+ is_small = relative_position < max_exact
59
+
60
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
61
+ relative_position_if_large = max_exact + (
62
+ torch.log(relative_position.float() / max_exact)
63
+ / torch.log(torch.tensor(max_distance / max_exact))
64
+ * (num_buckets - max_exact)
65
+ ).to(torch.long)
66
+ relative_position_if_large = torch.min(
67
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
68
+ )
69
+
70
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
71
+ return relative_buckets
72
+
73
+ def compute_bias(self, query_length, key_length, device=None):
74
+ """Compute binned relative position bias"""
75
+ if device is None:
76
+ device = self.relative_attention_bias.weight.device
77
+
78
+ if self.randomized_position:
79
+ context_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device)
80
+ context_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length])
81
+ context_indices_rand[0] = 0 # root the first element of the sequence
82
+ context_position = context_position[context_indices_rand][:, None]
83
+
84
+ memory_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device)
85
+ memory_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length])
86
+ memory_indices_rand[0] = 0 # root the first element of the sequence
87
+ memory_position = memory_position[memory_indices_rand][None, :]
88
+ else:
89
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
90
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
91
+
92
+ relative_position = memory_position - context_position # shape (query_length, key_length)
93
+
94
+ relative_position_bucket = self._relative_position_bucket(
95
+ relative_position, # shape (query_length, key_length)
96
+ bidirectional=self.bidirectional,
97
+ num_buckets=self.relative_attention_num_buckets,
98
+ max_distance=self.relative_attention_max_distance,
99
+ )
100
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
101
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
102
+ return values
103
+
104
+ def forward(self, q, k=None, v=None):
105
+
106
+ query_length = q.shape[1]
107
+ key_length = k.shape[1] if k is not None else query_length
108
+ bias = self.compute_bias(query_length, key_length, device=q.device).contiguous().to(q.dtype)
109
+
110
+ return q, k, v, bias
111
+
112
+
113
+ class ALiBiPositionalEncoding(nn.Module):
114
+
115
+ def __init__(self, max_sequence_length, num_heads, mode='symetric', randomized_position=False):
116
+
117
+ super().__init__()
118
+
119
+ self.max_sequence_length = max_sequence_length
120
+ self.num_heads = num_heads
121
+ self.mode = mode
122
+ self.randomized_position = randomized_position
123
+
124
+ self.alibi_bias = self.build_alibi_bias_matrix(num_heads, max_sequence_length, mode)
125
+
126
+ @staticmethod
127
+ def fill_with_neg_inf(t):
128
+ """FP16-compatible function that fills a tensor with -inf."""
129
+ return t.float().fill_(float("-inf")).type_as(t)
130
+
131
+ def get_slopes(self, n):
132
+
133
+ def get_slopes_power_of_2(n):
134
+ start = (2**(-2**-(math.log2(n)-3)))
135
+ ratio = start
136
+ return [start*ratio**i for i in range(n)]
137
+
138
+ if math.log2(n).is_integer():
139
+ return get_slopes_power_of_2(n) #In the paper, we only train models that have 2^a heads for some a. This function has
140
+ else: #some good properties that only occur when the input is a power of 2. To maintain that even
141
+ closest_power_of_2 = 2**math.floor(math.log2(n)) #when the number of heads is not a power of 2, we use this workaround.
142
+ return get_slopes_power_of_2(closest_power_of_2) + self.get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
143
+
144
+ def build_symetric_alibi_bias_matrix(self, num_heads, maxpos):
145
+
146
+ context_position = torch.arange(maxpos)[:, None]
147
+ memory_position = torch.arange(maxpos)[None, :]
148
+
149
+ relative_position = memory_position - context_position
150
+ relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads, -1,-1)
151
+
152
+ slopes = torch.Tensor(self.get_slopes(num_heads)) * -1
153
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
154
+ return alibi.view(1, num_heads, maxpos, maxpos)
155
+
156
+ def build_asymetric_alibi_bias_matrix(self, num_heads, maxpos):
157
+ _future_mask_right = torch.triu(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).unsqueeze(0).repeat(num_heads // 2, 1, 1)
158
+ _future_mask_left = torch.tril(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), -1).unsqueeze(0).repeat(num_heads // 2, 1, 1)
159
+
160
+ nonsym_mask = torch.cat((_future_mask_right, _future_mask_left), dim = 0).unsqueeze(0)
161
+ slopes = torch.Tensor(self.get_slopes(num_heads // 2)) * -1
162
+
163
+ context_position = torch.arange(maxpos)[:, None]
164
+ memory_position = torch.arange(maxpos)[None, :]
165
+
166
+ relative_position = memory_position - context_position
167
+ relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads // 2, -1,-1)
168
+
169
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
170
+ alibi = alibi.view(1, num_heads // 2, maxpos, maxpos)
171
+ alibi = alibi.repeat(1, 2, 1, 1)
172
+
173
+ return alibi.view(1, num_heads, maxpos, maxpos) + nonsym_mask.view(1, num_heads, maxpos, maxpos)
174
+
175
+
176
+ def build_alibi_bias_matrix(self, num_heads, maxpos, mode='symetric'):
177
+ if mode == 'symetric':
178
+ return self.build_symetric_alibi_bias_matrix(num_heads, maxpos)
179
+ elif mode == 'asymetric':
180
+ return self.build_asymetric_alibi_bias_matrix(num_heads, maxpos)
181
+ else:
182
+ raise ValueError("ALiBi mode " + mode + " is not implemented.")
183
+
184
+ def forward(self, q, k=None, v=None):
185
+
186
+ query_length = q.shape[1]
187
+ key_length = k.shape[1] if k is not None else query_length
188
+ assert (self.alibi_bias.shape[1] < query_length) & (self.alibi_bias.shape[1] < key_length), "Sequence length larger than allowed alibi bound"
189
+
190
+ if self.randomized_position:
191
+ query_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length])
192
+ key_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length])
193
+
194
+ # ground sequences
195
+ query_indices_rand[0] = 0
196
+ key_indices_rand[0] = 0
197
+
198
+ bias = self.alibi_bias[:, :, query_indices_rand, key_indices_rand].to(q.device)
199
+
200
+ else:
201
+ bias = self.alibi_bias[:, :, :query_length, :key_length].to(q.device)
202
+
203
+ return q, k, v, bias.to(q.dtype).contiguous()
204
+
205
+ class RotaryPositionalEncoding(nn.Module):
206
+
207
+ def __init__(self, dim,
208
+ max_sequence_length,
209
+ base=10000.0,
210
+ interleaved=False,
211
+ scale_base=None,
212
+ randomized_position=False):
213
+
214
+ super().__init__()
215
+
216
+ self.max_sequence_length = max_sequence_length
217
+ self.randomized_position = randomized_position
218
+
219
+ self.dim = dim
220
+ self.base = base
221
+ self.interleaved = interleaved
222
+ self.scale_base = scale_base
223
+
224
+ inv_freq = self._compute_inv_freq()
225
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
226
+
227
+ scale = (
228
+ (torch.arange(0, dim, 2, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
229
+ if scale_base is not None
230
+ else None
231
+ )
232
+ self.register_buffer("scale", scale, persistent=False)
233
+
234
+ self._cos_cached = None
235
+ self._sin_cached = None
236
+ self._cos_k_cached = None
237
+ self._sin_k_cached = None
238
+
239
+ def _compute_inv_freq(self, device=None):
240
+ return 1.0 / (
241
+ self.base
242
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
243
+ )
244
+
245
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
246
+ # Reset the tables if the sequence length has changed,
247
+ # if we're on a new device (possibly due to tracing for instance),
248
+ # or if we're switching from inference mode to training
249
+ if (
250
+ self._cos_cached is None
251
+ or self._cos_cached.device != device
252
+ or self._cos_cached.dtype != dtype
253
+ or (self.training and self._cos_cached.is_inference())
254
+ ):
255
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
256
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
257
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
258
+ inv_freq = self._compute_inv_freq(device=device)
259
+
260
+ # Don't do einsum, it converts fp32 to fp16 under AMP
261
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
262
+ t = torch.arange(seqlen, device=device, dtype=dtype)
263
+ freqs = torch.outer(t, inv_freq)
264
+ if self.scale is None:
265
+ self._cos_cached = torch.cos(freqs).to(dtype)
266
+ self._sin_cached = torch.sin(freqs).to(dtype)
267
+ self._cos_k_cached = None
268
+ self._sin_k_cached = None
269
+ else:
270
+ power = (
271
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
272
+ - seqlen // 2
273
+ ) / self.scale_base
274
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
275
+ # We want the multiplication by scale to happen in fp32
276
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
277
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
278
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
279
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
280
+
281
+ def forward(self, q, k=None, v=None):
282
+
283
+ if self._cos_cached is None:
284
+ self._update_cos_sin_cache(self.max_sequence_length, device=q.device, dtype=q.dtype)
285
+
286
+ if k is None and v is None:
287
+ q = apply_rotary_emb_qkv_(
288
+ q,
289
+ self._cos_cached,
290
+ self._sin_cached,
291
+ self._cos_k_cached,
292
+ self._sin_k_cached,
293
+ interleaved=self.interleaved,
294
+ seqlen_offsets=0
295
+ )
296
+ elif v is None and k is not None:
297
+ q = apply_rotary_emb_func(
298
+ q,
299
+ self._cos_cached,
300
+ self._sin_cached,
301
+ interleaved=self.interleaved,
302
+ inplace=True,
303
+ seqlen_offsets=0
304
+ )
305
+
306
+ k = apply_rotary_emb_kv_(
307
+ k,
308
+ self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
309
+ self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
310
+ interleaved=self.interleaved,
311
+ seqlen_offsets=0,
312
+ )
313
+ else:
314
+ q = apply_rotary_emb_func(
315
+ q,
316
+ self._cos_cached,
317
+ self._sin_cached,
318
+ interleaved=self.interleaved,
319
+ inplace=True,
320
+ seqlen_offsets=0
321
+ )
322
+
323
+ k = apply_rotary_emb_func(
324
+ k,
325
+ self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
326
+ self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
327
+ interleaved=self.interleaved,
328
+ seqlen_offsets=0,
329
+ )
330
+
331
+ v = apply_rotary_emb_func(
332
+ v,
333
+ self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
334
+ self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
335
+ interleaved=self.interleaved,
336
+ seqlen_offsets=0,
337
+ )
338
+
339
+ return q, k, v, None
340
+
341
+ class FIRE(nn.Module):
342
+
343
+ def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512., eps=1e-6):
344
+ """
345
+ FIRE attention bias module.
346
+
347
+ Args:
348
+ num_heads: number of attention heads.
349
+ mlp_width: Width of MLP.
350
+ init_c: initial value of log transformation parameter
351
+ init_L: initial value of thresholding parameter
352
+ eps: small constant for numerical stability
353
+ """
354
+
355
+ super(FIRE, self).__init__()
356
+
357
+ # Define the MLP layers
358
+ self.mlp = nn.Sequential(
359
+ nn.Linear(1, mlp_width),
360
+ nn.ReLU(),
361
+ nn.Linear(mlp_width, num_heads)
362
+ )
363
+
364
+ # Initialize c (log transformation parameter)
365
+ self.c = nn.Parameter(torch.tensor(init_c))
366
+
367
+
368
+ # Initialize L (threshold)
369
+ self.init_L = nn.Parameter(torch.tensor(init_L),
370
+ requires_grad=False)
371
+ # Learn a multiplier to L
372
+ self.L_multiplier = nn.Parameter(torch.tensor(1.0))
373
+ self.eps = eps
374
+
375
+ def apply_fire(self, seq_length, device):
376
+ """
377
+ Compute FIRE attention bias.
378
+
379
+ Args:
380
+ x: input sequence,
381
+ shape [bsz, seq_len, num_heads, hidden_dim]
382
+
383
+ Returns:
384
+ attention bias,
385
+ shape [1, num_heads, seq_len, seq_len]
386
+ """
387
+ positions = torch.arange(seq_length,
388
+ dtype=torch.float32,
389
+ device=device)
390
+
391
+ rel_distance = positions[:, None] - positions[None, :]
392
+
393
+ # Thresholding the normalizer
394
+ threshold = torch.abs(self.L_multiplier * self.init_L)
395
+ pos_normalizer = torch.max(positions, threshold)
396
+ pos_normalizer = pos_normalizer[:, None]
397
+
398
+ # Amplifying differences among local positions
399
+ # with log transform
400
+ rel_distance = torch.sign(rel_distance) * torch.log(
401
+ torch.abs(self.c * rel_distance) + 1
402
+ )
403
+ pos_normalizer = torch.log(
404
+ torch.abs(self.c * pos_normalizer) + 1
405
+ ) + self.eps
406
+
407
+ # Progressive interpolation
408
+ normalized_distance = rel_distance / pos_normalizer
409
+ fire_bias = self.mlp(normalized_distance.unsqueeze(-1))
410
+ fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2)
411
+ return fire_bias
412
+
413
+ def forward(self, q, k=None, v=None):
414
+
415
+ bias = self.apply_fire(q.shape[1], device=q.device).contiguous().to(q.dtype)
416
+
417
+ return q, k, v, bias
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39cac982d9842a1d653966fb42b2fa3df4077334030c76cd8fa165b5aea244ea
3
+ size 587431346
rms_norm.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Copyright 2024 CATIE. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Modifications to the orignal file
17
+ # - support for torch.compile
18
+
19
+ import triton
20
+ import triton.language as tl
21
+ import torch
22
+ import math
23
+ from typing import Tuple
24
+
25
+ @triton.jit
26
+ def _rmsnorm_fwd_kernel(
27
+ X, # pointer to the input
28
+ Y, # pointer to the output
29
+ W, # pointer to the weights
30
+ Rstd, # pointer to the 1/std
31
+ stride_x_row, # how much to increase the pointer when moving by 1 row
32
+ stride_y_row,
33
+ N, # number of columns in X
34
+ eps, # epsilon to avoid division by zero
35
+ BLOCK_N: tl.constexpr,
36
+ IS_EVEN_N: tl.constexpr
37
+ ):
38
+
39
+ row = tl.program_id(0)
40
+ X += row * stride_x_row
41
+ Y += row * stride_y_row
42
+
43
+ # Compute mean and variance
44
+ cols = tl.arange(0, BLOCK_N)
45
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
46
+
47
+ xbar = tl.where(cols < N, x, 0.0)
48
+ var = tl.sum(xbar * xbar, axis=0) / N
49
+ rstd = 1 / tl.sqrt(var + eps)
50
+ tl.store(Rstd + row, rstd)
51
+
52
+ # Normalize and apply linear transformation
53
+ mask = cols < N
54
+ if IS_EVEN_N:
55
+ w = tl.load(W + cols).to(tl.float32)
56
+ else:
57
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
58
+
59
+ x_hat = x * rstd
60
+ y = x_hat * w
61
+
62
+ # Write output
63
+ if IS_EVEN_N:
64
+ tl.store(Y + cols, y)
65
+ else:
66
+ tl.store(Y + cols, y, mask=mask)
67
+
68
+ @triton.jit
69
+ def _rmsnorm_bwd_kernel(
70
+ X, # pointer to the input
71
+ W, # pointer to the weights
72
+ DY, # pointer to the output gradient
73
+ DX, # pointer to the input gradient
74
+ DW, # pointer to the partial sum of weights gradient
75
+ Rstd, # pointer to the 1/std
76
+ stride_x_row, # how much to increase the pointer when moving by 1 row
77
+ stride_dy_row,
78
+ stride_dx_row,
79
+ M, # number of rows in X
80
+ N, # number of columns in X
81
+ eps, # epsilon to avoid division by zero
82
+ rows_per_program,
83
+ BLOCK_N: tl.constexpr,
84
+ IS_EVEN_N: tl.constexpr
85
+ ):
86
+ # Map the program id to the elements of X, DX, and DY it should compute.
87
+ row_block_id = tl.program_id(0)
88
+ row_start = row_block_id * rows_per_program
89
+ cols = tl.arange(0, BLOCK_N)
90
+ mask = cols < N
91
+ X += row_start * stride_x_row
92
+
93
+ DY += row_start * stride_dy_row
94
+ DX += row_start * stride_dx_row
95
+
96
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
97
+
98
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
99
+
100
+ row_end = min((row_block_id + 1) * rows_per_program, M)
101
+
102
+ for row in range(row_start, row_end):
103
+ # Load data to SRAM
104
+ if IS_EVEN_N:
105
+ x = tl.load(X + cols).to(tl.float32)
106
+ dy = tl.load(DY + cols).to(tl.float32)
107
+ else:
108
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
109
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
110
+
111
+ rstd = tl.load(Rstd + row)
112
+
113
+ # Compute dx
114
+ xhat = x * rstd
115
+ if not IS_EVEN_N:
116
+ xhat = tl.where(mask, xhat, 0.0)
117
+
118
+ wdy = w * dy
119
+ dw += dy * xhat
120
+
121
+ c1 = tl.sum(xhat * wdy, axis=0) / N
122
+ dx = (wdy - xhat * c1) * rstd
123
+
124
+ tl.store(DX + cols, dx, mask=mask)
125
+
126
+ X += stride_x_row
127
+
128
+ DY += stride_dy_row
129
+ DX += stride_dx_row
130
+
131
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
132
+
133
+
134
+ @torch.library.custom_op("flasht5::rmsnorm_triton_fwd", mutates_args=(), device_types="cuda")
135
+ def rmsnorm_triton_fwd(
136
+ X: torch.Tensor,
137
+ weight: torch.Tensor,
138
+ eps: float
139
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
140
+
141
+ M, N = X.shape
142
+
143
+ assert X.stride(-1) == 1
144
+
145
+ assert weight.shape == (N,)
146
+ assert weight.stride(-1) == 1
147
+
148
+ # allocate output
149
+ Y = torch.empty_like(X)
150
+ assert Y.stride(-1) == 1
151
+
152
+ rstd = torch.empty((M,), dtype=torch.float32, device=X.device)
153
+
154
+ # Less than 64KB per feature: enqueue fused kernel
155
+ MAX_FUSED_SIZE = 65536 // X.element_size()
156
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
157
+ assert N <= BLOCK_N
158
+
159
+ # heuristics for number of warps
160
+ with torch.cuda.device(X.device.index):
161
+ _rmsnorm_fwd_kernel[(M,)](
162
+ X,
163
+ Y,
164
+ weight,
165
+ rstd,
166
+ X.stride(0),
167
+ Y.stride(0),
168
+ N,
169
+ eps,
170
+ BLOCK_N,
171
+ (N % BLOCK_N == 0)
172
+ )
173
+
174
+ return Y, rstd
175
+
176
+
177
+ @torch.library.register_fake("flasht5::rmsnorm_triton_fwd")
178
+ def rmsnorm_triton_fwd_abstract(X, weight, eps):
179
+ M, N = X.shape
180
+
181
+ Y = torch.empty_like(X)
182
+ rstd = torch.empty((M,), dtype=torch.float32, device=X.device)
183
+
184
+ return Y, rstd
185
+
186
+ @torch.library.custom_op("flasht5::rmsnorm_triton_bwd", mutates_args=(), device_types="cuda")
187
+ def rmsnorm_triton_bwd(
188
+ dy: torch.Tensor,
189
+ x: torch.Tensor,
190
+ weight: torch.Tensor,
191
+ rstd: torch.Tensor,
192
+ eps: float
193
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
194
+ M, N = x.shape
195
+ assert x.stride(-1) == 1
196
+ assert dy.stride(-1) == 1
197
+ assert dy.shape == (M, N)
198
+
199
+ assert weight.shape == (N,)
200
+ assert weight.stride(-1) == 1
201
+
202
+ # allocate output
203
+ dx = torch.empty_like(x)
204
+
205
+ # Less than 64KB per feature: enqueue fused kernel
206
+ MAX_FUSED_SIZE = 65536 // x.element_size()
207
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
208
+
209
+ assert N <= BLOCK_N
210
+
211
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
212
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
213
+
214
+ rows_per_program = math.ceil(M / sm_count)
215
+ grid = (sm_count,)
216
+ with torch.cuda.device(x.device.index):
217
+ _rmsnorm_bwd_kernel[grid](
218
+ x,
219
+ weight,
220
+ dy,
221
+ dx,
222
+ _dw,
223
+ rstd,
224
+ x.stride(0),
225
+ dy.stride(0),
226
+ dx.stride(0),
227
+ M,
228
+ N,
229
+ eps,
230
+ rows_per_program,
231
+ BLOCK_N,
232
+ (N % BLOCK_N == 0)
233
+ )
234
+ dw = _dw.sum(0).to(weight.dtype)
235
+
236
+ return dx, dw
237
+
238
+
239
+ @torch.library.register_fake("flasht5::rmsnorm_triton_bwd")
240
+ def rmsnorm_triton_bwd_abstract(dy, x, weight, rstd, eps):
241
+
242
+ M, N = x.shape
243
+ dx = torch.empty_like(x)
244
+ dw = torch.empty((1, N), dtype=torch.float32, device=weight.device)
245
+
246
+
247
+ return dx, dw
248
+
249
+
250
+ class Fast_RMS_Layernorm(torch.autograd.Function):
251
+ @staticmethod
252
+ def forward(ctx, X, W, eps=1e-6):
253
+
254
+ X_orig_shape = X.shape
255
+ X = X.reshape(-1, X.shape[-1])
256
+
257
+ y, rstd, = torch.ops.flasht5.rmsnorm_triton_fwd(X, W, eps)
258
+
259
+ y = y.reshape(X_orig_shape)
260
+
261
+ # We don't store y, will be recomputed in the backward pass to save memory
262
+ ctx.save_for_backward(X, W, rstd)
263
+ ctx.x_shape_og = X_orig_shape
264
+ ctx.eps = eps
265
+
266
+ return y
267
+
268
+ @staticmethod
269
+ def backward(ctx, dY):
270
+ X, weight, rstd = ctx.saved_tensors
271
+ dY = dY.reshape(-1, dY.shape[-1])
272
+
273
+ assert dY.shape == X.shape
274
+
275
+ dx, dw = torch.ops.flasht5.rmsnorm_triton_bwd(
276
+ dY,
277
+ X,
278
+ weight,
279
+ rstd,
280
+ ctx.eps
281
+ )
282
+
283
+ return dx.reshape(ctx.x_shape_og), dw, None
284
+
285
+ def fast_rms_layernorm(X, W, eps):
286
+ out = Fast_RMS_Layernorm.apply(X, W, eps)
287
+ return out
rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:672546ccb6198cb6029856dffdff7fa8bbc52726b212606155673ca847d54511
3
+ size 14244
scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:605bdfbc75175c1891038202d54ff9a471d27b66aed2deedcdaf8b1fa6ca2185
3
+ size 1256
special_tokens_map.json ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>",
103
+ "<extra_id_100>",
104
+ "<extra_id_101>",
105
+ "<extra_id_102>",
106
+ "<extra_id_103>",
107
+ "<extra_id_104>",
108
+ "<extra_id_105>",
109
+ "<extra_id_106>",
110
+ "<extra_id_107>",
111
+ "<extra_id_108>",
112
+ "<extra_id_109>",
113
+ "<extra_id_110>",
114
+ "<extra_id_111>",
115
+ "<extra_id_112>",
116
+ "<extra_id_113>",
117
+ "<extra_id_114>",
118
+ "<extra_id_115>",
119
+ "<extra_id_116>",
120
+ "<extra_id_117>",
121
+ "<extra_id_118>",
122
+ "<extra_id_119>",
123
+ "<extra_id_120>",
124
+ "<extra_id_121>",
125
+ "<extra_id_122>",
126
+ "<extra_id_123>",
127
+ "<extra_id_124>",
128
+ "<extra_id_125>",
129
+ "<extra_id_126>",
130
+ "<extra_id_127>",
131
+ "<extra_id_128>",
132
+ "<extra_id_129>",
133
+ "<extra_id_130>",
134
+ "<extra_id_131>",
135
+ "<extra_id_132>",
136
+ "<extra_id_133>",
137
+ "<extra_id_134>",
138
+ "<extra_id_135>",
139
+ "<extra_id_136>",
140
+ "<extra_id_137>",
141
+ "<extra_id_138>",
142
+ "<extra_id_139>",
143
+ "<extra_id_140>",
144
+ "<extra_id_141>",
145
+ "<extra_id_142>",
146
+ "<extra_id_143>",
147
+ "<extra_id_144>",
148
+ "<extra_id_145>",
149
+ "<extra_id_146>",
150
+ "<extra_id_147>",
151
+ "<extra_id_148>",
152
+ "<extra_id_149>",
153
+ "<extra_id_150>",
154
+ "<extra_id_151>",
155
+ "<extra_id_152>",
156
+ "<extra_id_153>",
157
+ "<extra_id_154>",
158
+ "<extra_id_155>",
159
+ "<extra_id_156>",
160
+ "<extra_id_157>",
161
+ "<extra_id_158>",
162
+ "<extra_id_159>",
163
+ "<extra_id_160>",
164
+ "<extra_id_161>",
165
+ "<extra_id_162>",
166
+ "<extra_id_163>",
167
+ "<extra_id_164>",
168
+ "<extra_id_165>",
169
+ "<extra_id_166>",
170
+ "<extra_id_167>",
171
+ "<extra_id_168>",
172
+ "<extra_id_169>",
173
+ "<extra_id_170>",
174
+ "<extra_id_171>",
175
+ "<extra_id_172>",
176
+ "<extra_id_173>",
177
+ "<extra_id_174>",
178
+ "<extra_id_175>",
179
+ "<extra_id_176>",
180
+ "<extra_id_177>",
181
+ "<extra_id_178>",
182
+ "<extra_id_179>",
183
+ "<extra_id_180>",
184
+ "<extra_id_181>",
185
+ "<extra_id_182>",
186
+ "<extra_id_183>",
187
+ "<extra_id_184>",
188
+ "<extra_id_185>",
189
+ "<extra_id_186>",
190
+ "<extra_id_187>",
191
+ "<extra_id_188>",
192
+ "<extra_id_189>",
193
+ "<extra_id_190>",
194
+ "<extra_id_191>",
195
+ "<extra_id_192>",
196
+ "<extra_id_193>",
197
+ "<extra_id_194>",
198
+ "<extra_id_195>",
199
+ "<extra_id_196>",
200
+ "<extra_id_197>",
201
+ "<extra_id_198>",
202
+ "<extra_id_199>",
203
+ "<extra_id_200>",
204
+ "<extra_id_201>",
205
+ "<extra_id_202>",
206
+ "<extra_id_203>",
207
+ "<extra_id_204>",
208
+ "<extra_id_205>",
209
+ "<extra_id_206>",
210
+ "<extra_id_207>",
211
+ "<extra_id_208>",
212
+ "<extra_id_209>",
213
+ "<extra_id_210>",
214
+ "<extra_id_211>",
215
+ "<extra_id_212>",
216
+ "<extra_id_213>",
217
+ "<extra_id_214>",
218
+ "<extra_id_215>",
219
+ "<extra_id_216>",
220
+ "<extra_id_217>",
221
+ "<extra_id_218>",
222
+ "<extra_id_219>",
223
+ "<extra_id_220>",
224
+ "<extra_id_221>",
225
+ "<extra_id_222>",
226
+ "<extra_id_223>",
227
+ "<extra_id_224>",
228
+ "<extra_id_225>",
229
+ "<extra_id_226>",
230
+ "<extra_id_227>",
231
+ "<extra_id_228>",
232
+ "<extra_id_229>",
233
+ "<extra_id_230>",
234
+ "<extra_id_231>",
235
+ "<extra_id_232>",
236
+ "<extra_id_233>",
237
+ "<extra_id_234>",
238
+ "<extra_id_235>",
239
+ "<extra_id_236>",
240
+ "<extra_id_237>",
241
+ "<extra_id_238>",
242
+ "<extra_id_239>",
243
+ "<extra_id_240>",
244
+ "<extra_id_241>",
245
+ "<extra_id_242>",
246
+ "<extra_id_243>",
247
+ "<extra_id_244>",
248
+ "<extra_id_245>",
249
+ "<extra_id_246>",
250
+ "<extra_id_247>",
251
+ "<extra_id_248>",
252
+ "<extra_id_249>",
253
+ "<extra_id_250>",
254
+ "<extra_id_251>",
255
+ "<extra_id_252>",
256
+ "<extra_id_253>",
257
+ "<extra_id_254>",
258
+ "<extra_id_255>"
259
+ ],
260
+ "cls_token": "<cls>",
261
+ "eos_token": "</s>",
262
+ "mask_token": "<mask>",
263
+ "pad_token": "<pad>",
264
+ "sep_token": "<sep>",
265
+ "unk_token": "<unk>"
266
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,2367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<cls>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "</s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<mask>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<pad>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<sep>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "<unk>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "6": {
52
+ "content": "<extra_id_0>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "7": {
60
+ "content": "<extra_id_1>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "8": {
68
+ "content": "<extra_id_2>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "9": {
76
+ "content": "<extra_id_3>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "10": {
84
+ "content": "<extra_id_4>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "11": {
92
+ "content": "<extra_id_5>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "12": {
100
+ "content": "<extra_id_6>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "13": {
108
+ "content": "<extra_id_7>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ },
115
+ "14": {
116
+ "content": "<extra_id_8>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "15": {
124
+ "content": "<extra_id_9>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "16": {
132
+ "content": "<extra_id_10>",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "17": {
140
+ "content": "<extra_id_11>",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "18": {
148
+ "content": "<extra_id_12>",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "19": {
156
+ "content": "<extra_id_13>",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": true
162
+ },
163
+ "20": {
164
+ "content": "<extra_id_14>",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": true
170
+ },
171
+ "21": {
172
+ "content": "<extra_id_15>",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": true
178
+ },
179
+ "22": {
180
+ "content": "<extra_id_16>",
181
+ "lstrip": false,
182
+ "normalized": false,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": true
186
+ },
187
+ "23": {
188
+ "content": "<extra_id_17>",
189
+ "lstrip": false,
190
+ "normalized": false,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": true
194
+ },
195
+ "24": {
196
+ "content": "<extra_id_18>",
197
+ "lstrip": false,
198
+ "normalized": false,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": true
202
+ },
203
+ "25": {
204
+ "content": "<extra_id_19>",
205
+ "lstrip": false,
206
+ "normalized": false,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": true
210
+ },
211
+ "26": {
212
+ "content": "<extra_id_20>",
213
+ "lstrip": false,
214
+ "normalized": false,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": true
218
+ },
219
+ "27": {
220
+ "content": "<extra_id_21>",
221
+ "lstrip": false,
222
+ "normalized": false,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": true
226
+ },
227
+ "28": {
228
+ "content": "<extra_id_22>",
229
+ "lstrip": false,
230
+ "normalized": false,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": true
234
+ },
235
+ "29": {
236
+ "content": "<extra_id_23>",
237
+ "lstrip": false,
238
+ "normalized": false,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": true
242
+ },
243
+ "30": {
244
+ "content": "<extra_id_24>",
245
+ "lstrip": false,
246
+ "normalized": false,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": true
250
+ },
251
+ "31": {
252
+ "content": "<extra_id_25>",
253
+ "lstrip": false,
254
+ "normalized": false,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": true
258
+ },
259
+ "32": {
260
+ "content": "<extra_id_26>",
261
+ "lstrip": false,
262
+ "normalized": false,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": true
266
+ },
267
+ "33": {
268
+ "content": "<extra_id_27>",
269
+ "lstrip": false,
270
+ "normalized": false,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": true
274
+ },
275
+ "34": {
276
+ "content": "<extra_id_28>",
277
+ "lstrip": false,
278
+ "normalized": false,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": true
282
+ },
283
+ "35": {
284
+ "content": "<extra_id_29>",
285
+ "lstrip": false,
286
+ "normalized": false,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": true
290
+ },
291
+ "36": {
292
+ "content": "<extra_id_30>",
293
+ "lstrip": false,
294
+ "normalized": false,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": true
298
+ },
299
+ "37": {
300
+ "content": "<extra_id_31>",
301
+ "lstrip": false,
302
+ "normalized": false,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": true
306
+ },
307
+ "38": {
308
+ "content": "<extra_id_32>",
309
+ "lstrip": false,
310
+ "normalized": false,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": true
314
+ },
315
+ "39": {
316
+ "content": "<extra_id_33>",
317
+ "lstrip": false,
318
+ "normalized": false,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": true
322
+ },
323
+ "40": {
324
+ "content": "<extra_id_34>",
325
+ "lstrip": false,
326
+ "normalized": false,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": true
330
+ },
331
+ "41": {
332
+ "content": "<extra_id_35>",
333
+ "lstrip": false,
334
+ "normalized": false,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": true
338
+ },
339
+ "42": {
340
+ "content": "<extra_id_36>",
341
+ "lstrip": false,
342
+ "normalized": false,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": true
346
+ },
347
+ "43": {
348
+ "content": "<extra_id_37>",
349
+ "lstrip": false,
350
+ "normalized": false,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": true
354
+ },
355
+ "44": {
356
+ "content": "<extra_id_38>",
357
+ "lstrip": false,
358
+ "normalized": false,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": true
362
+ },
363
+ "45": {
364
+ "content": "<extra_id_39>",
365
+ "lstrip": false,
366
+ "normalized": false,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": true
370
+ },
371
+ "46": {
372
+ "content": "<extra_id_40>",
373
+ "lstrip": false,
374
+ "normalized": false,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": true
378
+ },
379
+ "47": {
380
+ "content": "<extra_id_41>",
381
+ "lstrip": false,
382
+ "normalized": false,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": true
386
+ },
387
+ "48": {
388
+ "content": "<extra_id_42>",
389
+ "lstrip": false,
390
+ "normalized": false,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": true
394
+ },
395
+ "49": {
396
+ "content": "<extra_id_43>",
397
+ "lstrip": false,
398
+ "normalized": false,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": true
402
+ },
403
+ "50": {
404
+ "content": "<extra_id_44>",
405
+ "lstrip": false,
406
+ "normalized": false,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": true
410
+ },
411
+ "51": {
412
+ "content": "<extra_id_45>",
413
+ "lstrip": false,
414
+ "normalized": false,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": true
418
+ },
419
+ "52": {
420
+ "content": "<extra_id_46>",
421
+ "lstrip": false,
422
+ "normalized": false,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": true
426
+ },
427
+ "53": {
428
+ "content": "<extra_id_47>",
429
+ "lstrip": false,
430
+ "normalized": false,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": true
434
+ },
435
+ "54": {
436
+ "content": "<extra_id_48>",
437
+ "lstrip": false,
438
+ "normalized": false,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": true
442
+ },
443
+ "55": {
444
+ "content": "<extra_id_49>",
445
+ "lstrip": false,
446
+ "normalized": false,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": true
450
+ },
451
+ "56": {
452
+ "content": "<extra_id_50>",
453
+ "lstrip": false,
454
+ "normalized": false,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": true
458
+ },
459
+ "57": {
460
+ "content": "<extra_id_51>",
461
+ "lstrip": false,
462
+ "normalized": false,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": true
466
+ },
467
+ "58": {
468
+ "content": "<extra_id_52>",
469
+ "lstrip": false,
470
+ "normalized": false,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": true
474
+ },
475
+ "59": {
476
+ "content": "<extra_id_53>",
477
+ "lstrip": false,
478
+ "normalized": false,
479
+ "rstrip": false,
480
+ "single_word": false,
481
+ "special": true
482
+ },
483
+ "60": {
484
+ "content": "<extra_id_54>",
485
+ "lstrip": false,
486
+ "normalized": false,
487
+ "rstrip": false,
488
+ "single_word": false,
489
+ "special": true
490
+ },
491
+ "61": {
492
+ "content": "<extra_id_55>",
493
+ "lstrip": false,
494
+ "normalized": false,
495
+ "rstrip": false,
496
+ "single_word": false,
497
+ "special": true
498
+ },
499
+ "62": {
500
+ "content": "<extra_id_56>",
501
+ "lstrip": false,
502
+ "normalized": false,
503
+ "rstrip": false,
504
+ "single_word": false,
505
+ "special": true
506
+ },
507
+ "63": {
508
+ "content": "<extra_id_57>",
509
+ "lstrip": false,
510
+ "normalized": false,
511
+ "rstrip": false,
512
+ "single_word": false,
513
+ "special": true
514
+ },
515
+ "64": {
516
+ "content": "<extra_id_58>",
517
+ "lstrip": false,
518
+ "normalized": false,
519
+ "rstrip": false,
520
+ "single_word": false,
521
+ "special": true
522
+ },
523
+ "65": {
524
+ "content": "<extra_id_59>",
525
+ "lstrip": false,
526
+ "normalized": false,
527
+ "rstrip": false,
528
+ "single_word": false,
529
+ "special": true
530
+ },
531
+ "66": {
532
+ "content": "<extra_id_60>",
533
+ "lstrip": false,
534
+ "normalized": false,
535
+ "rstrip": false,
536
+ "single_word": false,
537
+ "special": true
538
+ },
539
+ "67": {
540
+ "content": "<extra_id_61>",
541
+ "lstrip": false,
542
+ "normalized": false,
543
+ "rstrip": false,
544
+ "single_word": false,
545
+ "special": true
546
+ },
547
+ "68": {
548
+ "content": "<extra_id_62>",
549
+ "lstrip": false,
550
+ "normalized": false,
551
+ "rstrip": false,
552
+ "single_word": false,
553
+ "special": true
554
+ },
555
+ "69": {
556
+ "content": "<extra_id_63>",
557
+ "lstrip": false,
558
+ "normalized": false,
559
+ "rstrip": false,
560
+ "single_word": false,
561
+ "special": true
562
+ },
563
+ "70": {
564
+ "content": "<extra_id_64>",
565
+ "lstrip": false,
566
+ "normalized": false,
567
+ "rstrip": false,
568
+ "single_word": false,
569
+ "special": true
570
+ },
571
+ "71": {
572
+ "content": "<extra_id_65>",
573
+ "lstrip": false,
574
+ "normalized": false,
575
+ "rstrip": false,
576
+ "single_word": false,
577
+ "special": true
578
+ },
579
+ "72": {
580
+ "content": "<extra_id_66>",
581
+ "lstrip": false,
582
+ "normalized": false,
583
+ "rstrip": false,
584
+ "single_word": false,
585
+ "special": true
586
+ },
587
+ "73": {
588
+ "content": "<extra_id_67>",
589
+ "lstrip": false,
590
+ "normalized": false,
591
+ "rstrip": false,
592
+ "single_word": false,
593
+ "special": true
594
+ },
595
+ "74": {
596
+ "content": "<extra_id_68>",
597
+ "lstrip": false,
598
+ "normalized": false,
599
+ "rstrip": false,
600
+ "single_word": false,
601
+ "special": true
602
+ },
603
+ "75": {
604
+ "content": "<extra_id_69>",
605
+ "lstrip": false,
606
+ "normalized": false,
607
+ "rstrip": false,
608
+ "single_word": false,
609
+ "special": true
610
+ },
611
+ "76": {
612
+ "content": "<extra_id_70>",
613
+ "lstrip": false,
614
+ "normalized": false,
615
+ "rstrip": false,
616
+ "single_word": false,
617
+ "special": true
618
+ },
619
+ "77": {
620
+ "content": "<extra_id_71>",
621
+ "lstrip": false,
622
+ "normalized": false,
623
+ "rstrip": false,
624
+ "single_word": false,
625
+ "special": true
626
+ },
627
+ "78": {
628
+ "content": "<extra_id_72>",
629
+ "lstrip": false,
630
+ "normalized": false,
631
+ "rstrip": false,
632
+ "single_word": false,
633
+ "special": true
634
+ },
635
+ "79": {
636
+ "content": "<extra_id_73>",
637
+ "lstrip": false,
638
+ "normalized": false,
639
+ "rstrip": false,
640
+ "single_word": false,
641
+ "special": true
642
+ },
643
+ "80": {
644
+ "content": "<extra_id_74>",
645
+ "lstrip": false,
646
+ "normalized": false,
647
+ "rstrip": false,
648
+ "single_word": false,
649
+ "special": true
650
+ },
651
+ "81": {
652
+ "content": "<extra_id_75>",
653
+ "lstrip": false,
654
+ "normalized": false,
655
+ "rstrip": false,
656
+ "single_word": false,
657
+ "special": true
658
+ },
659
+ "82": {
660
+ "content": "<extra_id_76>",
661
+ "lstrip": false,
662
+ "normalized": false,
663
+ "rstrip": false,
664
+ "single_word": false,
665
+ "special": true
666
+ },
667
+ "83": {
668
+ "content": "<extra_id_77>",
669
+ "lstrip": false,
670
+ "normalized": false,
671
+ "rstrip": false,
672
+ "single_word": false,
673
+ "special": true
674
+ },
675
+ "84": {
676
+ "content": "<extra_id_78>",
677
+ "lstrip": false,
678
+ "normalized": false,
679
+ "rstrip": false,
680
+ "single_word": false,
681
+ "special": true
682
+ },
683
+ "85": {
684
+ "content": "<extra_id_79>",
685
+ "lstrip": false,
686
+ "normalized": false,
687
+ "rstrip": false,
688
+ "single_word": false,
689
+ "special": true
690
+ },
691
+ "86": {
692
+ "content": "<extra_id_80>",
693
+ "lstrip": false,
694
+ "normalized": false,
695
+ "rstrip": false,
696
+ "single_word": false,
697
+ "special": true
698
+ },
699
+ "87": {
700
+ "content": "<extra_id_81>",
701
+ "lstrip": false,
702
+ "normalized": false,
703
+ "rstrip": false,
704
+ "single_word": false,
705
+ "special": true
706
+ },
707
+ "88": {
708
+ "content": "<extra_id_82>",
709
+ "lstrip": false,
710
+ "normalized": false,
711
+ "rstrip": false,
712
+ "single_word": false,
713
+ "special": true
714
+ },
715
+ "89": {
716
+ "content": "<extra_id_83>",
717
+ "lstrip": false,
718
+ "normalized": false,
719
+ "rstrip": false,
720
+ "single_word": false,
721
+ "special": true
722
+ },
723
+ "90": {
724
+ "content": "<extra_id_84>",
725
+ "lstrip": false,
726
+ "normalized": false,
727
+ "rstrip": false,
728
+ "single_word": false,
729
+ "special": true
730
+ },
731
+ "91": {
732
+ "content": "<extra_id_85>",
733
+ "lstrip": false,
734
+ "normalized": false,
735
+ "rstrip": false,
736
+ "single_word": false,
737
+ "special": true
738
+ },
739
+ "92": {
740
+ "content": "<extra_id_86>",
741
+ "lstrip": false,
742
+ "normalized": false,
743
+ "rstrip": false,
744
+ "single_word": false,
745
+ "special": true
746
+ },
747
+ "93": {
748
+ "content": "<extra_id_87>",
749
+ "lstrip": false,
750
+ "normalized": false,
751
+ "rstrip": false,
752
+ "single_word": false,
753
+ "special": true
754
+ },
755
+ "94": {
756
+ "content": "<extra_id_88>",
757
+ "lstrip": false,
758
+ "normalized": false,
759
+ "rstrip": false,
760
+ "single_word": false,
761
+ "special": true
762
+ },
763
+ "95": {
764
+ "content": "<extra_id_89>",
765
+ "lstrip": false,
766
+ "normalized": false,
767
+ "rstrip": false,
768
+ "single_word": false,
769
+ "special": true
770
+ },
771
+ "96": {
772
+ "content": "<extra_id_90>",
773
+ "lstrip": false,
774
+ "normalized": false,
775
+ "rstrip": false,
776
+ "single_word": false,
777
+ "special": true
778
+ },
779
+ "97": {
780
+ "content": "<extra_id_91>",
781
+ "lstrip": false,
782
+ "normalized": false,
783
+ "rstrip": false,
784
+ "single_word": false,
785
+ "special": true
786
+ },
787
+ "98": {
788
+ "content": "<extra_id_92>",
789
+ "lstrip": false,
790
+ "normalized": false,
791
+ "rstrip": false,
792
+ "single_word": false,
793
+ "special": true
794
+ },
795
+ "99": {
796
+ "content": "<extra_id_93>",
797
+ "lstrip": false,
798
+ "normalized": false,
799
+ "rstrip": false,
800
+ "single_word": false,
801
+ "special": true
802
+ },
803
+ "100": {
804
+ "content": "<extra_id_94>",
805
+ "lstrip": false,
806
+ "normalized": false,
807
+ "rstrip": false,
808
+ "single_word": false,
809
+ "special": true
810
+ },
811
+ "101": {
812
+ "content": "<extra_id_95>",
813
+ "lstrip": false,
814
+ "normalized": false,
815
+ "rstrip": false,
816
+ "single_word": false,
817
+ "special": true
818
+ },
819
+ "102": {
820
+ "content": "<extra_id_96>",
821
+ "lstrip": false,
822
+ "normalized": false,
823
+ "rstrip": false,
824
+ "single_word": false,
825
+ "special": true
826
+ },
827
+ "103": {
828
+ "content": "<extra_id_97>",
829
+ "lstrip": false,
830
+ "normalized": false,
831
+ "rstrip": false,
832
+ "single_word": false,
833
+ "special": true
834
+ },
835
+ "104": {
836
+ "content": "<extra_id_98>",
837
+ "lstrip": false,
838
+ "normalized": false,
839
+ "rstrip": false,
840
+ "single_word": false,
841
+ "special": true
842
+ },
843
+ "105": {
844
+ "content": "<extra_id_99>",
845
+ "lstrip": false,
846
+ "normalized": false,
847
+ "rstrip": false,
848
+ "single_word": false,
849
+ "special": true
850
+ },
851
+ "106": {
852
+ "content": "<extra_id_100>",
853
+ "lstrip": false,
854
+ "normalized": false,
855
+ "rstrip": false,
856
+ "single_word": false,
857
+ "special": true
858
+ },
859
+ "107": {
860
+ "content": "<extra_id_101>",
861
+ "lstrip": false,
862
+ "normalized": false,
863
+ "rstrip": false,
864
+ "single_word": false,
865
+ "special": true
866
+ },
867
+ "108": {
868
+ "content": "<extra_id_102>",
869
+ "lstrip": false,
870
+ "normalized": false,
871
+ "rstrip": false,
872
+ "single_word": false,
873
+ "special": true
874
+ },
875
+ "109": {
876
+ "content": "<extra_id_103>",
877
+ "lstrip": false,
878
+ "normalized": false,
879
+ "rstrip": false,
880
+ "single_word": false,
881
+ "special": true
882
+ },
883
+ "110": {
884
+ "content": "<extra_id_104>",
885
+ "lstrip": false,
886
+ "normalized": false,
887
+ "rstrip": false,
888
+ "single_word": false,
889
+ "special": true
890
+ },
891
+ "111": {
892
+ "content": "<extra_id_105>",
893
+ "lstrip": false,
894
+ "normalized": false,
895
+ "rstrip": false,
896
+ "single_word": false,
897
+ "special": true
898
+ },
899
+ "112": {
900
+ "content": "<extra_id_106>",
901
+ "lstrip": false,
902
+ "normalized": false,
903
+ "rstrip": false,
904
+ "single_word": false,
905
+ "special": true
906
+ },
907
+ "113": {
908
+ "content": "<extra_id_107>",
909
+ "lstrip": false,
910
+ "normalized": false,
911
+ "rstrip": false,
912
+ "single_word": false,
913
+ "special": true
914
+ },
915
+ "114": {
916
+ "content": "<extra_id_108>",
917
+ "lstrip": false,
918
+ "normalized": false,
919
+ "rstrip": false,
920
+ "single_word": false,
921
+ "special": true
922
+ },
923
+ "115": {
924
+ "content": "<extra_id_109>",
925
+ "lstrip": false,
926
+ "normalized": false,
927
+ "rstrip": false,
928
+ "single_word": false,
929
+ "special": true
930
+ },
931
+ "116": {
932
+ "content": "<extra_id_110>",
933
+ "lstrip": false,
934
+ "normalized": false,
935
+ "rstrip": false,
936
+ "single_word": false,
937
+ "special": true
938
+ },
939
+ "117": {
940
+ "content": "<extra_id_111>",
941
+ "lstrip": false,
942
+ "normalized": false,
943
+ "rstrip": false,
944
+ "single_word": false,
945
+ "special": true
946
+ },
947
+ "118": {
948
+ "content": "<extra_id_112>",
949
+ "lstrip": false,
950
+ "normalized": false,
951
+ "rstrip": false,
952
+ "single_word": false,
953
+ "special": true
954
+ },
955
+ "119": {
956
+ "content": "<extra_id_113>",
957
+ "lstrip": false,
958
+ "normalized": false,
959
+ "rstrip": false,
960
+ "single_word": false,
961
+ "special": true
962
+ },
963
+ "120": {
964
+ "content": "<extra_id_114>",
965
+ "lstrip": false,
966
+ "normalized": false,
967
+ "rstrip": false,
968
+ "single_word": false,
969
+ "special": true
970
+ },
971
+ "121": {
972
+ "content": "<extra_id_115>",
973
+ "lstrip": false,
974
+ "normalized": false,
975
+ "rstrip": false,
976
+ "single_word": false,
977
+ "special": true
978
+ },
979
+ "122": {
980
+ "content": "<extra_id_116>",
981
+ "lstrip": false,
982
+ "normalized": false,
983
+ "rstrip": false,
984
+ "single_word": false,
985
+ "special": true
986
+ },
987
+ "123": {
988
+ "content": "<extra_id_117>",
989
+ "lstrip": false,
990
+ "normalized": false,
991
+ "rstrip": false,
992
+ "single_word": false,
993
+ "special": true
994
+ },
995
+ "124": {
996
+ "content": "<extra_id_118>",
997
+ "lstrip": false,
998
+ "normalized": false,
999
+ "rstrip": false,
1000
+ "single_word": false,
1001
+ "special": true
1002
+ },
1003
+ "125": {
1004
+ "content": "<extra_id_119>",
1005
+ "lstrip": false,
1006
+ "normalized": false,
1007
+ "rstrip": false,
1008
+ "single_word": false,
1009
+ "special": true
1010
+ },
1011
+ "126": {
1012
+ "content": "<extra_id_120>",
1013
+ "lstrip": false,
1014
+ "normalized": false,
1015
+ "rstrip": false,
1016
+ "single_word": false,
1017
+ "special": true
1018
+ },
1019
+ "127": {
1020
+ "content": "<extra_id_121>",
1021
+ "lstrip": false,
1022
+ "normalized": false,
1023
+ "rstrip": false,
1024
+ "single_word": false,
1025
+ "special": true
1026
+ },
1027
+ "128": {
1028
+ "content": "<extra_id_122>",
1029
+ "lstrip": false,
1030
+ "normalized": false,
1031
+ "rstrip": false,
1032
+ "single_word": false,
1033
+ "special": true
1034
+ },
1035
+ "129": {
1036
+ "content": "<extra_id_123>",
1037
+ "lstrip": false,
1038
+ "normalized": false,
1039
+ "rstrip": false,
1040
+ "single_word": false,
1041
+ "special": true
1042
+ },
1043
+ "130": {
1044
+ "content": "<extra_id_124>",
1045
+ "lstrip": false,
1046
+ "normalized": false,
1047
+ "rstrip": false,
1048
+ "single_word": false,
1049
+ "special": true
1050
+ },
1051
+ "131": {
1052
+ "content": "<extra_id_125>",
1053
+ "lstrip": false,
1054
+ "normalized": false,
1055
+ "rstrip": false,
1056
+ "single_word": false,
1057
+ "special": true
1058
+ },
1059
+ "132": {
1060
+ "content": "<extra_id_126>",
1061
+ "lstrip": false,
1062
+ "normalized": false,
1063
+ "rstrip": false,
1064
+ "single_word": false,
1065
+ "special": true
1066
+ },
1067
+ "133": {
1068
+ "content": "<extra_id_127>",
1069
+ "lstrip": false,
1070
+ "normalized": false,
1071
+ "rstrip": false,
1072
+ "single_word": false,
1073
+ "special": true
1074
+ },
1075
+ "134": {
1076
+ "content": "<extra_id_128>",
1077
+ "lstrip": false,
1078
+ "normalized": false,
1079
+ "rstrip": false,
1080
+ "single_word": false,
1081
+ "special": true
1082
+ },
1083
+ "135": {
1084
+ "content": "<extra_id_129>",
1085
+ "lstrip": false,
1086
+ "normalized": false,
1087
+ "rstrip": false,
1088
+ "single_word": false,
1089
+ "special": true
1090
+ },
1091
+ "136": {
1092
+ "content": "<extra_id_130>",
1093
+ "lstrip": false,
1094
+ "normalized": false,
1095
+ "rstrip": false,
1096
+ "single_word": false,
1097
+ "special": true
1098
+ },
1099
+ "137": {
1100
+ "content": "<extra_id_131>",
1101
+ "lstrip": false,
1102
+ "normalized": false,
1103
+ "rstrip": false,
1104
+ "single_word": false,
1105
+ "special": true
1106
+ },
1107
+ "138": {
1108
+ "content": "<extra_id_132>",
1109
+ "lstrip": false,
1110
+ "normalized": false,
1111
+ "rstrip": false,
1112
+ "single_word": false,
1113
+ "special": true
1114
+ },
1115
+ "139": {
1116
+ "content": "<extra_id_133>",
1117
+ "lstrip": false,
1118
+ "normalized": false,
1119
+ "rstrip": false,
1120
+ "single_word": false,
1121
+ "special": true
1122
+ },
1123
+ "140": {
1124
+ "content": "<extra_id_134>",
1125
+ "lstrip": false,
1126
+ "normalized": false,
1127
+ "rstrip": false,
1128
+ "single_word": false,
1129
+ "special": true
1130
+ },
1131
+ "141": {
1132
+ "content": "<extra_id_135>",
1133
+ "lstrip": false,
1134
+ "normalized": false,
1135
+ "rstrip": false,
1136
+ "single_word": false,
1137
+ "special": true
1138
+ },
1139
+ "142": {
1140
+ "content": "<extra_id_136>",
1141
+ "lstrip": false,
1142
+ "normalized": false,
1143
+ "rstrip": false,
1144
+ "single_word": false,
1145
+ "special": true
1146
+ },
1147
+ "143": {
1148
+ "content": "<extra_id_137>",
1149
+ "lstrip": false,
1150
+ "normalized": false,
1151
+ "rstrip": false,
1152
+ "single_word": false,
1153
+ "special": true
1154
+ },
1155
+ "144": {
1156
+ "content": "<extra_id_138>",
1157
+ "lstrip": false,
1158
+ "normalized": false,
1159
+ "rstrip": false,
1160
+ "single_word": false,
1161
+ "special": true
1162
+ },
1163
+ "145": {
1164
+ "content": "<extra_id_139>",
1165
+ "lstrip": false,
1166
+ "normalized": false,
1167
+ "rstrip": false,
1168
+ "single_word": false,
1169
+ "special": true
1170
+ },
1171
+ "146": {
1172
+ "content": "<extra_id_140>",
1173
+ "lstrip": false,
1174
+ "normalized": false,
1175
+ "rstrip": false,
1176
+ "single_word": false,
1177
+ "special": true
1178
+ },
1179
+ "147": {
1180
+ "content": "<extra_id_141>",
1181
+ "lstrip": false,
1182
+ "normalized": false,
1183
+ "rstrip": false,
1184
+ "single_word": false,
1185
+ "special": true
1186
+ },
1187
+ "148": {
1188
+ "content": "<extra_id_142>",
1189
+ "lstrip": false,
1190
+ "normalized": false,
1191
+ "rstrip": false,
1192
+ "single_word": false,
1193
+ "special": true
1194
+ },
1195
+ "149": {
1196
+ "content": "<extra_id_143>",
1197
+ "lstrip": false,
1198
+ "normalized": false,
1199
+ "rstrip": false,
1200
+ "single_word": false,
1201
+ "special": true
1202
+ },
1203
+ "150": {
1204
+ "content": "<extra_id_144>",
1205
+ "lstrip": false,
1206
+ "normalized": false,
1207
+ "rstrip": false,
1208
+ "single_word": false,
1209
+ "special": true
1210
+ },
1211
+ "151": {
1212
+ "content": "<extra_id_145>",
1213
+ "lstrip": false,
1214
+ "normalized": false,
1215
+ "rstrip": false,
1216
+ "single_word": false,
1217
+ "special": true
1218
+ },
1219
+ "152": {
1220
+ "content": "<extra_id_146>",
1221
+ "lstrip": false,
1222
+ "normalized": false,
1223
+ "rstrip": false,
1224
+ "single_word": false,
1225
+ "special": true
1226
+ },
1227
+ "153": {
1228
+ "content": "<extra_id_147>",
1229
+ "lstrip": false,
1230
+ "normalized": false,
1231
+ "rstrip": false,
1232
+ "single_word": false,
1233
+ "special": true
1234
+ },
1235
+ "154": {
1236
+ "content": "<extra_id_148>",
1237
+ "lstrip": false,
1238
+ "normalized": false,
1239
+ "rstrip": false,
1240
+ "single_word": false,
1241
+ "special": true
1242
+ },
1243
+ "155": {
1244
+ "content": "<extra_id_149>",
1245
+ "lstrip": false,
1246
+ "normalized": false,
1247
+ "rstrip": false,
1248
+ "single_word": false,
1249
+ "special": true
1250
+ },
1251
+ "156": {
1252
+ "content": "<extra_id_150>",
1253
+ "lstrip": false,
1254
+ "normalized": false,
1255
+ "rstrip": false,
1256
+ "single_word": false,
1257
+ "special": true
1258
+ },
1259
+ "157": {
1260
+ "content": "<extra_id_151>",
1261
+ "lstrip": false,
1262
+ "normalized": false,
1263
+ "rstrip": false,
1264
+ "single_word": false,
1265
+ "special": true
1266
+ },
1267
+ "158": {
1268
+ "content": "<extra_id_152>",
1269
+ "lstrip": false,
1270
+ "normalized": false,
1271
+ "rstrip": false,
1272
+ "single_word": false,
1273
+ "special": true
1274
+ },
1275
+ "159": {
1276
+ "content": "<extra_id_153>",
1277
+ "lstrip": false,
1278
+ "normalized": false,
1279
+ "rstrip": false,
1280
+ "single_word": false,
1281
+ "special": true
1282
+ },
1283
+ "160": {
1284
+ "content": "<extra_id_154>",
1285
+ "lstrip": false,
1286
+ "normalized": false,
1287
+ "rstrip": false,
1288
+ "single_word": false,
1289
+ "special": true
1290
+ },
1291
+ "161": {
1292
+ "content": "<extra_id_155>",
1293
+ "lstrip": false,
1294
+ "normalized": false,
1295
+ "rstrip": false,
1296
+ "single_word": false,
1297
+ "special": true
1298
+ },
1299
+ "162": {
1300
+ "content": "<extra_id_156>",
1301
+ "lstrip": false,
1302
+ "normalized": false,
1303
+ "rstrip": false,
1304
+ "single_word": false,
1305
+ "special": true
1306
+ },
1307
+ "163": {
1308
+ "content": "<extra_id_157>",
1309
+ "lstrip": false,
1310
+ "normalized": false,
1311
+ "rstrip": false,
1312
+ "single_word": false,
1313
+ "special": true
1314
+ },
1315
+ "164": {
1316
+ "content": "<extra_id_158>",
1317
+ "lstrip": false,
1318
+ "normalized": false,
1319
+ "rstrip": false,
1320
+ "single_word": false,
1321
+ "special": true
1322
+ },
1323
+ "165": {
1324
+ "content": "<extra_id_159>",
1325
+ "lstrip": false,
1326
+ "normalized": false,
1327
+ "rstrip": false,
1328
+ "single_word": false,
1329
+ "special": true
1330
+ },
1331
+ "166": {
1332
+ "content": "<extra_id_160>",
1333
+ "lstrip": false,
1334
+ "normalized": false,
1335
+ "rstrip": false,
1336
+ "single_word": false,
1337
+ "special": true
1338
+ },
1339
+ "167": {
1340
+ "content": "<extra_id_161>",
1341
+ "lstrip": false,
1342
+ "normalized": false,
1343
+ "rstrip": false,
1344
+ "single_word": false,
1345
+ "special": true
1346
+ },
1347
+ "168": {
1348
+ "content": "<extra_id_162>",
1349
+ "lstrip": false,
1350
+ "normalized": false,
1351
+ "rstrip": false,
1352
+ "single_word": false,
1353
+ "special": true
1354
+ },
1355
+ "169": {
1356
+ "content": "<extra_id_163>",
1357
+ "lstrip": false,
1358
+ "normalized": false,
1359
+ "rstrip": false,
1360
+ "single_word": false,
1361
+ "special": true
1362
+ },
1363
+ "170": {
1364
+ "content": "<extra_id_164>",
1365
+ "lstrip": false,
1366
+ "normalized": false,
1367
+ "rstrip": false,
1368
+ "single_word": false,
1369
+ "special": true
1370
+ },
1371
+ "171": {
1372
+ "content": "<extra_id_165>",
1373
+ "lstrip": false,
1374
+ "normalized": false,
1375
+ "rstrip": false,
1376
+ "single_word": false,
1377
+ "special": true
1378
+ },
1379
+ "172": {
1380
+ "content": "<extra_id_166>",
1381
+ "lstrip": false,
1382
+ "normalized": false,
1383
+ "rstrip": false,
1384
+ "single_word": false,
1385
+ "special": true
1386
+ },
1387
+ "173": {
1388
+ "content": "<extra_id_167>",
1389
+ "lstrip": false,
1390
+ "normalized": false,
1391
+ "rstrip": false,
1392
+ "single_word": false,
1393
+ "special": true
1394
+ },
1395
+ "174": {
1396
+ "content": "<extra_id_168>",
1397
+ "lstrip": false,
1398
+ "normalized": false,
1399
+ "rstrip": false,
1400
+ "single_word": false,
1401
+ "special": true
1402
+ },
1403
+ "175": {
1404
+ "content": "<extra_id_169>",
1405
+ "lstrip": false,
1406
+ "normalized": false,
1407
+ "rstrip": false,
1408
+ "single_word": false,
1409
+ "special": true
1410
+ },
1411
+ "176": {
1412
+ "content": "<extra_id_170>",
1413
+ "lstrip": false,
1414
+ "normalized": false,
1415
+ "rstrip": false,
1416
+ "single_word": false,
1417
+ "special": true
1418
+ },
1419
+ "177": {
1420
+ "content": "<extra_id_171>",
1421
+ "lstrip": false,
1422
+ "normalized": false,
1423
+ "rstrip": false,
1424
+ "single_word": false,
1425
+ "special": true
1426
+ },
1427
+ "178": {
1428
+ "content": "<extra_id_172>",
1429
+ "lstrip": false,
1430
+ "normalized": false,
1431
+ "rstrip": false,
1432
+ "single_word": false,
1433
+ "special": true
1434
+ },
1435
+ "179": {
1436
+ "content": "<extra_id_173>",
1437
+ "lstrip": false,
1438
+ "normalized": false,
1439
+ "rstrip": false,
1440
+ "single_word": false,
1441
+ "special": true
1442
+ },
1443
+ "180": {
1444
+ "content": "<extra_id_174>",
1445
+ "lstrip": false,
1446
+ "normalized": false,
1447
+ "rstrip": false,
1448
+ "single_word": false,
1449
+ "special": true
1450
+ },
1451
+ "181": {
1452
+ "content": "<extra_id_175>",
1453
+ "lstrip": false,
1454
+ "normalized": false,
1455
+ "rstrip": false,
1456
+ "single_word": false,
1457
+ "special": true
1458
+ },
1459
+ "182": {
1460
+ "content": "<extra_id_176>",
1461
+ "lstrip": false,
1462
+ "normalized": false,
1463
+ "rstrip": false,
1464
+ "single_word": false,
1465
+ "special": true
1466
+ },
1467
+ "183": {
1468
+ "content": "<extra_id_177>",
1469
+ "lstrip": false,
1470
+ "normalized": false,
1471
+ "rstrip": false,
1472
+ "single_word": false,
1473
+ "special": true
1474
+ },
1475
+ "184": {
1476
+ "content": "<extra_id_178>",
1477
+ "lstrip": false,
1478
+ "normalized": false,
1479
+ "rstrip": false,
1480
+ "single_word": false,
1481
+ "special": true
1482
+ },
1483
+ "185": {
1484
+ "content": "<extra_id_179>",
1485
+ "lstrip": false,
1486
+ "normalized": false,
1487
+ "rstrip": false,
1488
+ "single_word": false,
1489
+ "special": true
1490
+ },
1491
+ "186": {
1492
+ "content": "<extra_id_180>",
1493
+ "lstrip": false,
1494
+ "normalized": false,
1495
+ "rstrip": false,
1496
+ "single_word": false,
1497
+ "special": true
1498
+ },
1499
+ "187": {
1500
+ "content": "<extra_id_181>",
1501
+ "lstrip": false,
1502
+ "normalized": false,
1503
+ "rstrip": false,
1504
+ "single_word": false,
1505
+ "special": true
1506
+ },
1507
+ "188": {
1508
+ "content": "<extra_id_182>",
1509
+ "lstrip": false,
1510
+ "normalized": false,
1511
+ "rstrip": false,
1512
+ "single_word": false,
1513
+ "special": true
1514
+ },
1515
+ "189": {
1516
+ "content": "<extra_id_183>",
1517
+ "lstrip": false,
1518
+ "normalized": false,
1519
+ "rstrip": false,
1520
+ "single_word": false,
1521
+ "special": true
1522
+ },
1523
+ "190": {
1524
+ "content": "<extra_id_184>",
1525
+ "lstrip": false,
1526
+ "normalized": false,
1527
+ "rstrip": false,
1528
+ "single_word": false,
1529
+ "special": true
1530
+ },
1531
+ "191": {
1532
+ "content": "<extra_id_185>",
1533
+ "lstrip": false,
1534
+ "normalized": false,
1535
+ "rstrip": false,
1536
+ "single_word": false,
1537
+ "special": true
1538
+ },
1539
+ "192": {
1540
+ "content": "<extra_id_186>",
1541
+ "lstrip": false,
1542
+ "normalized": false,
1543
+ "rstrip": false,
1544
+ "single_word": false,
1545
+ "special": true
1546
+ },
1547
+ "193": {
1548
+ "content": "<extra_id_187>",
1549
+ "lstrip": false,
1550
+ "normalized": false,
1551
+ "rstrip": false,
1552
+ "single_word": false,
1553
+ "special": true
1554
+ },
1555
+ "194": {
1556
+ "content": "<extra_id_188>",
1557
+ "lstrip": false,
1558
+ "normalized": false,
1559
+ "rstrip": false,
1560
+ "single_word": false,
1561
+ "special": true
1562
+ },
1563
+ "195": {
1564
+ "content": "<extra_id_189>",
1565
+ "lstrip": false,
1566
+ "normalized": false,
1567
+ "rstrip": false,
1568
+ "single_word": false,
1569
+ "special": true
1570
+ },
1571
+ "196": {
1572
+ "content": "<extra_id_190>",
1573
+ "lstrip": false,
1574
+ "normalized": false,
1575
+ "rstrip": false,
1576
+ "single_word": false,
1577
+ "special": true
1578
+ },
1579
+ "197": {
1580
+ "content": "<extra_id_191>",
1581
+ "lstrip": false,
1582
+ "normalized": false,
1583
+ "rstrip": false,
1584
+ "single_word": false,
1585
+ "special": true
1586
+ },
1587
+ "198": {
1588
+ "content": "<extra_id_192>",
1589
+ "lstrip": false,
1590
+ "normalized": false,
1591
+ "rstrip": false,
1592
+ "single_word": false,
1593
+ "special": true
1594
+ },
1595
+ "199": {
1596
+ "content": "<extra_id_193>",
1597
+ "lstrip": false,
1598
+ "normalized": false,
1599
+ "rstrip": false,
1600
+ "single_word": false,
1601
+ "special": true
1602
+ },
1603
+ "200": {
1604
+ "content": "<extra_id_194>",
1605
+ "lstrip": false,
1606
+ "normalized": false,
1607
+ "rstrip": false,
1608
+ "single_word": false,
1609
+ "special": true
1610
+ },
1611
+ "201": {
1612
+ "content": "<extra_id_195>",
1613
+ "lstrip": false,
1614
+ "normalized": false,
1615
+ "rstrip": false,
1616
+ "single_word": false,
1617
+ "special": true
1618
+ },
1619
+ "202": {
1620
+ "content": "<extra_id_196>",
1621
+ "lstrip": false,
1622
+ "normalized": false,
1623
+ "rstrip": false,
1624
+ "single_word": false,
1625
+ "special": true
1626
+ },
1627
+ "203": {
1628
+ "content": "<extra_id_197>",
1629
+ "lstrip": false,
1630
+ "normalized": false,
1631
+ "rstrip": false,
1632
+ "single_word": false,
1633
+ "special": true
1634
+ },
1635
+ "204": {
1636
+ "content": "<extra_id_198>",
1637
+ "lstrip": false,
1638
+ "normalized": false,
1639
+ "rstrip": false,
1640
+ "single_word": false,
1641
+ "special": true
1642
+ },
1643
+ "205": {
1644
+ "content": "<extra_id_199>",
1645
+ "lstrip": false,
1646
+ "normalized": false,
1647
+ "rstrip": false,
1648
+ "single_word": false,
1649
+ "special": true
1650
+ },
1651
+ "206": {
1652
+ "content": "<extra_id_200>",
1653
+ "lstrip": false,
1654
+ "normalized": false,
1655
+ "rstrip": false,
1656
+ "single_word": false,
1657
+ "special": true
1658
+ },
1659
+ "207": {
1660
+ "content": "<extra_id_201>",
1661
+ "lstrip": false,
1662
+ "normalized": false,
1663
+ "rstrip": false,
1664
+ "single_word": false,
1665
+ "special": true
1666
+ },
1667
+ "208": {
1668
+ "content": "<extra_id_202>",
1669
+ "lstrip": false,
1670
+ "normalized": false,
1671
+ "rstrip": false,
1672
+ "single_word": false,
1673
+ "special": true
1674
+ },
1675
+ "209": {
1676
+ "content": "<extra_id_203>",
1677
+ "lstrip": false,
1678
+ "normalized": false,
1679
+ "rstrip": false,
1680
+ "single_word": false,
1681
+ "special": true
1682
+ },
1683
+ "210": {
1684
+ "content": "<extra_id_204>",
1685
+ "lstrip": false,
1686
+ "normalized": false,
1687
+ "rstrip": false,
1688
+ "single_word": false,
1689
+ "special": true
1690
+ },
1691
+ "211": {
1692
+ "content": "<extra_id_205>",
1693
+ "lstrip": false,
1694
+ "normalized": false,
1695
+ "rstrip": false,
1696
+ "single_word": false,
1697
+ "special": true
1698
+ },
1699
+ "212": {
1700
+ "content": "<extra_id_206>",
1701
+ "lstrip": false,
1702
+ "normalized": false,
1703
+ "rstrip": false,
1704
+ "single_word": false,
1705
+ "special": true
1706
+ },
1707
+ "213": {
1708
+ "content": "<extra_id_207>",
1709
+ "lstrip": false,
1710
+ "normalized": false,
1711
+ "rstrip": false,
1712
+ "single_word": false,
1713
+ "special": true
1714
+ },
1715
+ "214": {
1716
+ "content": "<extra_id_208>",
1717
+ "lstrip": false,
1718
+ "normalized": false,
1719
+ "rstrip": false,
1720
+ "single_word": false,
1721
+ "special": true
1722
+ },
1723
+ "215": {
1724
+ "content": "<extra_id_209>",
1725
+ "lstrip": false,
1726
+ "normalized": false,
1727
+ "rstrip": false,
1728
+ "single_word": false,
1729
+ "special": true
1730
+ },
1731
+ "216": {
1732
+ "content": "<extra_id_210>",
1733
+ "lstrip": false,
1734
+ "normalized": false,
1735
+ "rstrip": false,
1736
+ "single_word": false,
1737
+ "special": true
1738
+ },
1739
+ "217": {
1740
+ "content": "<extra_id_211>",
1741
+ "lstrip": false,
1742
+ "normalized": false,
1743
+ "rstrip": false,
1744
+ "single_word": false,
1745
+ "special": true
1746
+ },
1747
+ "218": {
1748
+ "content": "<extra_id_212>",
1749
+ "lstrip": false,
1750
+ "normalized": false,
1751
+ "rstrip": false,
1752
+ "single_word": false,
1753
+ "special": true
1754
+ },
1755
+ "219": {
1756
+ "content": "<extra_id_213>",
1757
+ "lstrip": false,
1758
+ "normalized": false,
1759
+ "rstrip": false,
1760
+ "single_word": false,
1761
+ "special": true
1762
+ },
1763
+ "220": {
1764
+ "content": "<extra_id_214>",
1765
+ "lstrip": false,
1766
+ "normalized": false,
1767
+ "rstrip": false,
1768
+ "single_word": false,
1769
+ "special": true
1770
+ },
1771
+ "221": {
1772
+ "content": "<extra_id_215>",
1773
+ "lstrip": false,
1774
+ "normalized": false,
1775
+ "rstrip": false,
1776
+ "single_word": false,
1777
+ "special": true
1778
+ },
1779
+ "222": {
1780
+ "content": "<extra_id_216>",
1781
+ "lstrip": false,
1782
+ "normalized": false,
1783
+ "rstrip": false,
1784
+ "single_word": false,
1785
+ "special": true
1786
+ },
1787
+ "223": {
1788
+ "content": "<extra_id_217>",
1789
+ "lstrip": false,
1790
+ "normalized": false,
1791
+ "rstrip": false,
1792
+ "single_word": false,
1793
+ "special": true
1794
+ },
1795
+ "224": {
1796
+ "content": "<extra_id_218>",
1797
+ "lstrip": false,
1798
+ "normalized": false,
1799
+ "rstrip": false,
1800
+ "single_word": false,
1801
+ "special": true
1802
+ },
1803
+ "225": {
1804
+ "content": "<extra_id_219>",
1805
+ "lstrip": false,
1806
+ "normalized": false,
1807
+ "rstrip": false,
1808
+ "single_word": false,
1809
+ "special": true
1810
+ },
1811
+ "226": {
1812
+ "content": "<extra_id_220>",
1813
+ "lstrip": false,
1814
+ "normalized": false,
1815
+ "rstrip": false,
1816
+ "single_word": false,
1817
+ "special": true
1818
+ },
1819
+ "227": {
1820
+ "content": "<extra_id_221>",
1821
+ "lstrip": false,
1822
+ "normalized": false,
1823
+ "rstrip": false,
1824
+ "single_word": false,
1825
+ "special": true
1826
+ },
1827
+ "228": {
1828
+ "content": "<extra_id_222>",
1829
+ "lstrip": false,
1830
+ "normalized": false,
1831
+ "rstrip": false,
1832
+ "single_word": false,
1833
+ "special": true
1834
+ },
1835
+ "229": {
1836
+ "content": "<extra_id_223>",
1837
+ "lstrip": false,
1838
+ "normalized": false,
1839
+ "rstrip": false,
1840
+ "single_word": false,
1841
+ "special": true
1842
+ },
1843
+ "230": {
1844
+ "content": "<extra_id_224>",
1845
+ "lstrip": false,
1846
+ "normalized": false,
1847
+ "rstrip": false,
1848
+ "single_word": false,
1849
+ "special": true
1850
+ },
1851
+ "231": {
1852
+ "content": "<extra_id_225>",
1853
+ "lstrip": false,
1854
+ "normalized": false,
1855
+ "rstrip": false,
1856
+ "single_word": false,
1857
+ "special": true
1858
+ },
1859
+ "232": {
1860
+ "content": "<extra_id_226>",
1861
+ "lstrip": false,
1862
+ "normalized": false,
1863
+ "rstrip": false,
1864
+ "single_word": false,
1865
+ "special": true
1866
+ },
1867
+ "233": {
1868
+ "content": "<extra_id_227>",
1869
+ "lstrip": false,
1870
+ "normalized": false,
1871
+ "rstrip": false,
1872
+ "single_word": false,
1873
+ "special": true
1874
+ },
1875
+ "234": {
1876
+ "content": "<extra_id_228>",
1877
+ "lstrip": false,
1878
+ "normalized": false,
1879
+ "rstrip": false,
1880
+ "single_word": false,
1881
+ "special": true
1882
+ },
1883
+ "235": {
1884
+ "content": "<extra_id_229>",
1885
+ "lstrip": false,
1886
+ "normalized": false,
1887
+ "rstrip": false,
1888
+ "single_word": false,
1889
+ "special": true
1890
+ },
1891
+ "236": {
1892
+ "content": "<extra_id_230>",
1893
+ "lstrip": false,
1894
+ "normalized": false,
1895
+ "rstrip": false,
1896
+ "single_word": false,
1897
+ "special": true
1898
+ },
1899
+ "237": {
1900
+ "content": "<extra_id_231>",
1901
+ "lstrip": false,
1902
+ "normalized": false,
1903
+ "rstrip": false,
1904
+ "single_word": false,
1905
+ "special": true
1906
+ },
1907
+ "238": {
1908
+ "content": "<extra_id_232>",
1909
+ "lstrip": false,
1910
+ "normalized": false,
1911
+ "rstrip": false,
1912
+ "single_word": false,
1913
+ "special": true
1914
+ },
1915
+ "239": {
1916
+ "content": "<extra_id_233>",
1917
+ "lstrip": false,
1918
+ "normalized": false,
1919
+ "rstrip": false,
1920
+ "single_word": false,
1921
+ "special": true
1922
+ },
1923
+ "240": {
1924
+ "content": "<extra_id_234>",
1925
+ "lstrip": false,
1926
+ "normalized": false,
1927
+ "rstrip": false,
1928
+ "single_word": false,
1929
+ "special": true
1930
+ },
1931
+ "241": {
1932
+ "content": "<extra_id_235>",
1933
+ "lstrip": false,
1934
+ "normalized": false,
1935
+ "rstrip": false,
1936
+ "single_word": false,
1937
+ "special": true
1938
+ },
1939
+ "242": {
1940
+ "content": "<extra_id_236>",
1941
+ "lstrip": false,
1942
+ "normalized": false,
1943
+ "rstrip": false,
1944
+ "single_word": false,
1945
+ "special": true
1946
+ },
1947
+ "243": {
1948
+ "content": "<extra_id_237>",
1949
+ "lstrip": false,
1950
+ "normalized": false,
1951
+ "rstrip": false,
1952
+ "single_word": false,
1953
+ "special": true
1954
+ },
1955
+ "244": {
1956
+ "content": "<extra_id_238>",
1957
+ "lstrip": false,
1958
+ "normalized": false,
1959
+ "rstrip": false,
1960
+ "single_word": false,
1961
+ "special": true
1962
+ },
1963
+ "245": {
1964
+ "content": "<extra_id_239>",
1965
+ "lstrip": false,
1966
+ "normalized": false,
1967
+ "rstrip": false,
1968
+ "single_word": false,
1969
+ "special": true
1970
+ },
1971
+ "246": {
1972
+ "content": "<extra_id_240>",
1973
+ "lstrip": false,
1974
+ "normalized": false,
1975
+ "rstrip": false,
1976
+ "single_word": false,
1977
+ "special": true
1978
+ },
1979
+ "247": {
1980
+ "content": "<extra_id_241>",
1981
+ "lstrip": false,
1982
+ "normalized": false,
1983
+ "rstrip": false,
1984
+ "single_word": false,
1985
+ "special": true
1986
+ },
1987
+ "248": {
1988
+ "content": "<extra_id_242>",
1989
+ "lstrip": false,
1990
+ "normalized": false,
1991
+ "rstrip": false,
1992
+ "single_word": false,
1993
+ "special": true
1994
+ },
1995
+ "249": {
1996
+ "content": "<extra_id_243>",
1997
+ "lstrip": false,
1998
+ "normalized": false,
1999
+ "rstrip": false,
2000
+ "single_word": false,
2001
+ "special": true
2002
+ },
2003
+ "250": {
2004
+ "content": "<extra_id_244>",
2005
+ "lstrip": false,
2006
+ "normalized": false,
2007
+ "rstrip": false,
2008
+ "single_word": false,
2009
+ "special": true
2010
+ },
2011
+ "251": {
2012
+ "content": "<extra_id_245>",
2013
+ "lstrip": false,
2014
+ "normalized": false,
2015
+ "rstrip": false,
2016
+ "single_word": false,
2017
+ "special": true
2018
+ },
2019
+ "252": {
2020
+ "content": "<extra_id_246>",
2021
+ "lstrip": false,
2022
+ "normalized": false,
2023
+ "rstrip": false,
2024
+ "single_word": false,
2025
+ "special": true
2026
+ },
2027
+ "253": {
2028
+ "content": "<extra_id_247>",
2029
+ "lstrip": false,
2030
+ "normalized": false,
2031
+ "rstrip": false,
2032
+ "single_word": false,
2033
+ "special": true
2034
+ },
2035
+ "254": {
2036
+ "content": "<extra_id_248>",
2037
+ "lstrip": false,
2038
+ "normalized": false,
2039
+ "rstrip": false,
2040
+ "single_word": false,
2041
+ "special": true
2042
+ },
2043
+ "255": {
2044
+ "content": "<extra_id_249>",
2045
+ "lstrip": false,
2046
+ "normalized": false,
2047
+ "rstrip": false,
2048
+ "single_word": false,
2049
+ "special": true
2050
+ },
2051
+ "256": {
2052
+ "content": "<extra_id_250>",
2053
+ "lstrip": false,
2054
+ "normalized": false,
2055
+ "rstrip": false,
2056
+ "single_word": false,
2057
+ "special": true
2058
+ },
2059
+ "257": {
2060
+ "content": "<extra_id_251>",
2061
+ "lstrip": false,
2062
+ "normalized": false,
2063
+ "rstrip": false,
2064
+ "single_word": false,
2065
+ "special": true
2066
+ },
2067
+ "258": {
2068
+ "content": "<extra_id_252>",
2069
+ "lstrip": false,
2070
+ "normalized": false,
2071
+ "rstrip": false,
2072
+ "single_word": false,
2073
+ "special": true
2074
+ },
2075
+ "259": {
2076
+ "content": "<extra_id_253>",
2077
+ "lstrip": false,
2078
+ "normalized": false,
2079
+ "rstrip": false,
2080
+ "single_word": false,
2081
+ "special": true
2082
+ },
2083
+ "260": {
2084
+ "content": "<extra_id_254>",
2085
+ "lstrip": false,
2086
+ "normalized": false,
2087
+ "rstrip": false,
2088
+ "single_word": false,
2089
+ "special": true
2090
+ },
2091
+ "261": {
2092
+ "content": "<extra_id_255>",
2093
+ "lstrip": false,
2094
+ "normalized": false,
2095
+ "rstrip": false,
2096
+ "single_word": false,
2097
+ "special": true
2098
+ }
2099
+ },
2100
+ "additional_special_tokens": [
2101
+ "<extra_id_0>",
2102
+ "<extra_id_1>",
2103
+ "<extra_id_2>",
2104
+ "<extra_id_3>",
2105
+ "<extra_id_4>",
2106
+ "<extra_id_5>",
2107
+ "<extra_id_6>",
2108
+ "<extra_id_7>",
2109
+ "<extra_id_8>",
2110
+ "<extra_id_9>",
2111
+ "<extra_id_10>",
2112
+ "<extra_id_11>",
2113
+ "<extra_id_12>",
2114
+ "<extra_id_13>",
2115
+ "<extra_id_14>",
2116
+ "<extra_id_15>",
2117
+ "<extra_id_16>",
2118
+ "<extra_id_17>",
2119
+ "<extra_id_18>",
2120
+ "<extra_id_19>",
2121
+ "<extra_id_20>",
2122
+ "<extra_id_21>",
2123
+ "<extra_id_22>",
2124
+ "<extra_id_23>",
2125
+ "<extra_id_24>",
2126
+ "<extra_id_25>",
2127
+ "<extra_id_26>",
2128
+ "<extra_id_27>",
2129
+ "<extra_id_28>",
2130
+ "<extra_id_29>",
2131
+ "<extra_id_30>",
2132
+ "<extra_id_31>",
2133
+ "<extra_id_32>",
2134
+ "<extra_id_33>",
2135
+ "<extra_id_34>",
2136
+ "<extra_id_35>",
2137
+ "<extra_id_36>",
2138
+ "<extra_id_37>",
2139
+ "<extra_id_38>",
2140
+ "<extra_id_39>",
2141
+ "<extra_id_40>",
2142
+ "<extra_id_41>",
2143
+ "<extra_id_42>",
2144
+ "<extra_id_43>",
2145
+ "<extra_id_44>",
2146
+ "<extra_id_45>",
2147
+ "<extra_id_46>",
2148
+ "<extra_id_47>",
2149
+ "<extra_id_48>",
2150
+ "<extra_id_49>",
2151
+ "<extra_id_50>",
2152
+ "<extra_id_51>",
2153
+ "<extra_id_52>",
2154
+ "<extra_id_53>",
2155
+ "<extra_id_54>",
2156
+ "<extra_id_55>",
2157
+ "<extra_id_56>",
2158
+ "<extra_id_57>",
2159
+ "<extra_id_58>",
2160
+ "<extra_id_59>",
2161
+ "<extra_id_60>",
2162
+ "<extra_id_61>",
2163
+ "<extra_id_62>",
2164
+ "<extra_id_63>",
2165
+ "<extra_id_64>",
2166
+ "<extra_id_65>",
2167
+ "<extra_id_66>",
2168
+ "<extra_id_67>",
2169
+ "<extra_id_68>",
2170
+ "<extra_id_69>",
2171
+ "<extra_id_70>",
2172
+ "<extra_id_71>",
2173
+ "<extra_id_72>",
2174
+ "<extra_id_73>",
2175
+ "<extra_id_74>",
2176
+ "<extra_id_75>",
2177
+ "<extra_id_76>",
2178
+ "<extra_id_77>",
2179
+ "<extra_id_78>",
2180
+ "<extra_id_79>",
2181
+ "<extra_id_80>",
2182
+ "<extra_id_81>",
2183
+ "<extra_id_82>",
2184
+ "<extra_id_83>",
2185
+ "<extra_id_84>",
2186
+ "<extra_id_85>",
2187
+ "<extra_id_86>",
2188
+ "<extra_id_87>",
2189
+ "<extra_id_88>",
2190
+ "<extra_id_89>",
2191
+ "<extra_id_90>",
2192
+ "<extra_id_91>",
2193
+ "<extra_id_92>",
2194
+ "<extra_id_93>",
2195
+ "<extra_id_94>",
2196
+ "<extra_id_95>",
2197
+ "<extra_id_96>",
2198
+ "<extra_id_97>",
2199
+ "<extra_id_98>",
2200
+ "<extra_id_99>",
2201
+ "<extra_id_100>",
2202
+ "<extra_id_101>",
2203
+ "<extra_id_102>",
2204
+ "<extra_id_103>",
2205
+ "<extra_id_104>",
2206
+ "<extra_id_105>",
2207
+ "<extra_id_106>",
2208
+ "<extra_id_107>",
2209
+ "<extra_id_108>",
2210
+ "<extra_id_109>",
2211
+ "<extra_id_110>",
2212
+ "<extra_id_111>",
2213
+ "<extra_id_112>",
2214
+ "<extra_id_113>",
2215
+ "<extra_id_114>",
2216
+ "<extra_id_115>",
2217
+ "<extra_id_116>",
2218
+ "<extra_id_117>",
2219
+ "<extra_id_118>",
2220
+ "<extra_id_119>",
2221
+ "<extra_id_120>",
2222
+ "<extra_id_121>",
2223
+ "<extra_id_122>",
2224
+ "<extra_id_123>",
2225
+ "<extra_id_124>",
2226
+ "<extra_id_125>",
2227
+ "<extra_id_126>",
2228
+ "<extra_id_127>",
2229
+ "<extra_id_128>",
2230
+ "<extra_id_129>",
2231
+ "<extra_id_130>",
2232
+ "<extra_id_131>",
2233
+ "<extra_id_132>",
2234
+ "<extra_id_133>",
2235
+ "<extra_id_134>",
2236
+ "<extra_id_135>",
2237
+ "<extra_id_136>",
2238
+ "<extra_id_137>",
2239
+ "<extra_id_138>",
2240
+ "<extra_id_139>",
2241
+ "<extra_id_140>",
2242
+ "<extra_id_141>",
2243
+ "<extra_id_142>",
2244
+ "<extra_id_143>",
2245
+ "<extra_id_144>",
2246
+ "<extra_id_145>",
2247
+ "<extra_id_146>",
2248
+ "<extra_id_147>",
2249
+ "<extra_id_148>",
2250
+ "<extra_id_149>",
2251
+ "<extra_id_150>",
2252
+ "<extra_id_151>",
2253
+ "<extra_id_152>",
2254
+ "<extra_id_153>",
2255
+ "<extra_id_154>",
2256
+ "<extra_id_155>",
2257
+ "<extra_id_156>",
2258
+ "<extra_id_157>",
2259
+ "<extra_id_158>",
2260
+ "<extra_id_159>",
2261
+ "<extra_id_160>",
2262
+ "<extra_id_161>",
2263
+ "<extra_id_162>",
2264
+ "<extra_id_163>",
2265
+ "<extra_id_164>",
2266
+ "<extra_id_165>",
2267
+ "<extra_id_166>",
2268
+ "<extra_id_167>",
2269
+ "<extra_id_168>",
2270
+ "<extra_id_169>",
2271
+ "<extra_id_170>",
2272
+ "<extra_id_171>",
2273
+ "<extra_id_172>",
2274
+ "<extra_id_173>",
2275
+ "<extra_id_174>",
2276
+ "<extra_id_175>",
2277
+ "<extra_id_176>",
2278
+ "<extra_id_177>",
2279
+ "<extra_id_178>",
2280
+ "<extra_id_179>",
2281
+ "<extra_id_180>",
2282
+ "<extra_id_181>",
2283
+ "<extra_id_182>",
2284
+ "<extra_id_183>",
2285
+ "<extra_id_184>",
2286
+ "<extra_id_185>",
2287
+ "<extra_id_186>",
2288
+ "<extra_id_187>",
2289
+ "<extra_id_188>",
2290
+ "<extra_id_189>",
2291
+ "<extra_id_190>",
2292
+ "<extra_id_191>",
2293
+ "<extra_id_192>",
2294
+ "<extra_id_193>",
2295
+ "<extra_id_194>",
2296
+ "<extra_id_195>",
2297
+ "<extra_id_196>",
2298
+ "<extra_id_197>",
2299
+ "<extra_id_198>",
2300
+ "<extra_id_199>",
2301
+ "<extra_id_200>",
2302
+ "<extra_id_201>",
2303
+ "<extra_id_202>",
2304
+ "<extra_id_203>",
2305
+ "<extra_id_204>",
2306
+ "<extra_id_205>",
2307
+ "<extra_id_206>",
2308
+ "<extra_id_207>",
2309
+ "<extra_id_208>",
2310
+ "<extra_id_209>",
2311
+ "<extra_id_210>",
2312
+ "<extra_id_211>",
2313
+ "<extra_id_212>",
2314
+ "<extra_id_213>",
2315
+ "<extra_id_214>",
2316
+ "<extra_id_215>",
2317
+ "<extra_id_216>",
2318
+ "<extra_id_217>",
2319
+ "<extra_id_218>",
2320
+ "<extra_id_219>",
2321
+ "<extra_id_220>",
2322
+ "<extra_id_221>",
2323
+ "<extra_id_222>",
2324
+ "<extra_id_223>",
2325
+ "<extra_id_224>",
2326
+ "<extra_id_225>",
2327
+ "<extra_id_226>",
2328
+ "<extra_id_227>",
2329
+ "<extra_id_228>",
2330
+ "<extra_id_229>",
2331
+ "<extra_id_230>",
2332
+ "<extra_id_231>",
2333
+ "<extra_id_232>",
2334
+ "<extra_id_233>",
2335
+ "<extra_id_234>",
2336
+ "<extra_id_235>",
2337
+ "<extra_id_236>",
2338
+ "<extra_id_237>",
2339
+ "<extra_id_238>",
2340
+ "<extra_id_239>",
2341
+ "<extra_id_240>",
2342
+ "<extra_id_241>",
2343
+ "<extra_id_242>",
2344
+ "<extra_id_243>",
2345
+ "<extra_id_244>",
2346
+ "<extra_id_245>",
2347
+ "<extra_id_246>",
2348
+ "<extra_id_247>",
2349
+ "<extra_id_248>",
2350
+ "<extra_id_249>",
2351
+ "<extra_id_250>",
2352
+ "<extra_id_251>",
2353
+ "<extra_id_252>",
2354
+ "<extra_id_253>",
2355
+ "<extra_id_254>",
2356
+ "<extra_id_255>"
2357
+ ],
2358
+ "clean_up_tokenization_spaces": true,
2359
+ "cls_token": "<cls>",
2360
+ "eos_token": "</s>",
2361
+ "mask_token": "<mask>",
2362
+ "model_max_length": 1024,
2363
+ "pad_token": "<pad>",
2364
+ "sep_token": "<sep>",
2365
+ "tokenizer_class": "PreTrainedTokenizerFast",
2366
+ "unk_token": "<unk>"
2367
+ }
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48108a519b2d546b6b73832c5c3752b2c0920e3ce76ab654621c6c98f2de2ef0
3
+ size 5240