bourdoiscatie
commited on
Add FAT5-small
Browse files- README.md +224 -0
- adamw_scaled.py +281 -0
- attn_ref.py +29 -0
- config.json +59 -0
- configuration_flash_t5.py +84 -0
- cross_entropy_loss.py +426 -0
- custom_heads_flash_t5.py +315 -0
- flash_attention_v2_bias.py +905 -0
- generation_config.json +7 -0
- modeling_flash_t5.py +790 -0
- optimizer.pt +3 -0
- positional_encoding.py +417 -0
- pytorch_model.bin +3 -0
- rms_norm.py +287 -0
- rng_state.pth +3 -0
- scheduler.pt +3 -0
- special_tokens_map.json +266 -0
- tokenizer.json +0 -0
- tokenizer_config.json +2367 -0
- trainer_state.json +0 -0
- training_args.bin +3 -0
README.md
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: fr
|
3 |
+
datasets:
|
4 |
+
- uonlp/CulturaX
|
5 |
+
- wikimedia/wikipedia
|
6 |
+
- eckendoerffer/justice_fr
|
7 |
+
- bigcode/the-stack-dedup
|
8 |
+
metrics:
|
9 |
+
- f1
|
10 |
+
- exact_match
|
11 |
+
library_name: transformers
|
12 |
+
co2_eq_emissions: 13.5
|
13 |
+
license: apache-2.0
|
14 |
+
---
|
15 |
+
|
16 |
+
# FAT5 (Flash Attention T5) ⚡
|
17 |
+
|
18 |
+
<br>
|
19 |
+
|
20 |
+
<div align="center" style="line-height: 1;">
|
21 |
+
<a href="https://huggingface.co/spaces/CATIE-AQ/FAT5-report" style="margin: 2px;">
|
22 |
+
<img alt="Blog post (EN)" src="https://img.shields.io/badge/📝_Blog_post-English_version-f5de53?&color=blue" style="display: inline-block; vertical-align: middle;"/>
|
23 |
+
</a>
|
24 |
+
<a href="https://huggingface.co/spaces/CATIE-AQ/FAT5-rapport" style="margin: 2px;">
|
25 |
+
<img alt="Blog post (FR)" src="https://img.shields.io/badge/📝_Blog_post-French_version-f5de53?&color=blue" style="display: inline-block; vertical-align: middle;"/>
|
26 |
+
<a href="https://huggingface.co/collections/CATIE-AQ/catie-french-fat5-ul2-677697a35feea336389d6403" target="_blank" style="margin: 2px;">
|
27 |
+
<img alt="Hugging Face collection" src="https://img.shields.io/badge/🤗_Hugging_Face-Collection-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
28 |
+
</a>
|
29 |
+
</a>
|
30 |
+
<a href="https://github.com/catie-aq/flashT5" target="_blank" style="margin: 2px;">
|
31 |
+
<img alt="FAT5 GitHub" src="https://img.shields.io/github/stars/catie-aq/flashT5?style=social" style="display: inline-block; vertical-align: middle;"/>
|
32 |
+
</a>
|
33 |
+
</a>
|
34 |
+
<a href="https://opensource.org/licenses/Apache-2.0" target="_blank" style="margin: 2px;">
|
35 |
+
<img alt="License: Apache 2.0" src="https://img.shields.io/badge/License-Apache--2.0-green.svg" style="display: inline-block; vertical-align: middle;"/>
|
36 |
+
</a>
|
37 |
+
</div>
|
38 |
+
|
39 |
+
## Introduction
|
40 |
+
|
41 |
+
FAT5 (for Flash Attention T5) is an implementation of T5 in PyTorch with an [UL2](https://arxiv.org/abs/2205.05131) objective optimized for GPGPU for both training and inference. It uses an experimental feature for using [Flash Attention](https://arxiv.org/abs/2307.08691) (v2) with relative position encoding biases that allow to train or finetune the model on longer sequence lengths than the original T5. It also has support for other positional embeddings such as [RoPE](https://arxiv.org/abs/2104.09864), [ALiBi](https://arxiv.org/abs/2108.12409) or [FIRE](https://arxiv.org/abs/2310.04418).
|
42 |
+
This methodology enabled us to efficiently pretrain as a proof of concept a T5 with 147M parameters in French in a reasonable time (1,461H to see 419B tokens) and with limited resources (1 single A100; i.e. a computational budget of around €2,200) which you'll find the weights in this repo.
|
43 |
+
To achieve this, we designed CUDA/Triton kernels to make Flash Attention compatible with T5, and to provide linear inference, thus extending the context size that can be taken into account by the model.
|
44 |
+
Other optimizations have also been implemented, as detailed in a subsequent [blog post](https://huggingface.co/spaces/CATIE-AQ/FAT5-report).
|
45 |
+
|
46 |
+
<br>
|
47 |
+
|
48 |
+
## Motivation
|
49 |
+
|
50 |
+
While a lot of effort has been focused on optimizing decoder-only models, in many practical applications older architectures remains useful.
|
51 |
+
We focus on [T5](http://jmlr.org/papers/v21/20-074.html), an encoder-decoder architecture exhibiting very decent performances for [instruction tuning](https://arxiv.org/pdf/2306.04757.pdf) or even sometimes outperforming much larger models when [finetuned](https://arxiv.org/pdf/2402.00841.pdf). Moreover it’s a natural architecture while considering [distillation](https://arxiv.org/abs/2305.02301) of much larger models.
|
52 |
+
|
53 |
+
A critical limitation of this architecture is the length of the sequence that these models can deal with due to the quadratic size in memory. While this quadratic term cannot be removed without considering other form of attention (like for [LongT5](https://arxiv.org/abs/2112.07916)), it can still be alleviated to accomodate longer sequence lengths.
|
54 |
+
Another limitation is the pre-training time, since techniques such as Flash Attention are not available for this architecture.
|
55 |
+
|
56 |
+
<br>
|
57 |
+
|
58 |
+
## Our work
|
59 |
+
|
60 |
+
We used the [nanoT5](https://github.com/PiotrNawrot/nanoT5?tab=readme-ov-file#cite) implementation as the base for our work.
|
61 |
+
|
62 |
+
We worked on optimizing the core component of the model, which is the attention part. We used the Flash Attention (v2) that optimize both the memory usage and the efficient use of Tensor Cores.
|
63 |
+
|
64 |
+
We support different implementation of attention biases:
|
65 |
+
- Full attention biases with Flash Attention 2 using this [PR](https://github.com/Dao-AILab/flash-attention/pull/617)
|
66 |
+
- T5-like relative position encoding biases with Flash Attention 2 using this [PR](https://github.com/Dao-AILab/flash-attention/pull/956)
|
67 |
+
- Full attention biases with a [triton implementation](src/model/ops/flash_attention_v2_bias.py) of Flash Attention 2
|
68 |
+
|
69 |
+
<center>
|
70 |
+
<img src="https://github.com/catie-aq/flashT5/raw/main/assets/FAT5_dark.gif" alt="FAT5_dark" width="100%">
|
71 |
+
</center>
|
72 |
+
|
73 |
+
Other parts of the architecture where optimized using [ad-hoc Triton kernels](https://github.com/catie-aq/flashT5/tree/main/src/model/ops) for the cross-entropy (and z-loss) and layernorm.
|
74 |
+
|
75 |
+
For pretext tasks during pre-training, we use the [UL2](https://arxiv.org/abs/2205.05131v3) mixture of denoisers with the following 7 tasks:
|
76 |
+
|
77 |
+
```python
|
78 |
+
denoiser_list=[
|
79 |
+
{"mu": 3.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[R]"},
|
80 |
+
{"mu": 8.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[R]"},
|
81 |
+
{"mu": 4.0, "r": 0.0, "max_spans": 1, "prefix": "[S]"},
|
82 |
+
{"mu": 3.0, "r": 0.5, "max_spans": max_token_length, "prefix": "[X]"},
|
83 |
+
{"mu": 8.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[X]"},
|
84 |
+
{"mu": 64.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[X]"},
|
85 |
+
{"mu": 64.0, "r": 0.5, "max_spans": max_token_length, "prefix": "[X]"}]
|
86 |
+
|
87 |
+
denoiser_proportions=[0.165, 0.165, 0.34, 0.0825, 0.0825, 0.0825, 0.0825]
|
88 |
+
```
|
89 |
+
where `mu`: the span size, `r`: the % of masking in the span and `prefix`: the type of the pretext task (the meaning of the letters `[R]`, `[S]` and `[X]` is described [here](https://huggingface.co/google/ul2#mixture-of-denoisers)).
|
90 |
+
|
91 |
+
As there was no implementation available in PyTorch, we [added one](https://github.com/catie-aq/flashT5/blob/main/src/data/data_collator_ul2.py) and adapted a dynamic batching mechanism to reduce padding in the model.
|
92 |
+
|
93 |
+
<br>
|
94 |
+
|
95 |
+
## Benchmarks
|
96 |
+
|
97 |
+
#### TFLOPS
|
98 |
+
|
99 |
+
The number of TFLOPS (trillions of floating-point calculations a processor can perform in one second) is probably the most eloquent measure of the impact of the optimizations carried out.
|
100 |
+
We therefore compare four approaches:
|
101 |
+
• the SPDA (Scaled Dot Product Attention) implementation with full bias,
|
102 |
+
• the same implementation but in Triton,
|
103 |
+
• the Flash Attention RPE implementation (our kernel),
|
104 |
+
• the Flash Attention implementation, i.e. without bias. We've included it here for reference, as it's unusable in practice for a T5.
|
105 |
+
|
106 |
+
For the forward pass, we have:
|
107 |
+
|
108 |
+
<center>
|
109 |
+
<img src="https://github.com/catie-aq/flashT5/raw/main/assets/benchmarks/FWD-causal-True_dark.png" alt="FWD-causal-True_dark" width="100%">
|
110 |
+
</center>
|
111 |
+
|
112 |
+
For the forward pass, we can see that the Triton approach achieves 1.34 times more FLOPS than SPDA, and that the Flash Attention RPE approach achieves 1.99 times more FLOPS than SPDA.
|
113 |
+
We can also see that our bf16 implementation is equivalent to fp16 (doing even better at size 512).
|
114 |
+
|
115 |
+
For the backward pass, we have:
|
116 |
+
|
117 |
+
<center>
|
118 |
+
<img src="https://github.com/catie-aq/flashT5/raw/main/assets/benchmarks/BWD-causal-True_dark.png" alt="BWD-causal-True_dark" width="100%">
|
119 |
+
</center>
|
120 |
+
|
121 |
+
For the backward pass, the Triton implementation is less efficient than SPDA, with 0.71 times the FLOPS of SPDA. The Flash Attention RPE implementation is more or less equivalent to SPDA (1.018 times more FLOPS).
|
122 |
+
We can also observe that Triton in head_dim 64 is more efficient than Triton in head_dim 128.
|
123 |
+
|
124 |
+
|
125 |
+
#### Torch vs Triton
|
126 |
+
|
127 |
+
We mentioned above that we had optimized parts of the architecture using ad hoc Triton kernels, namely the cross-entropy and RMSNorm layer. The following benchmarks should illustrate why.
|
128 |
+
For cross-entropy, we obtain a forward pass 7 to 11.4 times faster, a backward pass 3.26 to 3.75 times faster and a memory reduced by a factor of 4:
|
129 |
+
|
130 |
+
<center>
|
131 |
+
<img src="https://github.com/catie-aq/flashT5/raw/main/assets/benchmarks/CE_dark.png" alt="CE_dark" width="100%">
|
132 |
+
</center>
|
133 |
+
|
134 |
+
For the RMSNorm layer, we obtain a 3 to 5 times faster forward pass, a 2.33 to 4.33 times faster reverse pass and a memory reduced by a factor of 3.2:
|
135 |
+
<center>
|
136 |
+
<img src="https://github.com/catie-aq/flashT5/raw/main/assets/benchmarks/LN_dark.png" alt="LN_dark" width="100%">
|
137 |
+
</center>
|
138 |
+
|
139 |
+
Note that all benchmark graphs can be generated automatically using the following [code](https://github.com/catie-aq/flashT5/tree/main/benchmarks).
|
140 |
+
|
141 |
+
<br>
|
142 |
+
|
143 |
+
## Applications
|
144 |
+
|
145 |
+
### To French
|
146 |
+
We've pretrained a small (147M parameters) FAT5-UL2 in French. This is the model you'll find in this Hugging Face repo.
|
147 |
+
The dataset we used is a mixture of [CulturaX](https://huggingface.co/datasets/uonlp/CulturaX), [Wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia), [justice_fr](https://huggingface.co/datasets/eckendoerffer/justice_fr) and [The Stack](https://huggingface.co/datasets/bigcode/the-stack-dedup).
|
148 |
+
Our tokenizer of size 32,768 (8**5) is trained on CulturaX and The Stack.
|
149 |
+
Our model is pre-trained on a sequence of 1,024 tokens on a single A100 for 1,461H (= 419.4B tokens seen) at an estimated cost of €2,200.
|
150 |
+
|
151 |
+
Carbon emissions for the pretraining of this model were estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact:
|
152 |
+
• Hardware Type: A100 PCIe 40/80GB
|
153 |
+
• Hours used: 1,461h
|
154 |
+
• Cloud Provider: Private Infrastructure
|
155 |
+
• Carbon Efficiency (kg/kWh): 0.03696 kg (estimated from [electricitymaps](https://app.electricitymaps.com/zone/FR) (average carbon intensity in France average between October 18, 2024 and December 19, 2024)
|
156 |
+
• **Carbon Emitted** *(Power consumption x Time x Carbon produced based on location of power grid)*: **13.5 kg eq. CO2**.
|
157 |
+
|
158 |
+
|
159 |
+
### To other language
|
160 |
+
|
161 |
+
Our contribution focuses on French, with the pre-training and finetuning of models for comparison against French benchmarks. For other languages, we can't afford to do the same kind of work.
|
162 |
+
|
163 |
+
Nevertheless, to ensure that it can be used in other languages, we have developed a [code](https://github.com/catie-aq/flashT5/blob/main/convert_huggingface_t5.py) for adapting already pre-trained (m)T5/FLAN-T5 weights to our method. In this way, we hope users of a specific language will be able to efficiently continue pre-training one of these models to adapt it to more recent data, for example.
|
164 |
+
Note, however, that this adaptation is limited, since the additional pre-training will have to be carried out within the precision of the original model. For example, if the model's weights are in FP32 (which is the case with the FLAN-T5), training will not be as fast as with the FAT5, which is in BF16.
|
165 |
+
|
166 |
+
For English speakers, we have already adapted the weights of the various versions of the [FLANT-T5](https://arxiv.org/abs/2210.11416) to our method. All weights can be found in this Hugging Face [collection](https://huggingface.co/collections/CATIE-AQ/catie-english-fat5-flan-662b679a8e855c7c0137d69e).
|
167 |
+
To use one of the models, simply do the command:
|
168 |
+
|
169 |
+
```
|
170 |
+
from transformers import AutoModel, AutoTokenizer
|
171 |
+
model = AutoModel.from_pretrained("CATIE-AQ/FAT5-small-flan-en", trust_remote_code=True)
|
172 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
173 |
+
```
|
174 |
+
|
175 |
+
<br>
|
176 |
+
|
177 |
+
## Pretraining
|
178 |
+
|
179 |
+
If you want to pre-train your own model (to be specialized in a specific domain for example, and thus benefit from a custom tokenizer), we included a [tutorial](examples/minipile) to pretrain a small model on [minipile](https://huggingface.co/datasets/JeanKaddour/minipile) to show how it should be done.
|
180 |
+
You can find the documentation of the model configuration file [here](docs/configuration_file.md).
|
181 |
+
Note that we tested and trained the model of the tutorial on A100. It may or may not work with other GPUs.
|
182 |
+
|
183 |
+
<br>
|
184 |
+
|
185 |
+
<!--
|
186 |
+
## Finetuning
|
187 |
+
|
188 |
+
Once you've pre-trained your model, you'll want to finetune it. Because it's a custom model (hence the need for `trust_remote_code=True` to load it), it's currently causing some difficulties due to flaws in Hugging Face's Transformers library (during the `push_to_hub`, for example). This should be resolved in January 2025, as we're working on porting the FAT5 directly into Transformers with the help of their lovely team 🤗
|
189 |
+
-->
|
190 |
+
|
191 |
+
## Roadmap
|
192 |
+
|
193 |
+
We invite you to consult the “Next stage” section of the blog post.
|
194 |
+
|
195 |
+
<br>
|
196 |
+
|
197 |
+
## Citation
|
198 |
+
```
|
199 |
+
The DOI must be generated once model in public
|
200 |
+
```
|
201 |
+
|
202 |
+
<br>
|
203 |
+
|
204 |
+
## License
|
205 |
+
[Apache-2.0 license](https://github.com/catie-aq/flashT5/tree/main?tab=Apache-2.0-1-ov-file#readme)
|
206 |
+
|
207 |
+
<br>
|
208 |
+
|
209 |
+
## Ackowledgment
|
210 |
+
|
211 |
+
We use the following repos and thanks the authors for this:
|
212 |
+
- [nanoT5](https://github.com/PiotrNawrot/nanoT5) for the simple implementation and the optimizer.
|
213 |
+
- [Flash attention](https://github.com/Dao-AILab/flash-attention) for the groundbreaking algorithm for computing attention.
|
214 |
+
- [Hugging Face](https://github.com/huggingface/transformers) for their excellent library.
|
215 |
+
- [FlagAttention](https://github.com/FlagOpen/FlagAttention) for the implementation of FA2 in Triton.
|
216 |
+
- [Unsloth](https://github.com/unslothai/unsloth) for the simple Triton kernels of the cross-entropy and layernorm that we adapted to our usage.
|
217 |
+
- [TurboT5](https://github.com/Knowledgator/TurboT5) for the improvement of the February 2024 version of our work.
|
218 |
+
|
219 |
+
This work was support by the [Vaniila platform](http://vaniila.ai/).<br>
|
220 |
+
<div align="center">
|
221 |
+
<a href="[https://example.com](http://vaniila.ai/)" target="_blank">
|
222 |
+
<img src="https://www.vaniila.ai/wp-content/uploads/2020/02/Vaniila_bleu_horizontal.png" alt="Vaniila Logo" width="200">
|
223 |
+
</a>
|
224 |
+
</div>
|
adamw_scaled.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
|
4 |
+
from torch.optim import Optimizer
|
5 |
+
from torch.optim.optimizer import _default_to_fused_or_foreach
|
6 |
+
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
7 |
+
from typing import Iterable, Tuple
|
8 |
+
from torch import nn, Tensor
|
9 |
+
|
10 |
+
class AdamWScale(Optimizer):
|
11 |
+
"""
|
12 |
+
This AdamW implementation is copied from Huggingface.
|
13 |
+
We modified it with Adagrad scaling by rms of a weight tensor
|
14 |
+
|
15 |
+
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
16 |
+
Regularization](https://arxiv.org/abs/1711.05101).
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
params (`Iterable[nn.parameter.Parameter]`):
|
20 |
+
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
21 |
+
lr (`float`, *optional*, defaults to 1e-3):
|
22 |
+
The learning rate to use.
|
23 |
+
betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
|
24 |
+
Adam's betas parameters (b1, b2).
|
25 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
26 |
+
Adam's epsilon for numerical stability.
|
27 |
+
weight_decay (`float`, *optional*, defaults to 0.0):
|
28 |
+
Decoupled weight decay to apply.
|
29 |
+
kahan_sum (`bool`, *optional*, defaults to False):
|
30 |
+
Whether to use Kahan summation for updating parameters.
|
31 |
+
foreach (`bool`, *optional*, defaults to False):
|
32 |
+
Whether to use the foreach implementation.
|
33 |
+
correct_bias (`bool`, *optional*, defaults to True):
|
34 |
+
Whether to correct bias in Adam.
|
35 |
+
use_state_dtype (`torch.dtype`, *optional*, defaults to None):
|
36 |
+
The dtype to use for optimizer state. If None, use the default dtype.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
params: Iterable[nn.parameter.Parameter],
|
42 |
+
lr: float = 1e-3,
|
43 |
+
betas: Tuple[float, float] = (0.9, 0.999),
|
44 |
+
eps: float = 1e-6,
|
45 |
+
weight_decay: float = 0.0,
|
46 |
+
kahan_sum: bool = False,
|
47 |
+
foreach: bool = False,
|
48 |
+
correct_bias: bool = True,
|
49 |
+
use_state_dtype: torch.dtype = None
|
50 |
+
):
|
51 |
+
if lr < 0.0:
|
52 |
+
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
53 |
+
if not 0.0 <= betas[0] < 1.0:
|
54 |
+
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
55 |
+
if not 0.0 <= betas[1] < 1.0:
|
56 |
+
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
57 |
+
if not 0.0 <= eps:
|
58 |
+
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
59 |
+
|
60 |
+
assert not (foreach and use_state_dtype is not None), "foreach is not supported with use_state_dtype"
|
61 |
+
|
62 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, \
|
63 |
+
kahan_sum=kahan_sum, correct_bias=correct_bias, use_state_dtype=use_state_dtype)
|
64 |
+
|
65 |
+
super().__init__(params, defaults)
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def _rms(tensor):
|
69 |
+
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
70 |
+
|
71 |
+
@torch.no_grad()
|
72 |
+
def step(self, closure=None):
|
73 |
+
"""
|
74 |
+
Performs a single optimization step.
|
75 |
+
|
76 |
+
Arguments:
|
77 |
+
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
78 |
+
"""
|
79 |
+
loss = None
|
80 |
+
if closure is not None:
|
81 |
+
loss = closure()
|
82 |
+
|
83 |
+
for group in self.param_groups:
|
84 |
+
params, grads, exp_avgs, exp_avg_sqs, steps, kahan_comps = [], [], [], [], [], []
|
85 |
+
|
86 |
+
# Initialization
|
87 |
+
for p in group['params']:
|
88 |
+
if p.grad is None:
|
89 |
+
continue
|
90 |
+
|
91 |
+
params.append(p)
|
92 |
+
if p.grad.is_sparse:
|
93 |
+
raise RuntimeError('AdamWScale does not support sparse gradients')
|
94 |
+
grads.append(p.grad)
|
95 |
+
|
96 |
+
state = self.state[p]
|
97 |
+
|
98 |
+
# State initialization
|
99 |
+
if "kahan_comp" not in state:
|
100 |
+
state['step'] = torch.tensor(0, dtype=torch.int32, device=p.device)
|
101 |
+
|
102 |
+
if group["use_state_dtype"] in [torch.float16, torch.bfloat16]:
|
103 |
+
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=group["use_state_dtype"])
|
104 |
+
state['exp_avg_sq'] = torch.zeros_like(p, device=p.device, dtype=group["use_state_dtype"])
|
105 |
+
else:
|
106 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
107 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
108 |
+
|
109 |
+
if group["kahan_sum"] and p.dtype in [torch.float16, torch.bfloat16]:
|
110 |
+
state["kahan_comp"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
111 |
+
else:
|
112 |
+
state["kahan_comp"] = None
|
113 |
+
group["kahan_sum"] = False
|
114 |
+
|
115 |
+
exp_avgs.append(state['exp_avg'])
|
116 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
117 |
+
kahan_comps.append(state["kahan_comp"])
|
118 |
+
steps.append(state["step"])
|
119 |
+
|
120 |
+
torch._foreach_add_(steps, 1)
|
121 |
+
|
122 |
+
# AdamW step
|
123 |
+
if group["foreach"] and _default_to_fused_or_foreach(params, False, False):
|
124 |
+
self._foreach_adamwscaled(params,
|
125 |
+
grads,
|
126 |
+
exp_avgs,
|
127 |
+
exp_avg_sqs,
|
128 |
+
steps,
|
129 |
+
kahan_comps,
|
130 |
+
group["lr"],
|
131 |
+
group["betas"][0],
|
132 |
+
group["betas"][1],
|
133 |
+
group["weight_decay"],
|
134 |
+
group["eps"],
|
135 |
+
group["kahan_sum"],
|
136 |
+
group["correct_bias"])
|
137 |
+
else:
|
138 |
+
self._adamwscaled(params,
|
139 |
+
grads,
|
140 |
+
exp_avgs,
|
141 |
+
exp_avg_sqs,
|
142 |
+
steps,
|
143 |
+
kahan_comps,
|
144 |
+
group["lr"],
|
145 |
+
group["betas"][0],
|
146 |
+
group["betas"][1],
|
147 |
+
group["weight_decay"],
|
148 |
+
group["eps"],
|
149 |
+
group["kahan_sum"],
|
150 |
+
group["correct_bias"])
|
151 |
+
|
152 |
+
return loss
|
153 |
+
|
154 |
+
def _adamwscaled(self,
|
155 |
+
params: list[Tensor],
|
156 |
+
grads: list[Tensor],
|
157 |
+
exp_avgs: list[Tensor],
|
158 |
+
exp_avg_sqs: list[Tensor],
|
159 |
+
steps: list[Tensor],
|
160 |
+
kahan_comps: list[Tensor],
|
161 |
+
lr: float,
|
162 |
+
beta1: float,
|
163 |
+
beta2: float,
|
164 |
+
weight_decay: float,
|
165 |
+
eps: float,
|
166 |
+
do_kahan_sum: bool,
|
167 |
+
correct_bias: bool):
|
168 |
+
|
169 |
+
for i, p in enumerate(params):
|
170 |
+
|
171 |
+
exp_avg, exp_avg_sq, grad, step, kahan_comp = exp_avgs[i], exp_avg_sqs[i], grads[i], steps[i], kahan_comps[i]
|
172 |
+
|
173 |
+
# Decay the first and second moment running average coefficient
|
174 |
+
# In-place operations to update the averages at the same time
|
175 |
+
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
176 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1.0 - beta2))
|
177 |
+
denom = exp_avg_sq.sqrt().add_(eps)
|
178 |
+
|
179 |
+
step_size = lr
|
180 |
+
if correct_bias: # No bias correction for Bert
|
181 |
+
bias_correction1 = 1.0 - beta1 ** step
|
182 |
+
bias_correction2 = 1.0 - beta2 ** step
|
183 |
+
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
184 |
+
|
185 |
+
# Adapt Step from Adafactor
|
186 |
+
step_size = step_size * max(1e-3, self._rms(p.data))
|
187 |
+
|
188 |
+
if do_kahan_sum:
|
189 |
+
# Adam step
|
190 |
+
kahan_comp.addcdiv_(exp_avg, denom, value=-step_size)
|
191 |
+
|
192 |
+
# update weights with kahan compensation using dev_grads as temp buffer
|
193 |
+
grad.copy_(p)
|
194 |
+
p.add_(kahan_comp)
|
195 |
+
|
196 |
+
# save error back to kahan compensation for next iteration
|
197 |
+
grad.sub_(p, alpha=1)
|
198 |
+
kahan_comp.add_(grad, alpha=1)
|
199 |
+
else:
|
200 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
201 |
+
|
202 |
+
# Just adding the square of the weights to the loss function is *not*
|
203 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
204 |
+
# since that will interact with the m and v parameters in strange ways.
|
205 |
+
#
|
206 |
+
# Instead we want to decay the weights in a manner that doesn't interact
|
207 |
+
# with the m/v parameters. This is equivalent to adding the square
|
208 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
209 |
+
# Add weight decay at the end (fixed version)
|
210 |
+
if weight_decay > 0.0:
|
211 |
+
p.add_(p, alpha=(-lr * weight_decay))
|
212 |
+
|
213 |
+
def _foreach_adamwscaled(self,
|
214 |
+
params: list[Tensor],
|
215 |
+
grads: list[Tensor],
|
216 |
+
exp_avgs: list[Tensor],
|
217 |
+
exp_avg_sqs: list[Tensor],
|
218 |
+
steps: list[Tensor],
|
219 |
+
kahan_comps: list[Tensor],
|
220 |
+
lr: float,
|
221 |
+
beta1: float,
|
222 |
+
beta2: float,
|
223 |
+
weight_decay: float,
|
224 |
+
eps: float,
|
225 |
+
do_kahan_sum: bool,
|
226 |
+
correct_bias: bool):
|
227 |
+
|
228 |
+
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, kahan_comps])
|
229 |
+
|
230 |
+
for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_kahan_comps), _) in grouped_tensors.items():
|
231 |
+
# Foreach implementation
|
232 |
+
torch._foreach_mul_(dev_exp_avgs, beta1)
|
233 |
+
torch._foreach_add_(dev_exp_avgs, dev_grads, alpha=1 - beta1)
|
234 |
+
|
235 |
+
torch._foreach_mul_(dev_exp_avg_sqs, beta2)
|
236 |
+
torch._foreach_addcmul_(dev_exp_avg_sqs, dev_grads, dev_grads, 1 - beta2)
|
237 |
+
|
238 |
+
# Compute denominator
|
239 |
+
torch._foreach_copy_(dev_grads, dev_exp_avg_sqs)
|
240 |
+
torch._foreach_sqrt_(dev_grads)
|
241 |
+
torch._foreach_add_(dev_grads, eps)
|
242 |
+
|
243 |
+
step_size = [torch.tensor(lr, dtype=torch.float32, device=p.device) for p in dev_params]
|
244 |
+
|
245 |
+
if correct_bias:
|
246 |
+
torch._foreach_mul_(step_size,
|
247 |
+
[torch.tensor((math.sqrt(1 - beta2 ** steps[i].item()) / (1 - beta1 ** steps[i].item()) ), dtype=torch.float32, device=p.device)
|
248 |
+
for i, p in enumerate(dev_params)])
|
249 |
+
|
250 |
+
# Adapt step size using RMS of parameters
|
251 |
+
rms_p = torch._foreach_norm(dev_params)
|
252 |
+
numel = [torch.tensor(math.sqrt(p.numel())) for p in dev_params]
|
253 |
+
torch._foreach_div_(rms_p, numel)
|
254 |
+
torch._foreach_maximum_(rms_p, 1e-3)
|
255 |
+
|
256 |
+
torch._foreach_mul_(step_size, rms_p)
|
257 |
+
torch._foreach_div_(dev_grads, step_size)
|
258 |
+
|
259 |
+
# explicitly delete tensors when not used
|
260 |
+
del rms_p
|
261 |
+
del numel
|
262 |
+
del step_size
|
263 |
+
|
264 |
+
# Update parameters
|
265 |
+
if do_kahan_sum:
|
266 |
+
# Adam step
|
267 |
+
torch._foreach_addcdiv_(dev_kahan_comps, dev_exp_avgs, dev_grads, value=-1)
|
268 |
+
|
269 |
+
# update weights with kahan compensation using dev_grads as temp buffer
|
270 |
+
torch._foreach_copy_(dev_grads, dev_params)
|
271 |
+
torch._foreach_add_(dev_params, dev_kahan_comps, alpha=1)
|
272 |
+
|
273 |
+
# save error back to kahan compensation for next iteration
|
274 |
+
torch._foreach_sub_(dev_grads, dev_params, alpha=1)
|
275 |
+
torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1)
|
276 |
+
else:
|
277 |
+
torch._foreach_addcdiv_(dev_params, dev_exp_avgs, dev_grads, value=-1)
|
278 |
+
|
279 |
+
# Weight decay
|
280 |
+
if weight_decay > 0.0:
|
281 |
+
torch._foreach_add_(dev_params, dev_params, alpha=-weight_decay * lr)
|
attn_ref.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def attn_ref(q, k, v, b, sm_scale, dropout_p=0.0, causal=False, upcast=False):
|
4 |
+
if upcast:
|
5 |
+
q, k, v = q.float(), k.float(), v.float()
|
6 |
+
if b is not None:
|
7 |
+
b = b.float()
|
8 |
+
|
9 |
+
if b is not None:
|
10 |
+
if (b.shape[0] != q.shape[0]) or (b.shape[1] != q.shape[1]):
|
11 |
+
b = b.expand(q.shape[0], q.shape[1], q.shape[2], k.shape[2])
|
12 |
+
|
13 |
+
ms = torch.arange(q.shape[2], device=q.device).unsqueeze(-1)
|
14 |
+
ns = torch.arange(k.shape[2], device=q.device)
|
15 |
+
|
16 |
+
p = torch.matmul(q, k.transpose(2, 3))
|
17 |
+
p *= sm_scale
|
18 |
+
if b is not None:
|
19 |
+
p += b
|
20 |
+
|
21 |
+
if causal:
|
22 |
+
p = torch.where(ms + k.shape[2] - q.shape[2] >= ns, p, float("-inf"))
|
23 |
+
|
24 |
+
p = torch.softmax(p.float(), dim=-1).to(q.dtype)
|
25 |
+
if dropout_p > 0.0:
|
26 |
+
p = torch.dropout(p, dropout_p, train=True)
|
27 |
+
|
28 |
+
ref_out = torch.matmul(p, v)
|
29 |
+
return ref_out
|
config.json
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alibi_mode": "symetric",
|
3 |
+
"architectures": [
|
4 |
+
"FlashT5ForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"attention_dropout_rate": 0.0,
|
7 |
+
"attention_scale": 1.0,
|
8 |
+
"attention_type": "triton",
|
9 |
+
"auto_map": {
|
10 |
+
"AutoConfig": "configuration_flash_t5.FlashT5Config",
|
11 |
+
"AutoModel": "modeling_flash_t5.FlashT5EncoderModel",
|
12 |
+
"AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
|
13 |
+
"AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
|
14 |
+
"AutoModelForSequenceClassification": "custom_heads_flash_t5.FlashT5ForSequenceClassification",
|
15 |
+
"AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification"
|
16 |
+
},
|
17 |
+
"classifier_dropout": 0.0,
|
18 |
+
"crossentropy_inplace_backward": false,
|
19 |
+
"d_ff": 2048,
|
20 |
+
"d_kv": 64,
|
21 |
+
"d_model": 512,
|
22 |
+
"decoder_start_token_id": 0,
|
23 |
+
"dense_act_fn": "relu",
|
24 |
+
"dropout_rate": 0.0,
|
25 |
+
"eos_token_id": 1,
|
26 |
+
"feed_forward_proj": "relu",
|
27 |
+
"fire_mlp_width": 32,
|
28 |
+
"initializer_factor": 1.0,
|
29 |
+
"is_encoder_decoder": false,
|
30 |
+
"is_gated_act": false,
|
31 |
+
"label_smoothing": 0.0,
|
32 |
+
"layer_norm_epsilon": 1e-06,
|
33 |
+
"max_sequence_length": 1024,
|
34 |
+
"model_type": "flash_t5",
|
35 |
+
"num_decoder_layers": 12,
|
36 |
+
"num_heads": 8,
|
37 |
+
"num_layers": 12,
|
38 |
+
"pad_token_id": 3,
|
39 |
+
"position_encoding_type": "t5",
|
40 |
+
"relative_attention_max_distance": 128,
|
41 |
+
"relative_attention_num_buckets": 32,
|
42 |
+
"rotary_base": 10000,
|
43 |
+
"rotary_emb_fraction": 1.0,
|
44 |
+
"rotary_interleaved": false,
|
45 |
+
"rotary_scale_base": null,
|
46 |
+
"tie_word_embeddings": false,
|
47 |
+
"torch_dtype": "float32",
|
48 |
+
"transformers_version": "4.46.0.dev0",
|
49 |
+
"use_cache": true,
|
50 |
+
"use_full_bias_size": false,
|
51 |
+
"use_gelu_act": true,
|
52 |
+
"use_glu_mlp": true,
|
53 |
+
"use_masking": false,
|
54 |
+
"use_randomized_position_encoding": false,
|
55 |
+
"use_triton_crossentropy": true,
|
56 |
+
"use_triton_layernorm": true,
|
57 |
+
"vocab_size": 32768,
|
58 |
+
"z_loss": 0.0001
|
59 |
+
}
|
configuration_flash_t5.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from collections import OrderedDict
|
3 |
+
from typing import Mapping
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from transformers import T5Config
|
7 |
+
|
8 |
+
AUTO_MAP = {
|
9 |
+
"AutoModel": "modeling_flash_t5.FlashT5EncoderModel",
|
10 |
+
"AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
|
11 |
+
"AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification",
|
12 |
+
"AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
|
13 |
+
"AutoModelForSequenceClassification": "custom_heads_flash_t5.FlashT5ForSequenceClassification",
|
14 |
+
}
|
15 |
+
|
16 |
+
class FlashT5Config(T5Config):
|
17 |
+
|
18 |
+
model_type = "flash_t5"
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
decoder_start_token_id=0,
|
23 |
+
pad_token_id=-100,
|
24 |
+
use_glu_mlp=False,
|
25 |
+
position_encoding_type="t5",
|
26 |
+
use_randomized_position_encoding=False,
|
27 |
+
label_smoothing=0.0,
|
28 |
+
z_loss=None,
|
29 |
+
attention_type="ref",
|
30 |
+
max_sequence_length=1024,
|
31 |
+
attention_dropout_rate=0.0,
|
32 |
+
alibi_mode="symetric",
|
33 |
+
use_triton_layernorm=False,
|
34 |
+
use_triton_crossentropy=False,
|
35 |
+
crossentropy_inplace_backward=False,
|
36 |
+
use_gelu_act=True,
|
37 |
+
use_full_bias_size=False,
|
38 |
+
rotary_emb_fraction=1.0,
|
39 |
+
rotary_base=10000,
|
40 |
+
rotary_interleaved=False,
|
41 |
+
rotary_scale_base=None,
|
42 |
+
fire_mlp_width=32,
|
43 |
+
use_masking=False,
|
44 |
+
attention_scale=None,
|
45 |
+
**kwargs,
|
46 |
+
):
|
47 |
+
super().__init__(**kwargs)
|
48 |
+
|
49 |
+
self.decoder_start_token_id = decoder_start_token_id
|
50 |
+
self.pad_token_id = pad_token_id
|
51 |
+
self.use_glu_mlp = use_glu_mlp
|
52 |
+
self.position_encoding_type = position_encoding_type
|
53 |
+
self.use_randomized_position_encoding = use_randomized_position_encoding
|
54 |
+
self.label_smoothing = label_smoothing
|
55 |
+
self.z_loss = z_loss
|
56 |
+
self.attention_type = attention_type
|
57 |
+
self.max_sequence_length = max_sequence_length
|
58 |
+
self.alibi_mode = alibi_mode
|
59 |
+
self.attention_dropout_rate = attention_dropout_rate
|
60 |
+
self.use_triton_layernorm = use_triton_layernorm
|
61 |
+
self.use_triton_crossentropy = use_triton_crossentropy
|
62 |
+
self.crossentropy_inplace_backward = crossentropy_inplace_backward
|
63 |
+
self.use_gelu_act = use_gelu_act
|
64 |
+
self.use_full_bias_size = use_full_bias_size
|
65 |
+
self.rotary_base = rotary_base
|
66 |
+
self.rotary_interleaved = rotary_interleaved
|
67 |
+
self.rotary_scale_base = rotary_scale_base
|
68 |
+
self.rotary_emb_fraction = rotary_emb_fraction
|
69 |
+
self.fire_mlp_width = fire_mlp_width
|
70 |
+
self.use_masking = use_masking
|
71 |
+
self.attention_scale = attention_scale
|
72 |
+
|
73 |
+
self.auto_map = AUTO_MAP
|
74 |
+
|
75 |
+
def str_to_class(classname):
|
76 |
+
return getattr(sys.modules[__name__], classname)
|
77 |
+
|
78 |
+
# Register model in Auto API
|
79 |
+
try:
|
80 |
+
FlashT5Config.register_for_auto_class()
|
81 |
+
for key, value in AUTO_MAP.items():
|
82 |
+
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
83 |
+
except:
|
84 |
+
logging.warn("AutoRegister isn't available.")
|
cross_entropy_loss.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
# Copyright 2024 CATIE. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# Modification to the original version from Tri Dao:
|
17 |
+
# - support for torch.compile
|
18 |
+
|
19 |
+
from typing import Tuple, Optional
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
import triton
|
25 |
+
import triton.language as tl
|
26 |
+
|
27 |
+
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
28 |
+
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
29 |
+
# version of PyTorch. The following 2 lines are for backward compatibility with
|
30 |
+
# older PyTorch.
|
31 |
+
if "all_gather_into_tensor" not in dir(torch.distributed):
|
32 |
+
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
33 |
+
|
34 |
+
|
35 |
+
@triton.heuristics(
|
36 |
+
{
|
37 |
+
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
|
38 |
+
}
|
39 |
+
)
|
40 |
+
@triton.jit
|
41 |
+
def cross_entropy_fwd_kernel(
|
42 |
+
loss_ptr, # data ptrs
|
43 |
+
lse_ptr,
|
44 |
+
z_loss_ptr,
|
45 |
+
logits_ptr,
|
46 |
+
labels_ptr,
|
47 |
+
smoothing,
|
48 |
+
logit_scale,
|
49 |
+
lse_square_scale,
|
50 |
+
ignore_index,
|
51 |
+
total_classes,
|
52 |
+
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
53 |
+
n_cols, # shapes
|
54 |
+
logits_row_stride, # strides
|
55 |
+
BLOCK_SIZE: tl.constexpr,
|
56 |
+
HAS_SMOOTHING: tl.constexpr,
|
57 |
+
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
|
58 |
+
SPLIT: tl.constexpr,
|
59 |
+
PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
|
60 |
+
):
|
61 |
+
row_idx = tl.program_id(0)
|
62 |
+
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
63 |
+
sum_logits = 0.0 # For smoothing
|
64 |
+
if not PRECOMPUTED_LSE:
|
65 |
+
# Statistics for online softmax
|
66 |
+
m_i = -float("inf")
|
67 |
+
l_i = 0.0
|
68 |
+
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
69 |
+
cols = col_offset + tl.arange(0, BLOCK_SIZE)
|
70 |
+
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
|
71 |
+
tl.float32
|
72 |
+
) * logit_scale
|
73 |
+
if HAS_SMOOTHING:
|
74 |
+
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
|
75 |
+
m_i_new = tl.maximum(m_i, tl.max(logits))
|
76 |
+
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
|
77 |
+
m_i = m_i_new
|
78 |
+
lse = tl.log(l_i) + m_i
|
79 |
+
tl.store(lse_ptr + row_idx, lse)
|
80 |
+
else:
|
81 |
+
lse = tl.load(lse_ptr + row_idx)
|
82 |
+
label_idx = tl.load(labels_ptr + row_idx)
|
83 |
+
if label_idx == ignore_index:
|
84 |
+
loss = 0.0
|
85 |
+
z_loss = 0.0
|
86 |
+
else:
|
87 |
+
label_idx -= class_start_idx
|
88 |
+
if label_idx >= 0 and label_idx < n_cols:
|
89 |
+
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
|
90 |
+
if HAS_SMOOTHING:
|
91 |
+
loss = (
|
92 |
+
(lse if not SPLIT else 0.0)
|
93 |
+
- smoothing * sum_logits / total_classes
|
94 |
+
- (1 - smoothing) * logits_label
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
loss = (lse if not SPLIT else 0.0) - logits_label
|
98 |
+
else:
|
99 |
+
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
|
100 |
+
if HAS_SMOOTHING:
|
101 |
+
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
|
102 |
+
else:
|
103 |
+
loss = 0.0
|
104 |
+
if not SPLIT:
|
105 |
+
z_loss = lse_square_scale * lse * lse
|
106 |
+
loss += z_loss
|
107 |
+
else:
|
108 |
+
z_loss = 0.0
|
109 |
+
tl.store(loss_ptr + row_idx, loss)
|
110 |
+
if not SPLIT:
|
111 |
+
tl.store(z_loss_ptr + row_idx, z_loss)
|
112 |
+
|
113 |
+
|
114 |
+
@triton.heuristics(
|
115 |
+
{
|
116 |
+
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
|
117 |
+
}
|
118 |
+
)
|
119 |
+
@triton.jit
|
120 |
+
def cross_entropy_bwd_kernel(
|
121 |
+
dlogits_ptr, # data ptrs
|
122 |
+
dloss_ptr,
|
123 |
+
logits_ptr,
|
124 |
+
lse_ptr,
|
125 |
+
labels_ptr,
|
126 |
+
smoothing,
|
127 |
+
logit_scale,
|
128 |
+
lse_square_scale,
|
129 |
+
ignore_index,
|
130 |
+
total_classes,
|
131 |
+
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
132 |
+
n_cols, # shapes
|
133 |
+
logits_row_stride, # strides
|
134 |
+
dlogits_row_stride,
|
135 |
+
dloss_row_stride,
|
136 |
+
BLOCK_SIZE: tl.constexpr,
|
137 |
+
HAS_SMOOTHING: tl.constexpr,
|
138 |
+
):
|
139 |
+
row_idx = tl.program_id(0)
|
140 |
+
col_block_idx = tl.program_id(1)
|
141 |
+
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
142 |
+
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
|
143 |
+
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
144 |
+
label_idx = tl.load(labels_ptr + row_idx)
|
145 |
+
if label_idx != ignore_index:
|
146 |
+
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
|
147 |
+
else:
|
148 |
+
dloss = 0.0
|
149 |
+
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
150 |
+
tl.float32
|
151 |
+
) * logit_scale
|
152 |
+
lse = tl.load(lse_ptr + row_idx)
|
153 |
+
probs = tl.exp(logits - lse)
|
154 |
+
probs += 2.0 * lse_square_scale * lse * probs
|
155 |
+
label_idx -= class_start_idx
|
156 |
+
if HAS_SMOOTHING:
|
157 |
+
smooth_positive = 1.0 - smoothing
|
158 |
+
smooth_negative = smoothing / total_classes
|
159 |
+
probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
|
160 |
+
else:
|
161 |
+
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
162 |
+
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
|
163 |
+
|
164 |
+
@torch.library.custom_op("flasht5::cross_entropy_triton_fwd", mutates_args=(), device_types="cuda")
|
165 |
+
def cross_entropy_triton_fwd(
|
166 |
+
logits: torch.Tensor,
|
167 |
+
labels: torch.Tensor,
|
168 |
+
precomputed_lse: torch.Tensor,
|
169 |
+
use_precomputed_lse: bool,
|
170 |
+
split: bool,
|
171 |
+
smoothing: float,
|
172 |
+
logit_scale: float,
|
173 |
+
lse_square_scale: float,
|
174 |
+
ignore_index: int,
|
175 |
+
total_classes: int,
|
176 |
+
class_start_idx: int,
|
177 |
+
n_cols: int,
|
178 |
+
n_rows: int,
|
179 |
+
BLOCK_SIZE: int,
|
180 |
+
num_warps: int
|
181 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
182 |
+
|
183 |
+
if logits.stride(-1) != 1:
|
184 |
+
logits = logits.contiguous()
|
185 |
+
|
186 |
+
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
187 |
+
if use_precomputed_lse:
|
188 |
+
assert precomputed_lse.shape == (n_rows,)
|
189 |
+
lse = precomputed_lse.contiguous()
|
190 |
+
else:
|
191 |
+
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
192 |
+
|
193 |
+
z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
194 |
+
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
195 |
+
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
196 |
+
with torch.cuda.device(logits.device.index):
|
197 |
+
cross_entropy_fwd_kernel[(n_rows,)](
|
198 |
+
losses, # data ptrs
|
199 |
+
lse,
|
200 |
+
z_losses,
|
201 |
+
logits,
|
202 |
+
labels,
|
203 |
+
smoothing,
|
204 |
+
logit_scale,
|
205 |
+
lse_square_scale,
|
206 |
+
ignore_index,
|
207 |
+
total_classes,
|
208 |
+
class_start_idx,
|
209 |
+
n_cols, # shapes
|
210 |
+
logits.stride(0), # strides
|
211 |
+
BLOCK_SIZE=BLOCK_SIZE, # constants
|
212 |
+
SPLIT=split,
|
213 |
+
PRECOMPUTED_LSE=use_precomputed_lse,
|
214 |
+
num_warps=num_warps,
|
215 |
+
)
|
216 |
+
|
217 |
+
return losses, z_losses, lse
|
218 |
+
|
219 |
+
|
220 |
+
@torch.library.register_fake("flasht5::cross_entropy_triton_fwd")
|
221 |
+
def cross_entropy_triton_fwd_abstract(logits, labels, precomputed_lse, use_precomputed_lse, split, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps):
|
222 |
+
losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
|
223 |
+
z_losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
|
224 |
+
logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
|
225 |
+
|
226 |
+
return losses, z_losses, logsumexp
|
227 |
+
|
228 |
+
@torch.library.custom_op("flasht5::cross_entropy_triton_bwd", mutates_args={"logits"}, device_types="cuda")
|
229 |
+
def cross_entropy_triton_bwd(
|
230 |
+
dlosses: torch.Tensor,
|
231 |
+
logits: torch.Tensor,
|
232 |
+
lse: torch.Tensor,
|
233 |
+
labels: torch.Tensor,
|
234 |
+
inplace_backward: bool,
|
235 |
+
smoothing: float,
|
236 |
+
logit_scale: float,
|
237 |
+
lse_square_scale: float,
|
238 |
+
ignore_index: int,
|
239 |
+
total_classes: int,
|
240 |
+
class_start_idx: int,
|
241 |
+
n_cols: int,
|
242 |
+
n_rows: int,
|
243 |
+
BLOCK_SIZE: int,
|
244 |
+
num_warps: int
|
245 |
+
) -> torch.Tensor:
|
246 |
+
|
247 |
+
dlogits = logits if inplace_backward else torch.empty_like(logits)
|
248 |
+
|
249 |
+
grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
|
250 |
+
|
251 |
+
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
252 |
+
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
253 |
+
with torch.cuda.device(logits.device.index):
|
254 |
+
cross_entropy_bwd_kernel[grid](
|
255 |
+
dlogits, # data ptrs
|
256 |
+
dlosses,
|
257 |
+
logits,
|
258 |
+
lse,
|
259 |
+
labels,
|
260 |
+
smoothing,
|
261 |
+
logit_scale,
|
262 |
+
lse_square_scale,
|
263 |
+
ignore_index,
|
264 |
+
total_classes,
|
265 |
+
class_start_idx,
|
266 |
+
n_cols, # shapes
|
267 |
+
logits.stride(0), # strides
|
268 |
+
dlogits.stride(0),
|
269 |
+
dlosses.stride(0),
|
270 |
+
BLOCK_SIZE=BLOCK_SIZE, # constants
|
271 |
+
num_warps=num_warps,
|
272 |
+
)
|
273 |
+
|
274 |
+
return dlogits if not inplace_backward else None
|
275 |
+
|
276 |
+
@torch.library.register_fake("flasht5::cross_entropy_triton_bwd")
|
277 |
+
def cross_entropy_triton_bwd_abstract(dlosses, logits, lse, labels, inplace_backward, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps):
|
278 |
+
return torch.empty_like(logits)
|
279 |
+
|
280 |
+
class CrossEntropyLoss(torch.autograd.Function):
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
def forward(
|
284 |
+
ctx,
|
285 |
+
logits,
|
286 |
+
labels,
|
287 |
+
precomputed_lse=None,
|
288 |
+
smoothing=0.0,
|
289 |
+
logit_scale=1.0,
|
290 |
+
lse_square_scale=0.0,
|
291 |
+
ignore_index=-100,
|
292 |
+
inplace_backward=False,
|
293 |
+
process_group=None,
|
294 |
+
):
|
295 |
+
# For some reason Triton generates wrong code when labels has dtype long and its address
|
296 |
+
# is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
|
297 |
+
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
|
298 |
+
labels = F.pad(labels, (0, 1))[..., :-1]
|
299 |
+
assert labels.data_ptr() % 16 == 0
|
300 |
+
|
301 |
+
n_rows, n_cols = logits.shape
|
302 |
+
assert labels.shape == (n_rows,)
|
303 |
+
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
304 |
+
total_classes = world_size * n_cols
|
305 |
+
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
|
306 |
+
class_start_idx = rank * n_cols
|
307 |
+
use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
|
308 |
+
|
309 |
+
MAX_BLOCK_SIZE = 16 * 1024
|
310 |
+
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
|
311 |
+
num_warps = (
|
312 |
+
4
|
313 |
+
if BLOCK_SIZE < 2048
|
314 |
+
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
|
315 |
+
)
|
316 |
+
|
317 |
+
losses, z_losses, lse = torch.ops.flasht5.cross_entropy_triton_fwd(
|
318 |
+
logits, labels, precomputed_lse, use_precomputed_lse, \
|
319 |
+
world_size > 1, smoothing, logit_scale, lse_square_scale, \
|
320 |
+
ignore_index, total_classes, class_start_idx, \
|
321 |
+
n_cols, n_rows, BLOCK_SIZE, num_warps
|
322 |
+
)
|
323 |
+
|
324 |
+
if world_size > 1:
|
325 |
+
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
326 |
+
# - predicted logit, and 0 otherwise.
|
327 |
+
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
328 |
+
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
|
329 |
+
# For labels not in the vocab of this partition, losses contains
|
330 |
+
# -0.1 * sum logit / total_classes.
|
331 |
+
if world_size > 1:
|
332 |
+
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
|
333 |
+
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
|
334 |
+
handle_losses = torch.distributed.all_reduce(
|
335 |
+
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
336 |
+
)
|
337 |
+
lse = torch.logsumexp(lse_allgather, dim=0)
|
338 |
+
handle_losses.wait()
|
339 |
+
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
|
340 |
+
# we just have to add the (global) lse.
|
341 |
+
# If there's smoothing=0.1, the total losses are
|
342 |
+
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
|
343 |
+
# Again, we just have to add the (global) lse.
|
344 |
+
losses += lse
|
345 |
+
if lse_square_scale != 0.0:
|
346 |
+
z_losses = lse_square_scale * lse.square()
|
347 |
+
z_losses.masked_fill_(labels == ignore_index, 0.0)
|
348 |
+
losses += z_losses
|
349 |
+
else:
|
350 |
+
z_losses = torch.zeros_like(losses)
|
351 |
+
losses.masked_fill_(labels == ignore_index, 0.0)
|
352 |
+
|
353 |
+
ctx.save_for_backward(logits, lse, labels)
|
354 |
+
ctx.mark_non_differentiable(z_losses)
|
355 |
+
ctx.smoothing = smoothing
|
356 |
+
ctx.logit_scale = logit_scale
|
357 |
+
ctx.lse_square_scale = lse_square_scale
|
358 |
+
ctx.ignore_index = ignore_index
|
359 |
+
ctx.total_classes = total_classes
|
360 |
+
ctx.class_start_idx = class_start_idx
|
361 |
+
ctx.inplace_backward = inplace_backward
|
362 |
+
|
363 |
+
return losses, z_losses
|
364 |
+
|
365 |
+
@staticmethod
|
366 |
+
def backward(ctx, grad_losses, grad_z_losses):
|
367 |
+
del grad_z_losses # z_losses are only for logging.
|
368 |
+
|
369 |
+
logits, lse, labels = ctx.saved_tensors
|
370 |
+
|
371 |
+
n_rows, n_cols = logits.shape
|
372 |
+
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
|
373 |
+
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
|
374 |
+
|
375 |
+
dlogits = torch.ops.flasht5.cross_entropy_triton_bwd(
|
376 |
+
grad_losses, logits, lse, labels, \
|
377 |
+
ctx.inplace_backward, ctx.smoothing, ctx.logit_scale, \
|
378 |
+
ctx.lse_square_scale, ctx.ignore_index, ctx.total_classes, \
|
379 |
+
ctx.class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps
|
380 |
+
)
|
381 |
+
|
382 |
+
if ctx.inplace_backward:
|
383 |
+
dlogits = logits
|
384 |
+
|
385 |
+
return dlogits, None, None, None, None, None, None, None, None, None
|
386 |
+
|
387 |
+
|
388 |
+
def cross_entropy_loss(
|
389 |
+
logits: torch.Tensor,
|
390 |
+
labels: torch.Tensor,
|
391 |
+
precomputed_lse: Optional[torch.Tensor] = None,
|
392 |
+
label_smoothing: float = 0.0,
|
393 |
+
logit_scale: float = 1.0,
|
394 |
+
lse_square_scale: float = 0.0,
|
395 |
+
ignore_index=-100,
|
396 |
+
inplace_backward: bool = False,
|
397 |
+
process_group=None,
|
398 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
399 |
+
"""
|
400 |
+
Arguments:
|
401 |
+
logits: (batch, vocab_size)
|
402 |
+
labels: (batch,)
|
403 |
+
label_smoothing: float
|
404 |
+
logit_scale: float. Multiply logits by this scale before calculating the loss.
|
405 |
+
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
406 |
+
This is also referred to as "z-loss".
|
407 |
+
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
|
408 |
+
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
409 |
+
This saves memory.
|
410 |
+
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
411 |
+
one part of the vocab. The loss will be aggregated across processes.
|
412 |
+
Returns:
|
413 |
+
losses: (batch,), float
|
414 |
+
z_losses: (batch,), float
|
415 |
+
"""
|
416 |
+
return CrossEntropyLoss.apply(
|
417 |
+
logits.view(-1, logits.shape[-1]),
|
418 |
+
labels.view(-1),
|
419 |
+
precomputed_lse,
|
420 |
+
label_smoothing,
|
421 |
+
logit_scale,
|
422 |
+
lse_square_scale,
|
423 |
+
ignore_index,
|
424 |
+
inplace_backward,
|
425 |
+
process_group,
|
426 |
+
)
|
custom_heads_flash_t5.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
4 |
+
import copy
|
5 |
+
from typing import Optional, Union, Tuple, List
|
6 |
+
from transformers.modeling_outputs import (
|
7 |
+
Seq2SeqQuestionAnsweringModelOutput,
|
8 |
+
QuestionAnsweringModelOutput,
|
9 |
+
TokenClassifierOutput,
|
10 |
+
BaseModelOutput,
|
11 |
+
Seq2SeqSequenceClassifierOutput,
|
12 |
+
SequenceClassifierOutput
|
13 |
+
)
|
14 |
+
|
15 |
+
from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model
|
16 |
+
from .configuration_flash_t5 import FlashT5Config
|
17 |
+
|
18 |
+
|
19 |
+
################## Encoder only head ##################
|
20 |
+
class FlashT5ForTokenClassification(FlashT5PreTrainedModel):
|
21 |
+
|
22 |
+
def __init__(self, config: FlashT5Config):
|
23 |
+
super().__init__(config)
|
24 |
+
self.num_labels = config.num_labels
|
25 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
26 |
+
|
27 |
+
self.encoder = FlashT5Stack(config, self.shared)
|
28 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
29 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
30 |
+
|
31 |
+
# Initialize weights and apply final processing
|
32 |
+
self.post_init()
|
33 |
+
|
34 |
+
# Initialize classifier
|
35 |
+
self.classifier.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
|
36 |
+
self.classifier.bias.data.zero_()
|
37 |
+
|
38 |
+
self.model_parallel = False
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
input_ids: Optional[torch.Tensor] = None,
|
43 |
+
attention_mask: Optional[torch.Tensor] = None,
|
44 |
+
head_mask: Optional[torch.Tensor] = None,
|
45 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
46 |
+
labels: Optional[torch.Tensor] = None,
|
47 |
+
output_attentions: Optional[bool] = None,
|
48 |
+
output_hidden_states: Optional[bool] = None,
|
49 |
+
return_dict: Optional[bool] = None,
|
50 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
51 |
+
r"""
|
52 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
53 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
54 |
+
Returns:
|
55 |
+
"""
|
56 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
57 |
+
|
58 |
+
outputs = self.encoder(
|
59 |
+
input_ids=input_ids,
|
60 |
+
attention_mask=attention_mask,
|
61 |
+
inputs_embeds=inputs_embeds,
|
62 |
+
head_mask=head_mask,
|
63 |
+
output_attentions=output_attentions,
|
64 |
+
output_hidden_states=output_hidden_states,
|
65 |
+
return_dict=return_dict,
|
66 |
+
)
|
67 |
+
|
68 |
+
hidden_states = outputs[0]
|
69 |
+
hidden_states = self.dropout(hidden_states)
|
70 |
+
logits = self.classifier(hidden_states)
|
71 |
+
|
72 |
+
loss = None
|
73 |
+
if labels is not None:
|
74 |
+
loss_fct = nn.CrossEntropyLoss()
|
75 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
76 |
+
|
77 |
+
if not return_dict:
|
78 |
+
output = (logits, outputs[2:-1])
|
79 |
+
return ((loss,) + output) if loss is not None else output
|
80 |
+
|
81 |
+
return TokenClassifierOutput(
|
82 |
+
loss=loss,
|
83 |
+
logits=logits,
|
84 |
+
hidden_states=outputs.hidden_states,
|
85 |
+
attentions=outputs.attentions,
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
class FlashT5ClassificationHead(nn.Module):
|
90 |
+
"""Head for sentence-level classification tasks."""
|
91 |
+
|
92 |
+
def __init__(self, config: FlashT5Config):
|
93 |
+
super().__init__()
|
94 |
+
self.dense = nn.Linear(config.d_model, config.d_model)
|
95 |
+
self.dropout = nn.Dropout(p=config.classifier_dropout)
|
96 |
+
self.out_proj = nn.Linear(config.d_model, config.num_labels)
|
97 |
+
|
98 |
+
# initialize weights
|
99 |
+
factor = config.initializer_factor
|
100 |
+
self.dense.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
|
101 |
+
if hasattr(self.dense, "bias") and self.dense.bias is not None:
|
102 |
+
self.dense.bias.data.zero_()
|
103 |
+
self.out_proj.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
|
104 |
+
if hasattr(self.out_proj, "bias") and self.out_proj.bias is not None:
|
105 |
+
self.out_proj.bias.data.zero_()
|
106 |
+
|
107 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
108 |
+
hidden_states = self.dropout(hidden_states)
|
109 |
+
hidden_states = self.dense(hidden_states)
|
110 |
+
hidden_states = torch.tanh(hidden_states)
|
111 |
+
hidden_states = self.dropout(hidden_states)
|
112 |
+
hidden_states = self.out_proj(hidden_states)
|
113 |
+
return hidden_states
|
114 |
+
|
115 |
+
|
116 |
+
class FlashT5ForSequenceClassification(FlashT5PreTrainedModel):
|
117 |
+
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
|
118 |
+
|
119 |
+
def __init__(self, config: FlashT5Config):
|
120 |
+
super().__init__(config)
|
121 |
+
self.model_dim = config.d_model
|
122 |
+
self.config.problem_type = None
|
123 |
+
self.config.is_encoder_decoder = False
|
124 |
+
|
125 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
126 |
+
|
127 |
+
encoder_config = copy.deepcopy(config)
|
128 |
+
encoder_config.is_decoder = False
|
129 |
+
encoder_config.is_encoder_decoder = False
|
130 |
+
encoder_config.use_cache = False
|
131 |
+
self.encoder = FlashT5Stack(encoder_config, self.shared)
|
132 |
+
self.classification_head = FlashT5ClassificationHead(config)
|
133 |
+
|
134 |
+
# Initialize weights and apply final processing
|
135 |
+
self.post_init()
|
136 |
+
|
137 |
+
self.model_parallel = False
|
138 |
+
|
139 |
+
def forward(
|
140 |
+
self,
|
141 |
+
input_ids: torch.LongTensor = None,
|
142 |
+
attention_mask: Optional[torch.Tensor] = None,
|
143 |
+
head_mask: Optional[torch.Tensor] = None,
|
144 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
145 |
+
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
146 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
147 |
+
labels: Optional[torch.LongTensor] = None,
|
148 |
+
use_cache: Optional[bool] = None,
|
149 |
+
output_attentions: Optional[bool] = None,
|
150 |
+
output_hidden_states: Optional[bool] = None,
|
151 |
+
return_dict: Optional[bool] = None,
|
152 |
+
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
153 |
+
r"""
|
154 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
155 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
156 |
+
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
157 |
+
Returns:
|
158 |
+
"""
|
159 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
160 |
+
if labels is not None:
|
161 |
+
use_cache = False
|
162 |
+
|
163 |
+
if input_ids is None and inputs_embeds is not None:
|
164 |
+
raise NotImplementedError(
|
165 |
+
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
outputs = self.encoder(
|
170 |
+
input_ids=input_ids,
|
171 |
+
attention_mask=attention_mask,
|
172 |
+
inputs_embeds=inputs_embeds,
|
173 |
+
head_mask=head_mask,
|
174 |
+
output_attentions=output_attentions,
|
175 |
+
output_hidden_states=output_hidden_states,
|
176 |
+
return_dict=return_dict,
|
177 |
+
)
|
178 |
+
sequence_output = outputs[0]
|
179 |
+
|
180 |
+
eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
|
181 |
+
|
182 |
+
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
183 |
+
raise ValueError("All examples must have the same number of <eos> tokens.")
|
184 |
+
batch_size, _, hidden_size = sequence_output.shape
|
185 |
+
sentence_representation = sequence_output[:, -1, :]
|
186 |
+
# sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
|
187 |
+
logits = self.classification_head(sentence_representation)
|
188 |
+
|
189 |
+
loss = None
|
190 |
+
if labels is not None:
|
191 |
+
labels = labels.to(logits.device)
|
192 |
+
if self.config.problem_type is None:
|
193 |
+
if self.config.num_labels == 1:
|
194 |
+
self.config.problem_type = "regression"
|
195 |
+
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
196 |
+
self.config.problem_type = "single_label_classification"
|
197 |
+
else:
|
198 |
+
self.config.problem_type = "multi_label_classification"
|
199 |
+
|
200 |
+
if self.config.problem_type == "regression":
|
201 |
+
loss_fct = nn.MSELoss()
|
202 |
+
if self.config.num_labels == 1:
|
203 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
204 |
+
else:
|
205 |
+
loss = loss_fct(logits, labels)
|
206 |
+
elif self.config.problem_type == "single_label_classification":
|
207 |
+
loss_fct = nn.CrossEntropyLoss()
|
208 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
209 |
+
elif self.config.problem_type == "multi_label_classification":
|
210 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
211 |
+
loss = loss_fct(logits, labels)
|
212 |
+
if not return_dict:
|
213 |
+
output = (logits,) + outputs[1:]
|
214 |
+
return ((loss,) + output) if loss is not None else output
|
215 |
+
|
216 |
+
return SequenceClassifierOutput(
|
217 |
+
loss=loss,
|
218 |
+
logits=logits,
|
219 |
+
hidden_states=outputs.hidden_states,
|
220 |
+
attentions=outputs.attentions
|
221 |
+
)
|
222 |
+
|
223 |
+
|
224 |
+
class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
|
225 |
+
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
|
226 |
+
|
227 |
+
def __init__(self, config: FlashT5Config):
|
228 |
+
super().__init__(config)
|
229 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
230 |
+
|
231 |
+
encoder_config = copy.deepcopy(config)
|
232 |
+
encoder_config.is_decoder = False
|
233 |
+
encoder_config.is_encoder_decoder = False
|
234 |
+
self.encoder = FlashT5Stack(encoder_config, self.shared)
|
235 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
236 |
+
|
237 |
+
# Initialize weights and apply final processing
|
238 |
+
self.post_init()
|
239 |
+
|
240 |
+
self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
|
241 |
+
self.qa_outputs.bias.data.zero_()
|
242 |
+
|
243 |
+
self.model_parallel = False
|
244 |
+
|
245 |
+
def forward(
|
246 |
+
self,
|
247 |
+
input_ids: Optional[torch.LongTensor] = None,
|
248 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
249 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
250 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
251 |
+
start_positions: Optional[torch.LongTensor] = None,
|
252 |
+
end_positions: Optional[torch.LongTensor] = None,
|
253 |
+
output_attentions: Optional[bool] = None,
|
254 |
+
output_hidden_states: Optional[bool] = None,
|
255 |
+
return_dict: Optional[bool] = None,
|
256 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
257 |
+
r"""
|
258 |
+
Returns:
|
259 |
+
|
260 |
+
Example:
|
261 |
+
|
262 |
+
```python
|
263 |
+
>>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
|
264 |
+
|
265 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
|
266 |
+
>>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
|
267 |
+
>>> input_ids = tokenizer(
|
268 |
+
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
|
269 |
+
... ).input_ids # Batch size 1
|
270 |
+
>>> outputs = model(input_ids=input_ids)
|
271 |
+
>>> start_logits = outputs.start_logits
|
272 |
+
>>> end_logits = outputs.end_logits
|
273 |
+
```"""
|
274 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
275 |
+
|
276 |
+
outputs = self.encoder(
|
277 |
+
input_ids,
|
278 |
+
attention_mask=attention_mask,
|
279 |
+
inputs_embeds=inputs_embeds,
|
280 |
+
)
|
281 |
+
sequence_output = outputs[0]
|
282 |
+
|
283 |
+
logits = self.qa_outputs(sequence_output)
|
284 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
285 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
286 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
287 |
+
|
288 |
+
total_loss = None
|
289 |
+
if start_positions is not None and end_positions is not None:
|
290 |
+
# If we are on multi-GPU, split add a dimension
|
291 |
+
if len(start_positions.size()) > 1:
|
292 |
+
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
293 |
+
if len(end_positions.size()) > 1:
|
294 |
+
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
295 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
296 |
+
ignored_index = start_logits.size(1)
|
297 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
298 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
299 |
+
|
300 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
301 |
+
start_loss = loss_fct(start_logits, start_positions)
|
302 |
+
end_loss = loss_fct(end_logits, end_positions)
|
303 |
+
total_loss = (start_loss + end_loss) / 2
|
304 |
+
|
305 |
+
if not return_dict:
|
306 |
+
output = (start_logits, end_logits) + outputs[1:]
|
307 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
308 |
+
|
309 |
+
return QuestionAnsweringModelOutput(
|
310 |
+
loss=total_loss,
|
311 |
+
start_logits=start_logits,
|
312 |
+
end_logits=end_logits,
|
313 |
+
hidden_states=outputs.hidden_states,
|
314 |
+
attentions=outputs.attentions,
|
315 |
+
)
|
flash_attention_v2_bias.py
ADDED
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 BAAI
|
2 |
+
# Copyright 2024 CATIE
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# Modifications to the orignal file
|
17 |
+
# - Support for biases following https://github.com/FlagOpen/FlagAttention/pull/5
|
18 |
+
# - Support for shape (1,1,q,k) biases
|
19 |
+
|
20 |
+
import math
|
21 |
+
import torch
|
22 |
+
import triton
|
23 |
+
import triton.language as tl
|
24 |
+
|
25 |
+
from typing import Tuple
|
26 |
+
|
27 |
+
@torch.library.custom_op("flasht5::flash_attn_v2_fwd", mutates_args=(), device_types="cuda")
|
28 |
+
def flash_attn_v2_fwd(
|
29 |
+
q: torch.Tensor,
|
30 |
+
k: torch.Tensor,
|
31 |
+
v: torch.Tensor,
|
32 |
+
bias: torch.Tensor,
|
33 |
+
causal: bool,
|
34 |
+
sm_scale: float,
|
35 |
+
BLOCK_M: int,
|
36 |
+
BLOCK_N: int,
|
37 |
+
num_warps: int, num_stages: int
|
38 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
39 |
+
|
40 |
+
B, H, M, D = q.shape
|
41 |
+
N = k.shape[2]
|
42 |
+
P_SEQ = N - M
|
43 |
+
larger_m = M > N
|
44 |
+
|
45 |
+
# Trick to support shape such as (1, 1, seqlen_q, seqlen_k)
|
46 |
+
bias_batch_stride = bias.stride(0) if bias is not None else 0
|
47 |
+
bias_heads_stride = bias.stride(1) if bias is not None else 0
|
48 |
+
if bias is not None:
|
49 |
+
if (bias.shape[0] != q.shape[0]) and (bias.shape[0] == 1):
|
50 |
+
bias_batch_stride = 0
|
51 |
+
if (bias.shape[1] != q.shape[1]) and (bias.shape[1] == 1):
|
52 |
+
bias_heads_stride = 0
|
53 |
+
|
54 |
+
divisible_m = M % BLOCK_M == 0
|
55 |
+
divisible_n = N % BLOCK_N == 0
|
56 |
+
# consider using 3d grid to avoid div & rem
|
57 |
+
grid = (triton.cdiv(M, BLOCK_M), H, B)
|
58 |
+
o = torch.empty_like(q)
|
59 |
+
L = torch.empty((B, H, M), device=q.device, dtype=torch.float32)
|
60 |
+
|
61 |
+
with torch.cuda.device(q.device.index):
|
62 |
+
_fwd_kernel[grid](
|
63 |
+
q, k, v, bias, sm_scale,
|
64 |
+
L, o,
|
65 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
66 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
67 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
68 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
69 |
+
bias_batch_stride, bias_heads_stride,
|
70 |
+
bias.stride(2) if bias is not None else 0,
|
71 |
+
bias.stride(3) if bias is not None else 0,
|
72 |
+
B, H, M, N, P_SEQ,
|
73 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
|
74 |
+
IS_CAUSAL=causal, LARGER_M=larger_m,
|
75 |
+
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
|
76 |
+
HAS_BIAS=(bias is not None),
|
77 |
+
num_warps=num_warps, num_stages=num_stages,
|
78 |
+
)
|
79 |
+
|
80 |
+
return o, L
|
81 |
+
|
82 |
+
|
83 |
+
@torch.library.register_fake("flasht5::flash_attn_v2_fwd")
|
84 |
+
def flash_attn_v2_fwd_abstract(q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):
|
85 |
+
B, H, M, D = q.shape
|
86 |
+
o = torch.empty_like(q)
|
87 |
+
L = torch.empty((B, H, M), dtype=torch.float32, device=q.device)
|
88 |
+
|
89 |
+
return o, L
|
90 |
+
|
91 |
+
@torch.library.custom_op("flasht5::flash_attn_v2_bwd", mutates_args=(), device_types="cuda")
|
92 |
+
def flash_attn_v2_bwd(
|
93 |
+
o: torch.Tensor,
|
94 |
+
do: torch.Tensor,
|
95 |
+
q: torch.Tensor,
|
96 |
+
k: torch.Tensor,
|
97 |
+
v: torch.Tensor,
|
98 |
+
bias: torch.Tensor,
|
99 |
+
L: torch.Tensor,
|
100 |
+
causal: bool,
|
101 |
+
sm_scale: float,
|
102 |
+
BLOCK_M: int,
|
103 |
+
BLOCK_N: int,
|
104 |
+
num_warps: int,
|
105 |
+
num_stages: int
|
106 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
107 |
+
|
108 |
+
B, H, M, D = q.shape
|
109 |
+
N = k.shape[2]
|
110 |
+
P_SEQ = N - M
|
111 |
+
larger_m = M > N
|
112 |
+
|
113 |
+
divisible_m = M % BLOCK_M == 0
|
114 |
+
divisible_n = N % BLOCK_N == 0
|
115 |
+
|
116 |
+
# Trick to support shape such as (1, 1, seqlen_q, seqlen_k)
|
117 |
+
bias_batch_stride = bias.stride(0) if bias is not None else 0
|
118 |
+
bias_heads_stride = bias.stride(1) if bias is not None else 0
|
119 |
+
if bias is not None:
|
120 |
+
if (bias.shape[0] != q.shape[0]) and (bias.shape[0] == 1):
|
121 |
+
bias_batch_stride = 0
|
122 |
+
if (bias.shape[1] != q.shape[1]) and (bias.shape[1] == 1):
|
123 |
+
bias_heads_stride = 0
|
124 |
+
|
125 |
+
delta = torch.empty_like(L)
|
126 |
+
grid = (triton.cdiv(M, BLOCK_M), H, B)
|
127 |
+
|
128 |
+
with torch.cuda.device(q.device.index):
|
129 |
+
_bwd_preprocess[grid](
|
130 |
+
o, do,
|
131 |
+
delta,
|
132 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
133 |
+
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
|
134 |
+
delta.stride(0), delta.stride(1), delta.stride(2),
|
135 |
+
M,
|
136 |
+
BLOCK_M=BLOCK_M, D_HEAD=D,
|
137 |
+
DIVISIBLE_M=divisible_m,
|
138 |
+
)
|
139 |
+
|
140 |
+
dk = torch.empty_like(k)
|
141 |
+
dv = torch.empty_like(v)
|
142 |
+
|
143 |
+
HAS_BIAS = bias is not None
|
144 |
+
RETURN_DS = HAS_BIAS
|
145 |
+
IS_BATCH_REDUCED = (bias_batch_stride == 0)
|
146 |
+
#GROUP_SIZE_BIAS = min(B, 16)
|
147 |
+
GROUP_SIZE_BIAS = B
|
148 |
+
|
149 |
+
ds = None
|
150 |
+
locks = None
|
151 |
+
if RETURN_DS:
|
152 |
+
if IS_BATCH_REDUCED:
|
153 |
+
if causal:
|
154 |
+
ds = torch.zeros((GROUP_SIZE_BIAS, *bias.shape[1:]), dtype=bias.dtype, device=bias.device)
|
155 |
+
else:
|
156 |
+
ds = torch.empty((GROUP_SIZE_BIAS, *bias.shape[1:]), dtype=bias.dtype, device=bias.device)
|
157 |
+
locks = torch.zeros(2 * GROUP_SIZE_BIAS, dtype=torch.int32, device=q.device)
|
158 |
+
else:
|
159 |
+
if causal:
|
160 |
+
ds = torch.zeros_like(bias)
|
161 |
+
else:
|
162 |
+
ds = torch.empty_like(bias)
|
163 |
+
|
164 |
+
grid = (triton.cdiv(N, BLOCK_N), H, B)
|
165 |
+
with torch.cuda.device(q.device.index):
|
166 |
+
_bwd_kv_kernel[grid](
|
167 |
+
q, k, v, bias, sm_scale, do,
|
168 |
+
dk, dv, ds,
|
169 |
+
L, delta,
|
170 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
171 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
172 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
173 |
+
bias.stride(0) if HAS_BIAS else 0,
|
174 |
+
bias_heads_stride,
|
175 |
+
bias.stride(2) if HAS_BIAS else 0,
|
176 |
+
bias.stride(3) if HAS_BIAS else 0,
|
177 |
+
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
|
178 |
+
dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
|
179 |
+
dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
|
180 |
+
B, H, M, N, P_SEQ,
|
181 |
+
locks,
|
182 |
+
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal,
|
183 |
+
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
|
184 |
+
HAS_BIAS=HAS_BIAS,
|
185 |
+
RETURN_DS=RETURN_DS,
|
186 |
+
IS_BATCH_REDUCED=IS_BATCH_REDUCED,
|
187 |
+
GROUP_SIZE_BIAS=GROUP_SIZE_BIAS,
|
188 |
+
num_stages=num_stages, num_warps=num_warps,
|
189 |
+
)
|
190 |
+
|
191 |
+
dq = torch.empty_like(q)
|
192 |
+
grid = (triton.cdiv(M, BLOCK_M), H, B)
|
193 |
+
with torch.cuda.device(q.device.index):
|
194 |
+
_bwd_q_kernel[grid](
|
195 |
+
q, k, v, bias, sm_scale, do,
|
196 |
+
dq,
|
197 |
+
L, delta,
|
198 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
199 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
200 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
201 |
+
bias_batch_stride, bias_heads_stride,
|
202 |
+
bias.stride(2) if HAS_BIAS else 0,
|
203 |
+
bias.stride(3) if HAS_BIAS else 0,
|
204 |
+
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
|
205 |
+
dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
|
206 |
+
B, H, M, N, P_SEQ,
|
207 |
+
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
|
208 |
+
CAUSAL=causal, LARGER_M=larger_m,
|
209 |
+
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
|
210 |
+
HAS_BIAS=HAS_BIAS,
|
211 |
+
num_stages=num_stages, num_warps = num_warps,
|
212 |
+
)
|
213 |
+
|
214 |
+
if RETURN_DS and IS_BATCH_REDUCED and GROUP_SIZE_BIAS > 1:
|
215 |
+
ds = ds.sum(0, keepdim=True)
|
216 |
+
|
217 |
+
return dq, dk, dv, ds
|
218 |
+
|
219 |
+
@torch.library.register_fake("flasht5::flash_attn_v2_bwd")
|
220 |
+
def flash_attn_v2_bwd_abstract(o, do, q, k, v, bias, L, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):
|
221 |
+
dq = torch.empty_like(q)
|
222 |
+
dk = torch.empty_like(k)
|
223 |
+
dv = torch.empty_like(v)
|
224 |
+
ds = torch.empty_like(bias) if bias is not None else None
|
225 |
+
|
226 |
+
return dq, dk, dv, ds
|
227 |
+
|
228 |
+
class FlashAttentionAdditiveBias(torch.autograd.Function):
|
229 |
+
@staticmethod
|
230 |
+
def forward(ctx, q, k, v, bias, causal, sm_scale):
|
231 |
+
Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1]
|
232 |
+
|
233 |
+
assert Dq == Dk == Dv
|
234 |
+
assert Dk in {16, 32, 64, 128}
|
235 |
+
|
236 |
+
B, H, M, D = q.shape
|
237 |
+
N = k.shape[2]
|
238 |
+
|
239 |
+
if sm_scale is None:
|
240 |
+
sm_scale = 1. / math.sqrt(D)
|
241 |
+
|
242 |
+
config = get_fwd_config(B, H, M, N, D, causal)
|
243 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = config
|
244 |
+
|
245 |
+
o, L = torch.ops.flasht5.flash_attn_v2_fwd(q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages)
|
246 |
+
|
247 |
+
# autograd context maintenance
|
248 |
+
ctx.save_for_backward(q, k, v, bias, o, L)
|
249 |
+
ctx.sm_scale = sm_scale
|
250 |
+
ctx.causal = causal
|
251 |
+
|
252 |
+
return o
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def backward(ctx, do, *ignored):
|
256 |
+
q, k, v, bias, o, L = ctx.saved_tensors
|
257 |
+
sm_scale = ctx.sm_scale
|
258 |
+
causal = ctx.causal
|
259 |
+
|
260 |
+
B, H, M, D = q.shape
|
261 |
+
N = k.shape[2]
|
262 |
+
|
263 |
+
if sm_scale is None:
|
264 |
+
sm_scale = 1. / math.sqrt(D)
|
265 |
+
|
266 |
+
config = get_bwd_config(B, H, M, N, D, causal)
|
267 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = config
|
268 |
+
|
269 |
+
dq, dk, dv, ds = torch.ops.flasht5.flash_attn_v2_bwd(o, do, q, k, v, bias, L, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages)
|
270 |
+
|
271 |
+
return dq, dk, dv, ds, None, None, None, None
|
272 |
+
|
273 |
+
|
274 |
+
def flash_attention_v2_bias(q, k, v, bias, causal=False, sm_scale=None):
|
275 |
+
"""
|
276 |
+
An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691).
|
277 |
+
|
278 |
+
Arguments:
|
279 |
+
q(torch.Tensor): The first queries. The shape is (batch_size, nheads, seqlen_q, headdim).
|
280 |
+
k(torch.Tensor): The first keys. The shape is (batch_size, nheads, seqlen_k, headdim).
|
281 |
+
v(torch.Tensor): The values. The shape is (batch_size, nheads, seqlen_k, headdim).
|
282 |
+
causal(bool): Whether causal masking is applied to attention scores before applying softmax.
|
283 |
+
sm_scale(float): The scaling of attention scores before applying softmax.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
out(torch.Tensor): The output. The shape is (batch_size, nheads, seqlen_q, headdim).
|
287 |
+
"""
|
288 |
+
return FlashAttentionAdditiveBias.apply(q, k, v, bias, causal, sm_scale)
|
289 |
+
|
290 |
+
|
291 |
+
# --------------------------- Forward ---------------------------
|
292 |
+
# NOTE: this function can be overwritten at runtime to use your custom config
|
293 |
+
def get_fwd_config(B, H, M, N, D, causal):
|
294 |
+
if torch.cuda.get_device_capability() == (8, 0):
|
295 |
+
if not causal:
|
296 |
+
if D <= 64:
|
297 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
|
298 |
+
else:
|
299 |
+
if M <= 1024:
|
300 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
|
301 |
+
else:
|
302 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
|
303 |
+
else:
|
304 |
+
if D <= 64:
|
305 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
|
306 |
+
else:
|
307 |
+
if M <= 1024:
|
308 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
|
309 |
+
else:
|
310 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
|
311 |
+
elif torch.cuda.get_device_capability() == (8, 6):
|
312 |
+
if not causal:
|
313 |
+
if D <= 64:
|
314 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
|
315 |
+
else:
|
316 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
|
317 |
+
else: # causal
|
318 |
+
if D <= 64:
|
319 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4
|
320 |
+
else:
|
321 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
|
322 |
+
else:
|
323 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
|
324 |
+
return (BLOCK_M, BLOCK_N, num_stages, num_warps)
|
325 |
+
|
326 |
+
|
327 |
+
@triton.jit
|
328 |
+
def _fwd_kernel(
|
329 |
+
Q, K, V, B, sm_scale,
|
330 |
+
L, O,
|
331 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
332 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
333 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
334 |
+
stride_oz, stride_oh, stride_om, stride_ok,
|
335 |
+
stride_bz, stride_bh, stride_bm, stride_bn,
|
336 |
+
Z, H, M, N, P_SEQ,
|
337 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
|
338 |
+
IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
|
339 |
+
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
|
340 |
+
HAS_BIAS: tl.constexpr,
|
341 |
+
):
|
342 |
+
input_dtype = Q.dtype.element_ty
|
343 |
+
# -- grid id --
|
344 |
+
start_m = tl.program_id(0)
|
345 |
+
off_h = tl.program_id(1)
|
346 |
+
off_z = tl.program_id(2)
|
347 |
+
|
348 |
+
# scale sm_scale by log_2(e) and use
|
349 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
350 |
+
# don't work as expected with `exp` in the loop
|
351 |
+
log2e: tl.constexpr = 1.4426950408889634
|
352 |
+
|
353 |
+
# offset pointers for (batch, head)
|
354 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
355 |
+
K += off_z * stride_kz + off_h * stride_kh
|
356 |
+
V += off_z * stride_vz + off_h * stride_vh
|
357 |
+
O += off_z * stride_oz + off_h * stride_oh
|
358 |
+
if HAS_BIAS:
|
359 |
+
B += off_z * stride_bz + off_h * stride_bh
|
360 |
+
L += (off_z * H + off_h) * M # l's shape is (B, H, M)
|
361 |
+
|
362 |
+
offs_m_base = tl.arange(0, BLOCK_M)
|
363 |
+
offs_m = start_m * BLOCK_M + offs_m_base
|
364 |
+
offs_n_base = tl.arange(0, BLOCK_N)
|
365 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
366 |
+
|
367 |
+
# initialize pointers to value-like data
|
368 |
+
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
|
369 |
+
o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL)
|
370 |
+
l_ptrs = L + offs_m
|
371 |
+
|
372 |
+
# initialize pointer to m and l, fp32 for accumulators
|
373 |
+
m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
|
374 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
375 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
376 |
+
|
377 |
+
# load q
|
378 |
+
mask_m = offs_m < M
|
379 |
+
if DIVISIBLE_M:
|
380 |
+
q = tl.load(q_ptrs, cache_modifier=".cg")
|
381 |
+
else:
|
382 |
+
q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg")
|
383 |
+
|
384 |
+
#Dot I trick: to place q in registers, it saves shared memory
|
385 |
+
if BLOCK_DMODEL < 128:
|
386 |
+
I = tl.where(offs_k[:, None] == offs_k,
|
387 |
+
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),
|
388 |
+
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))
|
389 |
+
q = tl.dot(q, I).to(input_dtype)
|
390 |
+
# else:
|
391 |
+
# I = tl.where(offs_m_base[:, None] == offs_m_base,
|
392 |
+
# tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype),
|
393 |
+
# tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype))
|
394 |
+
# q = tl.dot(I, q).to(input_dtype)
|
395 |
+
|
396 |
+
# NOTE: Loop-Bound-For-N
|
397 |
+
# The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`.
|
398 |
+
# According to the rule of causal masking, then max index in n-dimension that this block may access
|
399 |
+
# is `P_SEQ + (start_m + 1) * BLOCK_M`.
|
400 |
+
# However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`).
|
401 |
+
# `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`.
|
402 |
+
# At this case, there would be illegal memory access when loading k & v tiles
|
403 |
+
# if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true).
|
404 |
+
# See also https://github.com/FlagOpen/FlagAttention/pull/8
|
405 |
+
if IS_CAUSAL:
|
406 |
+
hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
|
407 |
+
if LARGER_M:
|
408 |
+
hi = tl.maximum(0, hi)
|
409 |
+
else:
|
410 |
+
hi = N
|
411 |
+
|
412 |
+
# loop over k, v and update accumulators
|
413 |
+
offs_n_init = offs_n_base
|
414 |
+
k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N)
|
415 |
+
v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
|
416 |
+
if HAS_BIAS:
|
417 |
+
bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n_init[None, :] * stride_bn) # (BLOCK_M, BLOCK_N)
|
418 |
+
|
419 |
+
for start_n in range(0, hi, BLOCK_N):
|
420 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
421 |
+
offs_n = start_n + offs_n_base
|
422 |
+
|
423 |
+
# -- load k, v --
|
424 |
+
mask_n = offs_n < N
|
425 |
+
if DIVISIBLE_N:
|
426 |
+
k = tl.load(k_ptrs, cache_modifier=".cg")
|
427 |
+
v = tl.load(v_ptrs, cache_modifier=".cg")
|
428 |
+
else:
|
429 |
+
k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg")
|
430 |
+
v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg")
|
431 |
+
|
432 |
+
# -- load bias --
|
433 |
+
if HAS_BIAS:
|
434 |
+
if DIVISIBLE_M and DIVISIBLE_N:
|
435 |
+
b = tl.load(bias_ptrs)
|
436 |
+
else:
|
437 |
+
b = tl.load(bias_ptrs, mask_m[:, None] & mask_n[None, :])
|
438 |
+
|
439 |
+
# -- compute qk ---
|
440 |
+
s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
441 |
+
s += tl.dot(q, k) * sm_scale
|
442 |
+
if HAS_BIAS:
|
443 |
+
s += b
|
444 |
+
|
445 |
+
if not DIVISIBLE_N:
|
446 |
+
s = tl.where(mask_n[None, :], s, float("-inf"))
|
447 |
+
if IS_CAUSAL:
|
448 |
+
causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :]
|
449 |
+
s = tl.where(causal_mask, s, float("-inf"))
|
450 |
+
|
451 |
+
# -- compute scaling constant ---
|
452 |
+
m_i_new = tl.maximum(m_i, tl.max(s, 1))
|
453 |
+
alpha = tl.math.exp2((m_i - m_i_new)*log2e)
|
454 |
+
p = tl.math.exp2((s - m_i_new[:, None])*log2e)
|
455 |
+
|
456 |
+
# -- scale and update acc: acc *= alpha[:, None]--
|
457 |
+
acc *= alpha[:, None]
|
458 |
+
acc += tl.dot(p.to(input_dtype), v)
|
459 |
+
|
460 |
+
# -- update m_i and l_i --
|
461 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
462 |
+
m_i = m_i_new
|
463 |
+
# update pointers
|
464 |
+
k_ptrs += BLOCK_N * stride_kn
|
465 |
+
v_ptrs += BLOCK_N * stride_vn
|
466 |
+
if HAS_BIAS:
|
467 |
+
bias_ptrs += BLOCK_N * stride_bn
|
468 |
+
|
469 |
+
# write back l & o
|
470 |
+
if IS_CAUSAL and LARGER_M:
|
471 |
+
is_empty_line = (offs_m + P_SEQ) < 0
|
472 |
+
acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None]))
|
473 |
+
l = tl.where(is_empty_line, float("-inf"), m_i + tl.log(l_i))
|
474 |
+
else:
|
475 |
+
acc = acc * (1.0 / l_i[:, None])
|
476 |
+
l = m_i + tl.log(l_i) # log(normalizer)
|
477 |
+
|
478 |
+
if DIVISIBLE_M:
|
479 |
+
tl.store(l_ptrs, l, cache_modifier=".cg")
|
480 |
+
tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg")
|
481 |
+
else:
|
482 |
+
tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg")
|
483 |
+
tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg")
|
484 |
+
|
485 |
+
|
486 |
+
# --------------------------- Backward ---------------------------
|
487 |
+
# NOTE: this function can be overwritten at runtime to use your custom config
|
488 |
+
def get_bwd_config(B, H, M, N, D, causal):
|
489 |
+
if torch.cuda.get_device_capability() == (8, 0):
|
490 |
+
if not causal:
|
491 |
+
BLOCK_M = 128 if D <= 64 else 64
|
492 |
+
BLOCK_N = 64
|
493 |
+
num_stages = 2
|
494 |
+
num_warps = 4
|
495 |
+
else:
|
496 |
+
BLOCK_M = 64
|
497 |
+
BLOCK_N = 64
|
498 |
+
num_stages = 3 if D <= 64 else 2
|
499 |
+
num_warps = 4
|
500 |
+
elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6)
|
501 |
+
if not causal:
|
502 |
+
if D <= 64:
|
503 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
|
504 |
+
else:
|
505 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 8
|
506 |
+
else:
|
507 |
+
if D <= 64:
|
508 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
|
509 |
+
else:
|
510 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4
|
511 |
+
else:
|
512 |
+
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
|
513 |
+
return (BLOCK_M, BLOCK_N, num_stages, num_warps)
|
514 |
+
|
515 |
+
|
516 |
+
@triton.jit
|
517 |
+
def _bwd_preprocess(
|
518 |
+
Out, DO,
|
519 |
+
Delta,
|
520 |
+
stride_oz, stride_oh, stride_om, stride_ok,
|
521 |
+
stride_doz, stride_doh, stride_dom, stride_dok,
|
522 |
+
stride_dz, stride_dh, stride_dm,
|
523 |
+
M,
|
524 |
+
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
525 |
+
DIVISIBLE_M: tl.constexpr,
|
526 |
+
):
|
527 |
+
off_h = tl.program_id(1)
|
528 |
+
off_z = tl.program_id(2)
|
529 |
+
Out += off_z * stride_oz + off_h * stride_oh
|
530 |
+
DO += off_z * stride_doz + off_h * stride_doh
|
531 |
+
Delta += off_z * stride_dz + off_h * stride_dh
|
532 |
+
|
533 |
+
# compute (Out * Dout).sum() for vector interpretation
|
534 |
+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
535 |
+
off_n = tl.arange(0, D_HEAD)
|
536 |
+
|
537 |
+
# load
|
538 |
+
o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok
|
539 |
+
do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok
|
540 |
+
|
541 |
+
if DIVISIBLE_M:
|
542 |
+
o = tl.load(o_ptrs).to(tl.float32)
|
543 |
+
do = tl.load(do_ptrs).to(tl.float32)
|
544 |
+
else:
|
545 |
+
mask_m = off_m < M
|
546 |
+
o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32)
|
547 |
+
do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32)
|
548 |
+
|
549 |
+
# compute
|
550 |
+
delta = tl.sum(o * do, axis=1)
|
551 |
+
# write-back
|
552 |
+
d_ptrs = Delta + off_m * stride_dm
|
553 |
+
if DIVISIBLE_M:
|
554 |
+
tl.store(d_ptrs, delta)
|
555 |
+
else:
|
556 |
+
tl.store(d_ptrs, delta, mask=mask_m)
|
557 |
+
|
558 |
+
|
559 |
+
@triton.jit
|
560 |
+
def _bwd_kv_kernel(
|
561 |
+
Q, K, V, B, sm_scale, DO,
|
562 |
+
DK, DV, DS,
|
563 |
+
L,
|
564 |
+
D,
|
565 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
566 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
567 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
568 |
+
stride_bz, stride_bh, stride_bm, stride_bn,
|
569 |
+
stride_doz, stride_doh, stride_dom, stride_dok,
|
570 |
+
stride_dkz, stride_dkh, stride_dkn, stride_dkk,
|
571 |
+
stride_dvz, stride_dvh, stride_dvn, stride_dvk,
|
572 |
+
Z, H, M, N, P_SEQ,
|
573 |
+
lock,
|
574 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
|
575 |
+
CAUSAL: tl.constexpr,
|
576 |
+
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
|
577 |
+
HAS_BIAS: tl.constexpr,
|
578 |
+
RETURN_DS: tl.constexpr,
|
579 |
+
IS_BATCH_REDUCED: tl.constexpr,
|
580 |
+
GROUP_SIZE_BIAS: tl.constexpr,
|
581 |
+
):
|
582 |
+
input_dtype = Q.dtype.element_ty
|
583 |
+
# -- grid id --
|
584 |
+
start_n = tl.program_id(0)
|
585 |
+
off_h = tl.program_id(1)
|
586 |
+
off_z = tl.program_id(2)
|
587 |
+
log2e: tl.constexpr = 1.4426950408889634
|
588 |
+
|
589 |
+
# offset pointers for (batch, head)
|
590 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
591 |
+
K += off_z * stride_kz + off_h * stride_kh
|
592 |
+
V += off_z * stride_vz + off_h * stride_vh
|
593 |
+
if HAS_BIAS:
|
594 |
+
if IS_BATCH_REDUCED:
|
595 |
+
B += off_h * stride_bh
|
596 |
+
else:
|
597 |
+
B += off_z * stride_bz + off_h * stride_bh
|
598 |
+
DO += off_z * stride_doz + off_h * stride_doh
|
599 |
+
|
600 |
+
# offset pointers for batch/head
|
601 |
+
DK += off_z * stride_dkz + off_h * stride_dkh
|
602 |
+
DV += off_z * stride_dvz + off_h * stride_dvh
|
603 |
+
|
604 |
+
# offset pointer for ds tensor and locks for the reduction
|
605 |
+
if RETURN_DS:
|
606 |
+
DS += off_z * stride_bz + off_h * stride_bh
|
607 |
+
|
608 |
+
# offset pointers for batch/head
|
609 |
+
D += (off_z * H + off_h) * M
|
610 |
+
L += (off_z * H + off_h) * M
|
611 |
+
|
612 |
+
if CAUSAL:
|
613 |
+
lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0)
|
614 |
+
lo = (lo // BLOCK_M) * BLOCK_M
|
615 |
+
else:
|
616 |
+
lo = 0
|
617 |
+
|
618 |
+
offs_m_init = lo + tl.arange(0, BLOCK_M)
|
619 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
620 |
+
offs_m_base = tl.arange(0, BLOCK_M)
|
621 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
622 |
+
|
623 |
+
# initialize pointers to value-like data
|
624 |
+
q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
|
625 |
+
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
|
626 |
+
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
|
627 |
+
do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
|
628 |
+
|
629 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL)
|
630 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) # (BLOCK_N, BLOCK_DMODEL)
|
631 |
+
|
632 |
+
if HAS_BIAS:
|
633 |
+
bias_ptrs = B + (offs_m_init[:, None] * stride_bm + offs_n[None, :] * stride_bn)
|
634 |
+
|
635 |
+
if RETURN_DS:
|
636 |
+
ds_ptrs = DS + (offs_m_init[:, None] * stride_bm + offs_n[None, :] * stride_bn)
|
637 |
+
|
638 |
+
# k and v stay in SRAM throughout
|
639 |
+
mask_n = offs_n < N
|
640 |
+
if DIVISIBLE_N:
|
641 |
+
v = tl.load(v_ptrs)
|
642 |
+
k = tl.load(k_ptrs)
|
643 |
+
else:
|
644 |
+
v = tl.load(v_ptrs, mask=mask_n[:, None])
|
645 |
+
k = tl.load(k_ptrs, mask=mask_n[:, None])
|
646 |
+
|
647 |
+
# initialize dk amd dv
|
648 |
+
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
649 |
+
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
650 |
+
|
651 |
+
# loop over a col
|
652 |
+
for start_m in range(lo, M, BLOCK_M):
|
653 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
654 |
+
offs_m = start_m + offs_m_base
|
655 |
+
causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
|
656 |
+
|
657 |
+
# load q1, k1, q2, k2, v, do on-chip
|
658 |
+
mask_m = offs_m < M
|
659 |
+
if DIVISIBLE_M:
|
660 |
+
q = tl.load(q_ptrs)
|
661 |
+
else:
|
662 |
+
valid_mask = mask_m[:, None] # & mask_n
|
663 |
+
q = tl.load(q_ptrs, mask=mask_m[:, None])
|
664 |
+
|
665 |
+
# load bias
|
666 |
+
if HAS_BIAS:
|
667 |
+
if DIVISIBLE_M and DIVISIBLE_N:
|
668 |
+
b = tl.load(bias_ptrs)
|
669 |
+
else:
|
670 |
+
b = tl.load(bias_ptrs, mask=mask_m[:, None] & mask_n[None, :])
|
671 |
+
|
672 |
+
# recompute p = softmax(qk * sm_scale, dim=-1)
|
673 |
+
s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
674 |
+
s += tl.dot(q, tl.trans(k)) * sm_scale
|
675 |
+
|
676 |
+
if HAS_BIAS:
|
677 |
+
s += b
|
678 |
+
|
679 |
+
# NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
|
680 |
+
# So masking on s is not needed.
|
681 |
+
# s = tl.where(valid_mask, s , float("-inf"))
|
682 |
+
# if CAUSAL:
|
683 |
+
# s = tl.where(causal_mask, s, float("-inf"))
|
684 |
+
|
685 |
+
# -- recompute p ---
|
686 |
+
if DIVISIBLE_M:
|
687 |
+
l = tl.load(L + offs_m)
|
688 |
+
else:
|
689 |
+
l = tl.load(L + offs_m, mask=mask_m)
|
690 |
+
p = tl.math.exp2((s - l[:, None])*log2e) # (BLOCK_M, BLOCK_N)
|
691 |
+
|
692 |
+
if not DIVISIBLE_M:
|
693 |
+
p = tl.where(valid_mask, p, 0.0)
|
694 |
+
if CAUSAL:
|
695 |
+
p = tl.where(causal_mask, p, 0.0)
|
696 |
+
|
697 |
+
# compute dv = dot(p, do)
|
698 |
+
if DIVISIBLE_M:
|
699 |
+
do = tl.load(do_ptrs)
|
700 |
+
else:
|
701 |
+
do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL)
|
702 |
+
dv += tl.dot(tl.trans(p.to(do.dtype)), do) # (BLOCK_N, BLOCK_DMODEL) # still correct
|
703 |
+
|
704 |
+
# compute dp = dot(v, do)
|
705 |
+
if DIVISIBLE_M:
|
706 |
+
delta = tl.load(D + offs_m)
|
707 |
+
else:
|
708 |
+
delta = tl.load(D + offs_m, mask=mask_m)
|
709 |
+
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
710 |
+
dp += tl.dot(do, tl.trans(v))
|
711 |
+
|
712 |
+
# compute ds = p * (dp - delta[:, None])
|
713 |
+
ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
|
714 |
+
|
715 |
+
if not DIVISIBLE_M:
|
716 |
+
ds = tl.where(valid_mask, ds, 0.0)
|
717 |
+
if CAUSAL:
|
718 |
+
ds = tl.where(causal_mask, ds, 0.0)
|
719 |
+
|
720 |
+
ds = ds.to(input_dtype)
|
721 |
+
# compute dk = dot(ds.T, q) masking
|
722 |
+
dk += tl.dot(tl.trans(ds), q)
|
723 |
+
|
724 |
+
# store ds
|
725 |
+
if RETURN_DS:
|
726 |
+
if DIVISIBLE_M and DIVISIBLE_N:
|
727 |
+
tl.store(ds_ptrs, ds)
|
728 |
+
else:
|
729 |
+
tl.store(ds_ptrs, ds, mask=mask_m[:, None] & mask_n[None, :])
|
730 |
+
|
731 |
+
# increment pointers
|
732 |
+
q_ptrs += BLOCK_M * stride_qm
|
733 |
+
do_ptrs += BLOCK_M * stride_dom
|
734 |
+
if HAS_BIAS:
|
735 |
+
bias_ptrs += BLOCK_M * stride_bm
|
736 |
+
if RETURN_DS:
|
737 |
+
ds_ptrs += BLOCK_M * stride_bm
|
738 |
+
|
739 |
+
dk *= sm_scale
|
740 |
+
if DIVISIBLE_N:
|
741 |
+
tl.store(dk_ptrs, dk.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL)
|
742 |
+
tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL,)
|
743 |
+
else:
|
744 |
+
tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL)
|
745 |
+
tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL,)
|
746 |
+
|
747 |
+
|
748 |
+
@triton.jit
|
749 |
+
def _bwd_q_kernel(
|
750 |
+
Q, K, V, B, sm_scale, DO,
|
751 |
+
DQ,
|
752 |
+
L,
|
753 |
+
D,
|
754 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
755 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
756 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
757 |
+
stride_bz, stride_bh, stride_bm, stride_bn,
|
758 |
+
stride_doz, stride_doh, stride_dom, stride_dok,
|
759 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqk,
|
760 |
+
Z, H, M, N, P_SEQ,
|
761 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
|
762 |
+
CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
|
763 |
+
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
|
764 |
+
HAS_BIAS: tl.constexpr
|
765 |
+
):
|
766 |
+
input_dtype = Q.dtype.element_ty
|
767 |
+
# -- grid id --
|
768 |
+
start_m = tl.program_id(0)
|
769 |
+
off_h = tl.program_id(1)
|
770 |
+
off_z = tl.program_id(2)
|
771 |
+
|
772 |
+
# scale sm_scale by log_2(e) and use
|
773 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
774 |
+
# don't work as expected with `exp` in the loop
|
775 |
+
log2e: tl.constexpr = 1.4426950408889634
|
776 |
+
|
777 |
+
# offset pointers for (batch, head)
|
778 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
779 |
+
K += off_z * stride_kz + off_h * stride_kh
|
780 |
+
V += off_z * stride_vz + off_h * stride_vh
|
781 |
+
if HAS_BIAS:
|
782 |
+
B += off_z * stride_bz + off_h * stride_bh
|
783 |
+
DO += off_z * stride_doz + off_h * stride_doh
|
784 |
+
D += (off_z * H + off_h) * M
|
785 |
+
L += (off_z * H + off_h) * M
|
786 |
+
|
787 |
+
# offset pointers for batch/head
|
788 |
+
DQ += off_z * stride_dqz + off_h * stride_dqh
|
789 |
+
|
790 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
791 |
+
offs_n_base = tl.arange(0, BLOCK_N)
|
792 |
+
offs_n_init = offs_n_base
|
793 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
794 |
+
|
795 |
+
# initialize pointers to value-like data
|
796 |
+
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
|
797 |
+
k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
|
798 |
+
v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
|
799 |
+
|
800 |
+
if HAS_BIAS:
|
801 |
+
bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n_init[None, :] * stride_bn)
|
802 |
+
|
803 |
+
dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) # (BLOCK_M, BLOCK_DMODEL)
|
804 |
+
do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
|
805 |
+
|
806 |
+
# pointer to row-wise quantities in value-like data
|
807 |
+
d_ptrs = D + offs_m
|
808 |
+
l_ptrs = L + offs_m
|
809 |
+
|
810 |
+
# load q: it will stay in SRAM throughout
|
811 |
+
mask_m = offs_m < M
|
812 |
+
if DIVISIBLE_M:
|
813 |
+
q = tl.load(q_ptrs)
|
814 |
+
do = tl.load(do_ptrs)
|
815 |
+
delta = tl.load(d_ptrs)
|
816 |
+
l = tl.load(l_ptrs)
|
817 |
+
else:
|
818 |
+
q = tl.load(q_ptrs, mask=mask_m[:, None])
|
819 |
+
do = tl.load(do_ptrs, mask=mask_m[:, None])
|
820 |
+
delta = tl.load(d_ptrs, mask=mask_m)
|
821 |
+
l = tl.load(l_ptrs, mask=mask_m)
|
822 |
+
|
823 |
+
# initialize dq
|
824 |
+
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
825 |
+
|
826 |
+
# loop over k, v and update accumulator
|
827 |
+
# see note "Loop-Bound-For-N"
|
828 |
+
if CAUSAL:
|
829 |
+
hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
|
830 |
+
if LARGER_M:
|
831 |
+
hi = tl.maximum(0, hi)
|
832 |
+
else:
|
833 |
+
hi = N
|
834 |
+
|
835 |
+
# loop over a row
|
836 |
+
for start_n in range(0, hi, BLOCK_N):
|
837 |
+
offs_n = start_n + offs_n_base
|
838 |
+
|
839 |
+
# load k1, k2, v on chip
|
840 |
+
mask_n = offs_n < N
|
841 |
+
if DIVISIBLE_N:
|
842 |
+
v = tl.load(v_ptrs)
|
843 |
+
k = tl.load(k_ptrs)
|
844 |
+
else:
|
845 |
+
v = tl.load(v_ptrs, mask=mask_n[:, None])
|
846 |
+
k = tl.load(k_ptrs, mask=mask_n[:, None])
|
847 |
+
|
848 |
+
# load bias
|
849 |
+
if HAS_BIAS:
|
850 |
+
if DIVISIBLE_M and DIVISIBLE_N:
|
851 |
+
b = tl.load(bias_ptrs)
|
852 |
+
else:
|
853 |
+
b = tl.load(bias_ptrs, mask=mask_m[:, None] & mask_n[None, :])
|
854 |
+
|
855 |
+
# recompute p = softmax(qk * sm_scale, dim=-1)
|
856 |
+
if not DIVISIBLE_N:
|
857 |
+
valid_mask = mask_n # & mask_m[:, None]
|
858 |
+
if CAUSAL:
|
859 |
+
causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
|
860 |
+
|
861 |
+
s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
862 |
+
s += tl.dot(q, tl.trans(k)) * sm_scale
|
863 |
+
if HAS_BIAS:
|
864 |
+
s += b
|
865 |
+
|
866 |
+
# NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
|
867 |
+
# So masking on s is not needed.
|
868 |
+
# if CAUSAL:
|
869 |
+
# s = tl.where(causal_mask & valid_mask, s, float("-inf"))
|
870 |
+
# else:
|
871 |
+
# s = tl.where(valid_mask, s, float("-inf"))
|
872 |
+
p = tl.math.exp2((s - l[:, None])*log2e) # (BLOCK_M, BLOCK_N)
|
873 |
+
|
874 |
+
# compute dp = dot(v, do)
|
875 |
+
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
876 |
+
dp += tl.dot(do.to(input_dtype), tl.trans(v))
|
877 |
+
# no need to mask dp
|
878 |
+
# if CAUSAL:
|
879 |
+
# dp = tl.where(causal_mask & valid_mask, dp, 0.0)
|
880 |
+
# else:
|
881 |
+
# dp = tl.where(valid_mask, dp, 0.0)
|
882 |
+
|
883 |
+
# compute ds = p * (dp - delta[:, None])
|
884 |
+
# move scale out to dq at last
|
885 |
+
ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
|
886 |
+
|
887 |
+
# mask ds to ensure no small values
|
888 |
+
if not DIVISIBLE_N:
|
889 |
+
ds = tl.where(valid_mask, ds, 0.0)
|
890 |
+
if CAUSAL:
|
891 |
+
ds = tl.where(causal_mask, ds, 0.0)
|
892 |
+
|
893 |
+
dq += tl.dot(ds.to(input_dtype), k)
|
894 |
+
|
895 |
+
# increment pointers
|
896 |
+
k_ptrs += BLOCK_N * stride_kn
|
897 |
+
v_ptrs += BLOCK_N * stride_vn
|
898 |
+
if HAS_BIAS:
|
899 |
+
bias_ptrs += BLOCK_N * stride_bn
|
900 |
+
|
901 |
+
dq *= sm_scale
|
902 |
+
if DIVISIBLE_M:
|
903 |
+
tl.store(dq_ptrs, dq.to(input_dtype))
|
904 |
+
else:
|
905 |
+
tl.store(dq_ptrs, dq.to(input_dtype), mask=mask_m[:, None])
|
generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"decoder_start_token_id": 0,
|
4 |
+
"eos_token_id": 1,
|
5 |
+
"pad_token_id": 3,
|
6 |
+
"transformers_version": "4.46.0.dev0"
|
7 |
+
}
|
modeling_flash_t5.py
ADDED
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import math
|
5 |
+
from typing import Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from transformers.modeling_utils import ModuleUtilsMixin
|
12 |
+
from transformers.modeling_outputs import ModelOutput, Seq2SeqModelOutput, BaseModelOutput, Seq2SeqLMOutput
|
13 |
+
from transformers import PreTrainedModel
|
14 |
+
|
15 |
+
try:
|
16 |
+
from .rms_norm import fast_rms_layernorm
|
17 |
+
except ImportError:
|
18 |
+
fast_rms_layernorm = None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from .cross_entropy_loss import cross_entropy_loss as fast_cross_entropy_loss
|
22 |
+
except ImportError:
|
23 |
+
fast_cross_entropy_loss = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
from .flash_attention_v2_bias import flash_attention_v2_bias
|
27 |
+
except ImportError:
|
28 |
+
flash_attention_v2_bias = None
|
29 |
+
|
30 |
+
try:
|
31 |
+
from flash_attn import flash_attn_kvpacked_func, flash_attn_func
|
32 |
+
except ImportError:
|
33 |
+
flash_attn_kvpacked_func, flash_attn_func = None, None
|
34 |
+
|
35 |
+
from .attn_ref import attn_ref
|
36 |
+
|
37 |
+
from .configuration_flash_t5 import FlashT5Config
|
38 |
+
from .positional_encoding import ALiBiPositionalEncoding, RelativePositionalEncoding, RotaryPositionalEncoding, FIRE
|
39 |
+
|
40 |
+
class FlashT5CrossEntropyLoss(nn.Module):
|
41 |
+
def __init__(self, z_loss_factor=0.0, label_smoothing=0.0, use_triton_crossentropy=False, inplace_backward=False):
|
42 |
+
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
if use_triton_crossentropy and fast_cross_entropy_loss is None:
|
46 |
+
raise ImportError("fast_cross_entropy_loss is not available")
|
47 |
+
|
48 |
+
self.use_triton_crossentropy = use_triton_crossentropy
|
49 |
+
self.z_loss_factor = z_loss_factor
|
50 |
+
self.label_smoothing = label_smoothing
|
51 |
+
self.inplace_backward = inplace_backward
|
52 |
+
|
53 |
+
self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
54 |
+
|
55 |
+
def compute_zloss(self, logits: torch.Tensor, z_loss: float):
|
56 |
+
logits_sum = torch.logsumexp(logits, dim=-1, keepdim=True)
|
57 |
+
log_z = torch.squeeze(logits_sum, axis=-1)
|
58 |
+
total_z_loss = z_loss * torch.square(log_z)
|
59 |
+
return total_z_loss.mean()
|
60 |
+
|
61 |
+
def forward(self, logits, labels):
|
62 |
+
|
63 |
+
if self.use_triton_crossentropy:
|
64 |
+
return fast_cross_entropy_loss(logits, labels, \
|
65 |
+
lse_square_scale=self.z_loss_factor, \
|
66 |
+
label_smoothing=self.label_smoothing, \
|
67 |
+
inplace_backward=self.inplace_backward \
|
68 |
+
)[0].mean()
|
69 |
+
|
70 |
+
# use standard method
|
71 |
+
batch, seq_len, d = logits.shape
|
72 |
+
logits_flatten = logits.float().view(batch*seq_len, d) # Must cast to float32 for numerical stability
|
73 |
+
labels_flatten = labels.view(-1)
|
74 |
+
loss = self.cross_entropy_loss(logits_flatten, labels_flatten)
|
75 |
+
z_loss = 0.0
|
76 |
+
if self.z_loss_factor != 0.0:
|
77 |
+
z_loss = self.compute_zloss(logits_flatten[labels_flatten != -100],
|
78 |
+
z_loss=self.z_loss_factor)
|
79 |
+
return loss + z_loss
|
80 |
+
|
81 |
+
class FlashT5LayerNorm(nn.Module):
|
82 |
+
def __init__(self, hidden_size, eps=1e-6, use_triton_layernorm=False):
|
83 |
+
"""
|
84 |
+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
85 |
+
"""
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
if use_triton_layernorm and fast_rms_layernorm is None:
|
89 |
+
raise ImportError("fast_rms_layernorm is not available")
|
90 |
+
|
91 |
+
self.use_triton_layernorm = use_triton_layernorm
|
92 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
93 |
+
self.variance_epsilon = eps
|
94 |
+
|
95 |
+
def forward(self, hidden_states):
|
96 |
+
|
97 |
+
if self.use_triton_layernorm:
|
98 |
+
return fast_rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
|
99 |
+
|
100 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
101 |
+
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
102 |
+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
103 |
+
# half-precision inputs is done in fp32
|
104 |
+
|
105 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
106 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
107 |
+
|
108 |
+
# convert into half-precision if necessary
|
109 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
110 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
111 |
+
|
112 |
+
return self.weight * hidden_states
|
113 |
+
|
114 |
+
class FlashT5DenseAct(nn.Module):
|
115 |
+
def __init__(self, config: FlashT5Config):
|
116 |
+
super().__init__()
|
117 |
+
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
|
118 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
119 |
+
self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
|
120 |
+
|
121 |
+
def forward(self, hidden_states):
|
122 |
+
hidden_states = self.wi(hidden_states)
|
123 |
+
hidden_states = self.act(hidden_states)
|
124 |
+
hidden_states = self.dropout(hidden_states)
|
125 |
+
if (
|
126 |
+
isinstance(self.wo.weight, torch.Tensor)
|
127 |
+
and hidden_states.dtype != self.wo.weight.dtype
|
128 |
+
and self.wo.weight.dtype != torch.int8
|
129 |
+
):
|
130 |
+
hidden_states = hidden_states.to(self.wo.weight.dtype)
|
131 |
+
|
132 |
+
return hidden_states
|
133 |
+
|
134 |
+
class FlashT5DenseGatedAct(nn.Module):
|
135 |
+
def __init__(self, config: FlashT5Config):
|
136 |
+
super().__init__()
|
137 |
+
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
138 |
+
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
139 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
140 |
+
self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
|
141 |
+
|
142 |
+
self.use_gelu_act = config.use_gelu_act
|
143 |
+
|
144 |
+
def forward(self, hidden_states):
|
145 |
+
|
146 |
+
hidden_act = self.act(self.wi_0(hidden_states))
|
147 |
+
hidden_linear = self.wi_1(hidden_states)
|
148 |
+
hidden_states = hidden_act * hidden_linear
|
149 |
+
hidden_states = self.dropout(hidden_states)
|
150 |
+
|
151 |
+
return hidden_states
|
152 |
+
|
153 |
+
class FlashT5LayerFF(nn.Module):
|
154 |
+
def __init__(self, config: FlashT5Config):
|
155 |
+
super().__init__()
|
156 |
+
if config.use_glu_mlp:
|
157 |
+
self.act = FlashT5DenseGatedAct(config)
|
158 |
+
else:
|
159 |
+
self.act = FlashT5DenseAct(config)
|
160 |
+
|
161 |
+
self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
|
162 |
+
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
163 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
164 |
+
|
165 |
+
def forward(self, hidden_states):
|
166 |
+
forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states)
|
167 |
+
forwarded_states = self.act(forwarded_states)
|
168 |
+
forwarded_states = self.wo(forwarded_states)
|
169 |
+
hidden_states = hidden_states + self.dropout(forwarded_states)
|
170 |
+
return hidden_states
|
171 |
+
|
172 |
+
|
173 |
+
class FlashT5Attention(nn.Module, ModuleUtilsMixin):
|
174 |
+
def __init__(self, config: FlashT5Config, has_positional_encoding=False, is_causal=False):
|
175 |
+
super().__init__()
|
176 |
+
self.is_decoder = config.is_decoder
|
177 |
+
self.has_positional_encoding = has_positional_encoding
|
178 |
+
self.is_causal = is_causal
|
179 |
+
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
180 |
+
self.relative_attention_max_distance = config.relative_attention_max_distance
|
181 |
+
self.d_model = config.d_model
|
182 |
+
self.key_value_proj_dim = config.d_kv
|
183 |
+
self.n_heads = config.num_heads
|
184 |
+
self.p_dropout = config.attention_dropout_rate
|
185 |
+
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
186 |
+
self.attention_type = config.attention_type
|
187 |
+
self.position_encoding_type = config.position_encoding_type
|
188 |
+
self.max_sequence_length = config.max_sequence_length
|
189 |
+
self.softmax_scale = config.attention_scale if config.attention_scale is not None else 1.0/math.sqrt(self.n_heads)
|
190 |
+
self.use_full_bias_size = config.use_full_bias_size
|
191 |
+
self.use_masking = config.use_masking
|
192 |
+
|
193 |
+
if self.use_masking and not self.use_full_bias_size:
|
194 |
+
raise ValueError("Masking can only be used with full batch size.")
|
195 |
+
|
196 |
+
if self.attention_type == "triton" and flash_attention_v2_bias is None:
|
197 |
+
raise ImportError("flash_attention_triton is not available")
|
198 |
+
elif self.attention_type.startswith("fa2") and flash_attn_func is None:
|
199 |
+
raise ImportError("Flash Attention 2 is not available")
|
200 |
+
|
201 |
+
if self.attention_type == "fa2_rpe" and self.position_encoding_type != "t5":
|
202 |
+
raise ValueError("fa2_rpe is not compatible with non-T5 position encoding")
|
203 |
+
|
204 |
+
assert (self.p_dropout == 0.0) or (self.attention_type != "triton"), "Triton attention does not support dropout"
|
205 |
+
|
206 |
+
self.pe_encoding = None
|
207 |
+
if self.position_encoding_type == "ALiBi" and has_positional_encoding:
|
208 |
+
# build alibi matrix with an upper bound on seq length
|
209 |
+
self.pe_encoding = ALiBiPositionalEncoding(self.max_sequence_length,
|
210 |
+
self.n_heads,
|
211 |
+
config.alibi_mode,
|
212 |
+
randomized_position=config.use_randomized_position_encoding)
|
213 |
+
elif self.position_encoding_type == "t5" and has_positional_encoding:
|
214 |
+
self.pe_encoding = RelativePositionalEncoding(self.relative_attention_num_buckets,
|
215 |
+
self.relative_attention_max_distance,
|
216 |
+
self.n_heads,
|
217 |
+
self.max_sequence_length,
|
218 |
+
bidirectional=(not self.is_decoder),
|
219 |
+
randomized_position=config.use_randomized_position_encoding)
|
220 |
+
elif self.position_encoding_type == "RoPE":
|
221 |
+
self.pe_encoding = RotaryPositionalEncoding(int(self.key_value_proj_dim * config.rotary_emb_fraction),
|
222 |
+
self.max_sequence_length,
|
223 |
+
config.rotary_base,
|
224 |
+
config.rotary_interleaved,
|
225 |
+
config.rotary_scale_base,
|
226 |
+
randomized_position=config.use_randomized_position_encoding)
|
227 |
+
elif self.position_encoding_type == "FIRE" and has_positional_encoding:
|
228 |
+
self.pe_encoding = FIRE(num_heads=self.n_heads,
|
229 |
+
mlp_width=config.fire_mlp_width,
|
230 |
+
init_c=0.1,
|
231 |
+
init_L=self.relative_attention_max_distance)
|
232 |
+
|
233 |
+
self.Wq = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
234 |
+
self.Wk = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
235 |
+
self.Wv = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
236 |
+
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
237 |
+
|
238 |
+
def forward(
|
239 |
+
self,
|
240 |
+
hidden_states,
|
241 |
+
mask=None,
|
242 |
+
key_value_states=None,
|
243 |
+
position_bias=None,
|
244 |
+
):
|
245 |
+
"""
|
246 |
+
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
247 |
+
"""
|
248 |
+
# Input is (batch_size, seq_length, dim)
|
249 |
+
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
250 |
+
batch_size, seq_length = hidden_states.shape[:2]
|
251 |
+
key_length = seq_length if key_value_states is None else key_value_states.shape[1]
|
252 |
+
q = self.Wq(hidden_states)
|
253 |
+
if key_value_states is None:
|
254 |
+
k = self.Wk(hidden_states)
|
255 |
+
v = self.Wv(hidden_states)
|
256 |
+
else:
|
257 |
+
k = self.Wk(key_value_states)
|
258 |
+
v = self.Wv(key_value_states)
|
259 |
+
|
260 |
+
q = q.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim)
|
261 |
+
k = k.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
|
262 |
+
v = v.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
|
263 |
+
|
264 |
+
if position_bias is None and self.pe_encoding is not None and self.attention_type != "fa2_rpe":
|
265 |
+
q, k, v, position_bias = self.pe_encoding(q, k, v)
|
266 |
+
|
267 |
+
if position_bias is not None and self.use_full_bias_size:
|
268 |
+
position_bias = position_bias.expand(q.shape[0], q.shape[2], q.shape[1], k.shape[1])
|
269 |
+
if self.attention_type == "fa2_bias" or self.attention_type == "triton":
|
270 |
+
position_bias = position_bias.contiguous()
|
271 |
+
|
272 |
+
if position_bias is not None and mask is not None and self.use_masking:
|
273 |
+
mask = mask.unsqueeze(1)
|
274 |
+
if len(mask.shape) == 3:
|
275 |
+
mask = mask.unsqueeze(3)
|
276 |
+
position_bias = torch.where(mask, position_bias, torch.finfo(hidden_states.dtype).min)
|
277 |
+
|
278 |
+
if self.attention_type == "fa2_bias":
|
279 |
+
output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, \
|
280 |
+
attn_bias=position_bias, causal=self.is_causal)
|
281 |
+
elif self.attention_type == "fa2_rpe":
|
282 |
+
output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, \
|
283 |
+
rpe_weights=self.pe_encoding.relative_attention_bias.weight.t(), \
|
284 |
+
rpe_max_distance=self.relative_attention_max_distance, \
|
285 |
+
causal=self.is_causal)
|
286 |
+
elif self.attention_type == "triton":
|
287 |
+
q = q.permute(0, 2, 1, 3)
|
288 |
+
k = k.permute(0, 2, 1, 3)
|
289 |
+
v = v.permute(0, 2, 1, 3)
|
290 |
+
output = flash_attention_v2_bias(q, k, v, position_bias, self.is_causal, self.softmax_scale)
|
291 |
+
output = output.permute(0, 2, 1, 3)
|
292 |
+
else: # use flash attention
|
293 |
+
q = q.permute(0, 2, 1, 3)
|
294 |
+
k = k.permute(0, 2, 1, 3)
|
295 |
+
v = v.permute(0, 2, 1, 3)
|
296 |
+
output = attn_ref(q, k, v, position_bias, dropout_p=self.p_dropout, sm_scale=self.softmax_scale, causal=self.is_causal)
|
297 |
+
output = output.permute(0, 2, 1, 3)
|
298 |
+
|
299 |
+
output = self.o(output.reshape(output.shape[0], output.shape[1], self.inner_dim))
|
300 |
+
return (output, position_bias)
|
301 |
+
|
302 |
+
|
303 |
+
class FlashT5LayerSelfAttention(nn.Module):
|
304 |
+
def __init__(self, config, has_positional_encoding=False):
|
305 |
+
super().__init__()
|
306 |
+
self.self_attention = FlashT5Attention(config, has_positional_encoding=has_positional_encoding, is_causal=config.is_decoder)
|
307 |
+
self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
|
308 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
309 |
+
|
310 |
+
def forward(
|
311 |
+
self,
|
312 |
+
hidden_states,
|
313 |
+
attention_mask=None,
|
314 |
+
position_bias=None,
|
315 |
+
):
|
316 |
+
normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states)
|
317 |
+
attention_output = self.self_attention(
|
318 |
+
normed_hidden_states,
|
319 |
+
mask=attention_mask,
|
320 |
+
position_bias=position_bias,
|
321 |
+
)
|
322 |
+
hidden_states = hidden_states + self.dropout(attention_output[0])
|
323 |
+
outputs = (hidden_states,) + attention_output[1:]
|
324 |
+
return outputs
|
325 |
+
|
326 |
+
|
327 |
+
class FlashT5LayerCrossAttention(nn.Module):
|
328 |
+
def __init__(self, config):
|
329 |
+
super().__init__()
|
330 |
+
self.cross_attention = FlashT5Attention(config, has_positional_encoding=False)
|
331 |
+
self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
|
332 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
333 |
+
|
334 |
+
def forward(
|
335 |
+
self,
|
336 |
+
hidden_states,
|
337 |
+
key_value_states,
|
338 |
+
attention_mask=None,
|
339 |
+
position_bias=None,
|
340 |
+
):
|
341 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
342 |
+
attention_output = self.cross_attention(
|
343 |
+
normed_hidden_states,
|
344 |
+
mask=attention_mask,
|
345 |
+
key_value_states=key_value_states,
|
346 |
+
position_bias=position_bias,
|
347 |
+
)
|
348 |
+
layer_output = hidden_states + self.dropout(attention_output[0])
|
349 |
+
outputs = (layer_output,) + attention_output[1:]
|
350 |
+
return outputs
|
351 |
+
|
352 |
+
|
353 |
+
class FlashT5Block(nn.Module):
|
354 |
+
def __init__(self, config, has_positional_encoding=False):
|
355 |
+
super().__init__()
|
356 |
+
self.is_decoder = config.is_decoder
|
357 |
+
|
358 |
+
self.self_attention_layer = FlashT5LayerSelfAttention(config, has_positional_encoding=has_positional_encoding)
|
359 |
+
|
360 |
+
if self.is_decoder:
|
361 |
+
self.cross_attention_layer = FlashT5LayerCrossAttention(config)
|
362 |
+
|
363 |
+
self.ff_layer = FlashT5LayerFF(config)
|
364 |
+
|
365 |
+
def forward(
|
366 |
+
self,
|
367 |
+
hidden_states,
|
368 |
+
attention_mask=None,
|
369 |
+
position_bias=None,
|
370 |
+
encoder_hidden_states=None,
|
371 |
+
encoder_attention_mask=None,
|
372 |
+
encoder_decoder_position_bias=None,
|
373 |
+
):
|
374 |
+
self_attention_outputs = self.self_attention_layer(
|
375 |
+
hidden_states,
|
376 |
+
attention_mask=attention_mask,
|
377 |
+
position_bias=position_bias,
|
378 |
+
)
|
379 |
+
hidden_states = self_attention_outputs[0]
|
380 |
+
attention_outputs = self_attention_outputs[1:] # Relative position weights
|
381 |
+
|
382 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
383 |
+
cross_attention_outputs = self.cross_attention_layer(
|
384 |
+
hidden_states,
|
385 |
+
key_value_states=encoder_hidden_states,
|
386 |
+
attention_mask=encoder_attention_mask,
|
387 |
+
position_bias=encoder_decoder_position_bias,
|
388 |
+
)
|
389 |
+
hidden_states = cross_attention_outputs[0]
|
390 |
+
|
391 |
+
# Keep relative position weights
|
392 |
+
attention_outputs = attention_outputs + cross_attention_outputs[1:]
|
393 |
+
|
394 |
+
# Apply Feed Forward layer
|
395 |
+
hidden_states = self.ff_layer(hidden_states)
|
396 |
+
|
397 |
+
outputs = (hidden_states,) + attention_outputs
|
398 |
+
return outputs # hidden-states, (self-attention position bias), (cross-attention position bias)
|
399 |
+
|
400 |
+
|
401 |
+
class FlashT5Stack(nn.Module, ModuleUtilsMixin):
|
402 |
+
def __init__(self, config, embed_tokens):
|
403 |
+
super().__init__()
|
404 |
+
assert embed_tokens is not None
|
405 |
+
|
406 |
+
self.config = config
|
407 |
+
self.embed_tokens = embed_tokens
|
408 |
+
self.is_decoder = config.is_decoder
|
409 |
+
|
410 |
+
self.block = nn.ModuleList(
|
411 |
+
[FlashT5Block(config, has_positional_encoding=bool(i == 0)) for i in range(config.num_layers)]
|
412 |
+
)
|
413 |
+
|
414 |
+
self.final_layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
|
415 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
416 |
+
|
417 |
+
def forward(
|
418 |
+
self,
|
419 |
+
input_ids=None,
|
420 |
+
# input_ids: Optional[torch.LongTensor] = None,
|
421 |
+
attention_mask=None,
|
422 |
+
encoder_hidden_states=None,
|
423 |
+
encoder_attention_mask=None,
|
424 |
+
inputs_embeds=None,
|
425 |
+
head_mask=None,
|
426 |
+
cross_attn_head_mask=None,
|
427 |
+
past_key_values=None,
|
428 |
+
use_cache=None,
|
429 |
+
output_attentions=None,
|
430 |
+
output_hidden_states=None,
|
431 |
+
return_dict=None,
|
432 |
+
) -> BaseModelOutput:
|
433 |
+
input_shape = input_ids.size()
|
434 |
+
batch_size, seq_length = input_shape
|
435 |
+
|
436 |
+
if inputs_embeds is None:
|
437 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
438 |
+
|
439 |
+
if torch.is_autocast_enabled() and input_ids.device.type == 'cuda':
|
440 |
+
inputs_embeds = inputs_embeds.to(torch.get_autocast_gpu_dtype())
|
441 |
+
|
442 |
+
# Masking
|
443 |
+
if attention_mask is None:
|
444 |
+
attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=torch.bool)
|
445 |
+
|
446 |
+
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
447 |
+
encoder_seq_length = encoder_hidden_states.shape[1]
|
448 |
+
encoder_attention_mask = torch.ones(
|
449 |
+
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.bool
|
450 |
+
)
|
451 |
+
|
452 |
+
position_bias = None
|
453 |
+
encoder_decoder_position_bias = None
|
454 |
+
|
455 |
+
hidden_states = self.dropout(inputs_embeds)
|
456 |
+
|
457 |
+
for _, layer_module in enumerate(self.block):
|
458 |
+
layer_outputs = layer_module(
|
459 |
+
hidden_states,
|
460 |
+
attention_mask=attention_mask,
|
461 |
+
position_bias=position_bias,
|
462 |
+
encoder_hidden_states=encoder_hidden_states,
|
463 |
+
encoder_attention_mask=encoder_attention_mask,
|
464 |
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
465 |
+
)
|
466 |
+
|
467 |
+
# We share the position biases between the layers - the first layer store them
|
468 |
+
position_bias = layer_outputs[1]
|
469 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
470 |
+
encoder_decoder_position_bias = layer_outputs[2]
|
471 |
+
|
472 |
+
hidden_states = layer_outputs[0]
|
473 |
+
|
474 |
+
hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states)
|
475 |
+
hidden_states = self.dropout(hidden_states)
|
476 |
+
|
477 |
+
return BaseModelOutput(
|
478 |
+
last_hidden_state=hidden_states
|
479 |
+
)
|
480 |
+
|
481 |
+
|
482 |
+
|
483 |
+
class FlashT5PreTrainedModel(PreTrainedModel):
|
484 |
+
"""
|
485 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
486 |
+
models.
|
487 |
+
"""
|
488 |
+
|
489 |
+
config_class = FlashT5Config
|
490 |
+
base_model_prefix = "transformer"
|
491 |
+
is_parallelizable = False
|
492 |
+
supports_gradient_checkpointing = True
|
493 |
+
_no_split_modules = ["FlashT5Block"]
|
494 |
+
_keep_in_fp32_modules = []
|
495 |
+
|
496 |
+
def _init_weights(self, module):
|
497 |
+
factor = self.config.initializer_factor # Used for testing weights initialization
|
498 |
+
if isinstance(module, FlashT5LayerNorm):
|
499 |
+
module.weight.data.fill_(factor * 1.0)
|
500 |
+
elif isinstance(module, (FlashT5ForConditionalGeneration)):
|
501 |
+
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
502 |
+
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
|
503 |
+
module.lm_head.weight.data.normal_(mean=0.0, std=factor * self.config.d_model ** -0.5)
|
504 |
+
elif isinstance(module, FlashT5DenseGatedAct):
|
505 |
+
d_ff, d_model = module.wi_0.weight.data.size()
|
506 |
+
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
507 |
+
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
508 |
+
elif isinstance(module, FlashT5LayerFF):
|
509 |
+
d_ff, d_model = module.wo.weight.data.size()
|
510 |
+
module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
|
511 |
+
elif isinstance(module, FlashT5Attention):
|
512 |
+
d_model = self.config.d_model
|
513 |
+
key_value_proj_dim = self.config.d_kv
|
514 |
+
n_heads = self.config.num_heads
|
515 |
+
module.Wq.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
|
516 |
+
module.Wk.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
517 |
+
module.Wv.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
518 |
+
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
|
519 |
+
if module.has_positional_encoding:
|
520 |
+
if hasattr(module.pe_encoding, "relative_attention_bias"):
|
521 |
+
module.pe_encoding.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
522 |
+
|
523 |
+
def _shift_right(self, input_ids):
|
524 |
+
decoder_start_token_id = self.config.decoder_start_token_id
|
525 |
+
pad_token_id = self.config.pad_token_id
|
526 |
+
|
527 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
528 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
529 |
+
shifted_input_ids[..., 0] = decoder_start_token_id
|
530 |
+
|
531 |
+
# replace possible -100 values in labels by `pad_token_id`
|
532 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
533 |
+
|
534 |
+
return shifted_input_ids
|
535 |
+
|
536 |
+
|
537 |
+
class FlashT5Model(FlashT5PreTrainedModel):
|
538 |
+
|
539 |
+
def __init__(self, config: FlashT5Config):
|
540 |
+
super().__init__(config)
|
541 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
542 |
+
|
543 |
+
encoder_config = copy.deepcopy(config)
|
544 |
+
encoder_config.is_decoder = False
|
545 |
+
encoder_config.use_cache = False
|
546 |
+
encoder_config.is_encoder_decoder = False
|
547 |
+
self.encoder = FlashT5Stack(encoder_config, self.shared)
|
548 |
+
|
549 |
+
decoder_config = copy.deepcopy(config)
|
550 |
+
decoder_config.is_decoder = True
|
551 |
+
decoder_config.is_encoder_decoder = False
|
552 |
+
decoder_config.num_layers = config.num_decoder_layers
|
553 |
+
self.decoder = FlashT5Stack(decoder_config, self.shared)
|
554 |
+
|
555 |
+
# Initialize weights and apply final processing
|
556 |
+
self.post_init()
|
557 |
+
|
558 |
+
# Model parallel
|
559 |
+
self.model_parallel = False
|
560 |
+
self.device_map = None
|
561 |
+
|
562 |
+
def get_input_embeddings(self):
|
563 |
+
return self.shared
|
564 |
+
|
565 |
+
def set_input_embeddings(self, new_embeddings):
|
566 |
+
self.shared = new_embeddings
|
567 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
568 |
+
self.decoder.set_input_embeddings(new_embeddings)
|
569 |
+
|
570 |
+
def get_encoder(self):
|
571 |
+
return self.encoder
|
572 |
+
|
573 |
+
def get_decoder(self):
|
574 |
+
return self.decoder
|
575 |
+
|
576 |
+
def forward(
|
577 |
+
self,
|
578 |
+
input_ids=None,
|
579 |
+
# input_ids: Optional[torch.LongTensor] = None,
|
580 |
+
attention_mask=None,
|
581 |
+
encoder_hidden_states=None,
|
582 |
+
encoder_attention_mask=None,
|
583 |
+
inputs_embeds=None,
|
584 |
+
head_mask=None,
|
585 |
+
cross_attn_head_mask=None,
|
586 |
+
past_key_values=None,
|
587 |
+
use_cache=None,
|
588 |
+
output_attentions=None,
|
589 |
+
output_hidden_states=None,
|
590 |
+
return_dict=None,
|
591 |
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
592 |
+
|
593 |
+
# Encode if needed (training, first prediction pass)
|
594 |
+
if encoder_outputs is None:
|
595 |
+
encoder_outputs = self.encoder(
|
596 |
+
input_ids=input_ids,
|
597 |
+
attention_mask=attention_mask,
|
598 |
+
inputs_embeds=inputs_embeds
|
599 |
+
)
|
600 |
+
|
601 |
+
hidden_states = encoder_outputs[0]
|
602 |
+
|
603 |
+
# Decode
|
604 |
+
decoder_outputs = self.decoder(
|
605 |
+
input_ids=decoder_input_ids,
|
606 |
+
attention_mask=decoder_attention_mask,
|
607 |
+
inputs_embeds=decoder_inputs_embeds,
|
608 |
+
encoder_hidden_states=hidden_states,
|
609 |
+
encoder_attention_mask=attention_mask
|
610 |
+
)
|
611 |
+
|
612 |
+
return Seq2SeqModelOutput(
|
613 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
614 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
615 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
616 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
617 |
+
)
|
618 |
+
|
619 |
+
class FlashT5ForConditionalGeneration(FlashT5PreTrainedModel):
|
620 |
+
|
621 |
+
def __init__(self, config: FlashT5Config):
|
622 |
+
super().__init__(config)
|
623 |
+
config.is_encoder_decoder = False
|
624 |
+
assert not config.tie_word_embeddings
|
625 |
+
|
626 |
+
self.config = config
|
627 |
+
self.model_dim = config.d_model
|
628 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
629 |
+
|
630 |
+
encoder_config = copy.deepcopy(config)
|
631 |
+
encoder_config.is_decoder = False
|
632 |
+
self.encoder = FlashT5Stack(encoder_config, self.shared)
|
633 |
+
|
634 |
+
decoder_config = copy.deepcopy(config)
|
635 |
+
decoder_config.is_decoder = True
|
636 |
+
decoder_config.num_layers = config.num_decoder_layers
|
637 |
+
self.decoder = FlashT5Stack(decoder_config, self.shared)
|
638 |
+
|
639 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
640 |
+
|
641 |
+
self.loss_fct = FlashT5CrossEntropyLoss(z_loss_factor=config.z_loss,
|
642 |
+
label_smoothing=config.label_smoothing,
|
643 |
+
use_triton_crossentropy=config.use_triton_crossentropy,
|
644 |
+
inplace_backward=config.crossentropy_inplace_backward)
|
645 |
+
|
646 |
+
# Initialize weights and apply final processing
|
647 |
+
self.post_init()
|
648 |
+
|
649 |
+
def prepare_inputs_for_generation(
|
650 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
651 |
+
):
|
652 |
+
# do nothing
|
653 |
+
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
654 |
+
|
655 |
+
return model_inputs
|
656 |
+
|
657 |
+
def get_input_embeddings(self):
|
658 |
+
return self.shared
|
659 |
+
|
660 |
+
def set_input_embeddings(self, value):
|
661 |
+
self.shared = value
|
662 |
+
|
663 |
+
def generate(
|
664 |
+
self,
|
665 |
+
input_ids: Optional[torch.LongTensor] = None,
|
666 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
667 |
+
max_length = 32,
|
668 |
+
**kwargs,
|
669 |
+
) -> torch.LongTensor:
|
670 |
+
"""
|
671 |
+
input_ids: B x L_encoder, int64
|
672 |
+
attention_mask: B x L_encoder, int64
|
673 |
+
1 for tokens to attend to, 0 for tokens to ignore
|
674 |
+
|
675 |
+
Generation:
|
676 |
+
Starts with 0, ends with 1, padding is 0
|
677 |
+
|
678 |
+
# For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s
|
679 |
+
"""
|
680 |
+
B, _ = input_ids.size()
|
681 |
+
labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device)
|
682 |
+
encoder_hidden_states = None
|
683 |
+
|
684 |
+
for _ in range(max_length):
|
685 |
+
out = self.forward(
|
686 |
+
input_ids=input_ids,
|
687 |
+
attention_mask=attention_mask,
|
688 |
+
decoder_input_ids=labels,
|
689 |
+
encoder_hidden_states=encoder_hidden_states,
|
690 |
+
)
|
691 |
+
encoder_hidden_states = out.encoder_hidden_states
|
692 |
+
top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1)
|
693 |
+
labels = torch.cat([labels, top_labels], dim=-1)
|
694 |
+
|
695 |
+
if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B:
|
696 |
+
break
|
697 |
+
|
698 |
+
labels[:, -1] = 1
|
699 |
+
|
700 |
+
# Mask out the padding, i.e., all positions after the first 1 with 0
|
701 |
+
B, L = labels.size()
|
702 |
+
mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1)
|
703 |
+
labels = labels.masked_fill(~mask, 0)
|
704 |
+
|
705 |
+
return labels
|
706 |
+
|
707 |
+
def forward(
|
708 |
+
self,
|
709 |
+
input_ids: Optional[torch.LongTensor] = None,
|
710 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
711 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
712 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
713 |
+
labels: Optional[torch.LongTensor] = None,
|
714 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
715 |
+
) -> Seq2SeqLMOutput:
|
716 |
+
"""
|
717 |
+
input_ids: B x L_encoder, int64
|
718 |
+
attention_mask: B x L_encoder, int64
|
719 |
+
1 for tokens to attend to, 0 for tokens to ignore
|
720 |
+
labels: B x L_decoder, int64
|
721 |
+
"""
|
722 |
+
if encoder_hidden_states is None:
|
723 |
+
encoder_hidden_states = self.encoder(
|
724 |
+
input_ids=input_ids,
|
725 |
+
attention_mask=attention_mask,
|
726 |
+
)[0]
|
727 |
+
|
728 |
+
hidden_states = encoder_hidden_states
|
729 |
+
|
730 |
+
if labels is not None and decoder_input_ids is None:
|
731 |
+
decoder_input_ids = self._shift_right(labels)
|
732 |
+
|
733 |
+
decoder_outputs = self.decoder(
|
734 |
+
input_ids=decoder_input_ids,
|
735 |
+
attention_mask=decoder_attention_mask,
|
736 |
+
encoder_hidden_states=hidden_states,
|
737 |
+
encoder_attention_mask=attention_mask,
|
738 |
+
)
|
739 |
+
|
740 |
+
sequence_output = decoder_outputs[0]
|
741 |
+
lm_logits = self.lm_head(sequence_output)
|
742 |
+
|
743 |
+
loss = None
|
744 |
+
if labels is not None:
|
745 |
+
loss = self.loss_fct(lm_logits, labels)
|
746 |
+
|
747 |
+
return Seq2SeqLMOutput(
|
748 |
+
loss=loss,
|
749 |
+
logits=lm_logits,
|
750 |
+
encoder_hidden_states=encoder_hidden_states
|
751 |
+
)
|
752 |
+
|
753 |
+
|
754 |
+
class FlashT5EncoderModel(FlashT5PreTrainedModel):
|
755 |
+
def __init__(self, config: FlashT5Config):
|
756 |
+
super().__init__(config)
|
757 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
758 |
+
encoder_config = copy.deepcopy(config)
|
759 |
+
encoder_config.use_cache = False
|
760 |
+
encoder_config.is_encoder_decoder = False
|
761 |
+
self.encoder = FlashT5Stack(encoder_config, self.shared)
|
762 |
+
# Initialize weights and apply final processing
|
763 |
+
self.post_init()
|
764 |
+
# Model parallel
|
765 |
+
self.model_parallel = False
|
766 |
+
self.device_map = None
|
767 |
+
def get_input_embeddings(self):
|
768 |
+
return self.shared
|
769 |
+
def set_input_embeddings(self, new_embeddings):
|
770 |
+
self.shared = new_embeddings
|
771 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
772 |
+
def get_encoder(self):
|
773 |
+
return self.encoder
|
774 |
+
def forward(
|
775 |
+
self,
|
776 |
+
input_ids: Optional[torch.LongTensor] = None,
|
777 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
778 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
779 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
780 |
+
output_attentions: Optional[bool] = None,
|
781 |
+
output_hidden_states: Optional[bool] = None,
|
782 |
+
return_dict: Optional[bool] = None,
|
783 |
+
token_type_ids: Optional[bool] = None,
|
784 |
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
785 |
+
encoder_outputs = self.encoder(
|
786 |
+
input_ids=input_ids,
|
787 |
+
attention_mask=attention_mask,
|
788 |
+
inputs_embeds=inputs_embeds
|
789 |
+
)
|
790 |
+
return encoder_outputs
|
optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a61ade6d36e7273c43a001cc4e665c4bae25570aecff45ed23f189b8a2b8687
|
3 |
+
size 1174905530
|
positional_encoding.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
try:
|
6 |
+
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_func, apply_rotary_emb_kv_
|
7 |
+
except:
|
8 |
+
apply_rotary_emb_qkv_, apply_rotary_emb_func, apply_rotary_emb_kv_ = None, None, None
|
9 |
+
|
10 |
+
class RelativePositionalEncoding(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, relative_attention_num_buckets, relative_attention_max_distance, n_heads, max_sequence_length, bidirectional=True, randomized_position=False):
|
13 |
+
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.relative_attention_num_buckets = relative_attention_num_buckets
|
17 |
+
self.relative_attention_max_distance = relative_attention_max_distance
|
18 |
+
self.n_heads = n_heads
|
19 |
+
self.max_sequence_length = max_sequence_length
|
20 |
+
self.bidirectional = bidirectional
|
21 |
+
self.randomized_position = randomized_position
|
22 |
+
|
23 |
+
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
27 |
+
"""
|
28 |
+
Adapted from Mesh Tensorflow:
|
29 |
+
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
30 |
+
|
31 |
+
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
32 |
+
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
33 |
+
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
34 |
+
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
35 |
+
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
36 |
+
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
37 |
+
|
38 |
+
Args:
|
39 |
+
relative_position: an int32 Tensor
|
40 |
+
bidirectional: a boolean - whether the attention is bidirectional
|
41 |
+
num_buckets: an integer
|
42 |
+
max_distance: an integer
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
46 |
+
"""
|
47 |
+
relative_buckets = 0
|
48 |
+
if bidirectional:
|
49 |
+
num_buckets //= 2
|
50 |
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
51 |
+
relative_position = torch.abs(relative_position)
|
52 |
+
else:
|
53 |
+
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
54 |
+
# now relative_position is in the range [0, inf)
|
55 |
+
|
56 |
+
# half of the buckets are for exact increments in positions
|
57 |
+
max_exact = num_buckets // 2
|
58 |
+
is_small = relative_position < max_exact
|
59 |
+
|
60 |
+
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
61 |
+
relative_position_if_large = max_exact + (
|
62 |
+
torch.log(relative_position.float() / max_exact)
|
63 |
+
/ torch.log(torch.tensor(max_distance / max_exact))
|
64 |
+
* (num_buckets - max_exact)
|
65 |
+
).to(torch.long)
|
66 |
+
relative_position_if_large = torch.min(
|
67 |
+
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
68 |
+
)
|
69 |
+
|
70 |
+
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
71 |
+
return relative_buckets
|
72 |
+
|
73 |
+
def compute_bias(self, query_length, key_length, device=None):
|
74 |
+
"""Compute binned relative position bias"""
|
75 |
+
if device is None:
|
76 |
+
device = self.relative_attention_bias.weight.device
|
77 |
+
|
78 |
+
if self.randomized_position:
|
79 |
+
context_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device)
|
80 |
+
context_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length])
|
81 |
+
context_indices_rand[0] = 0 # root the first element of the sequence
|
82 |
+
context_position = context_position[context_indices_rand][:, None]
|
83 |
+
|
84 |
+
memory_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device)
|
85 |
+
memory_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length])
|
86 |
+
memory_indices_rand[0] = 0 # root the first element of the sequence
|
87 |
+
memory_position = memory_position[memory_indices_rand][None, :]
|
88 |
+
else:
|
89 |
+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
90 |
+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
91 |
+
|
92 |
+
relative_position = memory_position - context_position # shape (query_length, key_length)
|
93 |
+
|
94 |
+
relative_position_bucket = self._relative_position_bucket(
|
95 |
+
relative_position, # shape (query_length, key_length)
|
96 |
+
bidirectional=self.bidirectional,
|
97 |
+
num_buckets=self.relative_attention_num_buckets,
|
98 |
+
max_distance=self.relative_attention_max_distance,
|
99 |
+
)
|
100 |
+
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
101 |
+
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
102 |
+
return values
|
103 |
+
|
104 |
+
def forward(self, q, k=None, v=None):
|
105 |
+
|
106 |
+
query_length = q.shape[1]
|
107 |
+
key_length = k.shape[1] if k is not None else query_length
|
108 |
+
bias = self.compute_bias(query_length, key_length, device=q.device).contiguous().to(q.dtype)
|
109 |
+
|
110 |
+
return q, k, v, bias
|
111 |
+
|
112 |
+
|
113 |
+
class ALiBiPositionalEncoding(nn.Module):
|
114 |
+
|
115 |
+
def __init__(self, max_sequence_length, num_heads, mode='symetric', randomized_position=False):
|
116 |
+
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
self.max_sequence_length = max_sequence_length
|
120 |
+
self.num_heads = num_heads
|
121 |
+
self.mode = mode
|
122 |
+
self.randomized_position = randomized_position
|
123 |
+
|
124 |
+
self.alibi_bias = self.build_alibi_bias_matrix(num_heads, max_sequence_length, mode)
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def fill_with_neg_inf(t):
|
128 |
+
"""FP16-compatible function that fills a tensor with -inf."""
|
129 |
+
return t.float().fill_(float("-inf")).type_as(t)
|
130 |
+
|
131 |
+
def get_slopes(self, n):
|
132 |
+
|
133 |
+
def get_slopes_power_of_2(n):
|
134 |
+
start = (2**(-2**-(math.log2(n)-3)))
|
135 |
+
ratio = start
|
136 |
+
return [start*ratio**i for i in range(n)]
|
137 |
+
|
138 |
+
if math.log2(n).is_integer():
|
139 |
+
return get_slopes_power_of_2(n) #In the paper, we only train models that have 2^a heads for some a. This function has
|
140 |
+
else: #some good properties that only occur when the input is a power of 2. To maintain that even
|
141 |
+
closest_power_of_2 = 2**math.floor(math.log2(n)) #when the number of heads is not a power of 2, we use this workaround.
|
142 |
+
return get_slopes_power_of_2(closest_power_of_2) + self.get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
|
143 |
+
|
144 |
+
def build_symetric_alibi_bias_matrix(self, num_heads, maxpos):
|
145 |
+
|
146 |
+
context_position = torch.arange(maxpos)[:, None]
|
147 |
+
memory_position = torch.arange(maxpos)[None, :]
|
148 |
+
|
149 |
+
relative_position = memory_position - context_position
|
150 |
+
relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads, -1,-1)
|
151 |
+
|
152 |
+
slopes = torch.Tensor(self.get_slopes(num_heads)) * -1
|
153 |
+
alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
|
154 |
+
return alibi.view(1, num_heads, maxpos, maxpos)
|
155 |
+
|
156 |
+
def build_asymetric_alibi_bias_matrix(self, num_heads, maxpos):
|
157 |
+
_future_mask_right = torch.triu(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).unsqueeze(0).repeat(num_heads // 2, 1, 1)
|
158 |
+
_future_mask_left = torch.tril(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), -1).unsqueeze(0).repeat(num_heads // 2, 1, 1)
|
159 |
+
|
160 |
+
nonsym_mask = torch.cat((_future_mask_right, _future_mask_left), dim = 0).unsqueeze(0)
|
161 |
+
slopes = torch.Tensor(self.get_slopes(num_heads // 2)) * -1
|
162 |
+
|
163 |
+
context_position = torch.arange(maxpos)[:, None]
|
164 |
+
memory_position = torch.arange(maxpos)[None, :]
|
165 |
+
|
166 |
+
relative_position = memory_position - context_position
|
167 |
+
relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads // 2, -1,-1)
|
168 |
+
|
169 |
+
alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
|
170 |
+
alibi = alibi.view(1, num_heads // 2, maxpos, maxpos)
|
171 |
+
alibi = alibi.repeat(1, 2, 1, 1)
|
172 |
+
|
173 |
+
return alibi.view(1, num_heads, maxpos, maxpos) + nonsym_mask.view(1, num_heads, maxpos, maxpos)
|
174 |
+
|
175 |
+
|
176 |
+
def build_alibi_bias_matrix(self, num_heads, maxpos, mode='symetric'):
|
177 |
+
if mode == 'symetric':
|
178 |
+
return self.build_symetric_alibi_bias_matrix(num_heads, maxpos)
|
179 |
+
elif mode == 'asymetric':
|
180 |
+
return self.build_asymetric_alibi_bias_matrix(num_heads, maxpos)
|
181 |
+
else:
|
182 |
+
raise ValueError("ALiBi mode " + mode + " is not implemented.")
|
183 |
+
|
184 |
+
def forward(self, q, k=None, v=None):
|
185 |
+
|
186 |
+
query_length = q.shape[1]
|
187 |
+
key_length = k.shape[1] if k is not None else query_length
|
188 |
+
assert (self.alibi_bias.shape[1] < query_length) & (self.alibi_bias.shape[1] < key_length), "Sequence length larger than allowed alibi bound"
|
189 |
+
|
190 |
+
if self.randomized_position:
|
191 |
+
query_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length])
|
192 |
+
key_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length])
|
193 |
+
|
194 |
+
# ground sequences
|
195 |
+
query_indices_rand[0] = 0
|
196 |
+
key_indices_rand[0] = 0
|
197 |
+
|
198 |
+
bias = self.alibi_bias[:, :, query_indices_rand, key_indices_rand].to(q.device)
|
199 |
+
|
200 |
+
else:
|
201 |
+
bias = self.alibi_bias[:, :, :query_length, :key_length].to(q.device)
|
202 |
+
|
203 |
+
return q, k, v, bias.to(q.dtype).contiguous()
|
204 |
+
|
205 |
+
class RotaryPositionalEncoding(nn.Module):
|
206 |
+
|
207 |
+
def __init__(self, dim,
|
208 |
+
max_sequence_length,
|
209 |
+
base=10000.0,
|
210 |
+
interleaved=False,
|
211 |
+
scale_base=None,
|
212 |
+
randomized_position=False):
|
213 |
+
|
214 |
+
super().__init__()
|
215 |
+
|
216 |
+
self.max_sequence_length = max_sequence_length
|
217 |
+
self.randomized_position = randomized_position
|
218 |
+
|
219 |
+
self.dim = dim
|
220 |
+
self.base = base
|
221 |
+
self.interleaved = interleaved
|
222 |
+
self.scale_base = scale_base
|
223 |
+
|
224 |
+
inv_freq = self._compute_inv_freq()
|
225 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
226 |
+
|
227 |
+
scale = (
|
228 |
+
(torch.arange(0, dim, 2, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
229 |
+
if scale_base is not None
|
230 |
+
else None
|
231 |
+
)
|
232 |
+
self.register_buffer("scale", scale, persistent=False)
|
233 |
+
|
234 |
+
self._cos_cached = None
|
235 |
+
self._sin_cached = None
|
236 |
+
self._cos_k_cached = None
|
237 |
+
self._sin_k_cached = None
|
238 |
+
|
239 |
+
def _compute_inv_freq(self, device=None):
|
240 |
+
return 1.0 / (
|
241 |
+
self.base
|
242 |
+
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
243 |
+
)
|
244 |
+
|
245 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
246 |
+
# Reset the tables if the sequence length has changed,
|
247 |
+
# if we're on a new device (possibly due to tracing for instance),
|
248 |
+
# or if we're switching from inference mode to training
|
249 |
+
if (
|
250 |
+
self._cos_cached is None
|
251 |
+
or self._cos_cached.device != device
|
252 |
+
or self._cos_cached.dtype != dtype
|
253 |
+
or (self.training and self._cos_cached.is_inference())
|
254 |
+
):
|
255 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
256 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
257 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
258 |
+
inv_freq = self._compute_inv_freq(device=device)
|
259 |
+
|
260 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
261 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
262 |
+
t = torch.arange(seqlen, device=device, dtype=dtype)
|
263 |
+
freqs = torch.outer(t, inv_freq)
|
264 |
+
if self.scale is None:
|
265 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
266 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
267 |
+
self._cos_k_cached = None
|
268 |
+
self._sin_k_cached = None
|
269 |
+
else:
|
270 |
+
power = (
|
271 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
272 |
+
- seqlen // 2
|
273 |
+
) / self.scale_base
|
274 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
275 |
+
# We want the multiplication by scale to happen in fp32
|
276 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
277 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
278 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
279 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
280 |
+
|
281 |
+
def forward(self, q, k=None, v=None):
|
282 |
+
|
283 |
+
if self._cos_cached is None:
|
284 |
+
self._update_cos_sin_cache(self.max_sequence_length, device=q.device, dtype=q.dtype)
|
285 |
+
|
286 |
+
if k is None and v is None:
|
287 |
+
q = apply_rotary_emb_qkv_(
|
288 |
+
q,
|
289 |
+
self._cos_cached,
|
290 |
+
self._sin_cached,
|
291 |
+
self._cos_k_cached,
|
292 |
+
self._sin_k_cached,
|
293 |
+
interleaved=self.interleaved,
|
294 |
+
seqlen_offsets=0
|
295 |
+
)
|
296 |
+
elif v is None and k is not None:
|
297 |
+
q = apply_rotary_emb_func(
|
298 |
+
q,
|
299 |
+
self._cos_cached,
|
300 |
+
self._sin_cached,
|
301 |
+
interleaved=self.interleaved,
|
302 |
+
inplace=True,
|
303 |
+
seqlen_offsets=0
|
304 |
+
)
|
305 |
+
|
306 |
+
k = apply_rotary_emb_kv_(
|
307 |
+
k,
|
308 |
+
self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
|
309 |
+
self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
|
310 |
+
interleaved=self.interleaved,
|
311 |
+
seqlen_offsets=0,
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
q = apply_rotary_emb_func(
|
315 |
+
q,
|
316 |
+
self._cos_cached,
|
317 |
+
self._sin_cached,
|
318 |
+
interleaved=self.interleaved,
|
319 |
+
inplace=True,
|
320 |
+
seqlen_offsets=0
|
321 |
+
)
|
322 |
+
|
323 |
+
k = apply_rotary_emb_func(
|
324 |
+
k,
|
325 |
+
self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
|
326 |
+
self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
|
327 |
+
interleaved=self.interleaved,
|
328 |
+
seqlen_offsets=0,
|
329 |
+
)
|
330 |
+
|
331 |
+
v = apply_rotary_emb_func(
|
332 |
+
v,
|
333 |
+
self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
|
334 |
+
self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
|
335 |
+
interleaved=self.interleaved,
|
336 |
+
seqlen_offsets=0,
|
337 |
+
)
|
338 |
+
|
339 |
+
return q, k, v, None
|
340 |
+
|
341 |
+
class FIRE(nn.Module):
|
342 |
+
|
343 |
+
def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512., eps=1e-6):
|
344 |
+
"""
|
345 |
+
FIRE attention bias module.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
num_heads: number of attention heads.
|
349 |
+
mlp_width: Width of MLP.
|
350 |
+
init_c: initial value of log transformation parameter
|
351 |
+
init_L: initial value of thresholding parameter
|
352 |
+
eps: small constant for numerical stability
|
353 |
+
"""
|
354 |
+
|
355 |
+
super(FIRE, self).__init__()
|
356 |
+
|
357 |
+
# Define the MLP layers
|
358 |
+
self.mlp = nn.Sequential(
|
359 |
+
nn.Linear(1, mlp_width),
|
360 |
+
nn.ReLU(),
|
361 |
+
nn.Linear(mlp_width, num_heads)
|
362 |
+
)
|
363 |
+
|
364 |
+
# Initialize c (log transformation parameter)
|
365 |
+
self.c = nn.Parameter(torch.tensor(init_c))
|
366 |
+
|
367 |
+
|
368 |
+
# Initialize L (threshold)
|
369 |
+
self.init_L = nn.Parameter(torch.tensor(init_L),
|
370 |
+
requires_grad=False)
|
371 |
+
# Learn a multiplier to L
|
372 |
+
self.L_multiplier = nn.Parameter(torch.tensor(1.0))
|
373 |
+
self.eps = eps
|
374 |
+
|
375 |
+
def apply_fire(self, seq_length, device):
|
376 |
+
"""
|
377 |
+
Compute FIRE attention bias.
|
378 |
+
|
379 |
+
Args:
|
380 |
+
x: input sequence,
|
381 |
+
shape [bsz, seq_len, num_heads, hidden_dim]
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
attention bias,
|
385 |
+
shape [1, num_heads, seq_len, seq_len]
|
386 |
+
"""
|
387 |
+
positions = torch.arange(seq_length,
|
388 |
+
dtype=torch.float32,
|
389 |
+
device=device)
|
390 |
+
|
391 |
+
rel_distance = positions[:, None] - positions[None, :]
|
392 |
+
|
393 |
+
# Thresholding the normalizer
|
394 |
+
threshold = torch.abs(self.L_multiplier * self.init_L)
|
395 |
+
pos_normalizer = torch.max(positions, threshold)
|
396 |
+
pos_normalizer = pos_normalizer[:, None]
|
397 |
+
|
398 |
+
# Amplifying differences among local positions
|
399 |
+
# with log transform
|
400 |
+
rel_distance = torch.sign(rel_distance) * torch.log(
|
401 |
+
torch.abs(self.c * rel_distance) + 1
|
402 |
+
)
|
403 |
+
pos_normalizer = torch.log(
|
404 |
+
torch.abs(self.c * pos_normalizer) + 1
|
405 |
+
) + self.eps
|
406 |
+
|
407 |
+
# Progressive interpolation
|
408 |
+
normalized_distance = rel_distance / pos_normalizer
|
409 |
+
fire_bias = self.mlp(normalized_distance.unsqueeze(-1))
|
410 |
+
fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2)
|
411 |
+
return fire_bias
|
412 |
+
|
413 |
+
def forward(self, q, k=None, v=None):
|
414 |
+
|
415 |
+
bias = self.apply_fire(q.shape[1], device=q.device).contiguous().to(q.dtype)
|
416 |
+
|
417 |
+
return q, k, v, bias
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39cac982d9842a1d653966fb42b2fa3df4077334030c76cd8fa165b5aea244ea
|
3 |
+
size 587431346
|
rms_norm.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
# Copyright 2024 CATIE. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# Modifications to the orignal file
|
17 |
+
# - support for torch.compile
|
18 |
+
|
19 |
+
import triton
|
20 |
+
import triton.language as tl
|
21 |
+
import torch
|
22 |
+
import math
|
23 |
+
from typing import Tuple
|
24 |
+
|
25 |
+
@triton.jit
|
26 |
+
def _rmsnorm_fwd_kernel(
|
27 |
+
X, # pointer to the input
|
28 |
+
Y, # pointer to the output
|
29 |
+
W, # pointer to the weights
|
30 |
+
Rstd, # pointer to the 1/std
|
31 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
32 |
+
stride_y_row,
|
33 |
+
N, # number of columns in X
|
34 |
+
eps, # epsilon to avoid division by zero
|
35 |
+
BLOCK_N: tl.constexpr,
|
36 |
+
IS_EVEN_N: tl.constexpr
|
37 |
+
):
|
38 |
+
|
39 |
+
row = tl.program_id(0)
|
40 |
+
X += row * stride_x_row
|
41 |
+
Y += row * stride_y_row
|
42 |
+
|
43 |
+
# Compute mean and variance
|
44 |
+
cols = tl.arange(0, BLOCK_N)
|
45 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
46 |
+
|
47 |
+
xbar = tl.where(cols < N, x, 0.0)
|
48 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
49 |
+
rstd = 1 / tl.sqrt(var + eps)
|
50 |
+
tl.store(Rstd + row, rstd)
|
51 |
+
|
52 |
+
# Normalize and apply linear transformation
|
53 |
+
mask = cols < N
|
54 |
+
if IS_EVEN_N:
|
55 |
+
w = tl.load(W + cols).to(tl.float32)
|
56 |
+
else:
|
57 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
58 |
+
|
59 |
+
x_hat = x * rstd
|
60 |
+
y = x_hat * w
|
61 |
+
|
62 |
+
# Write output
|
63 |
+
if IS_EVEN_N:
|
64 |
+
tl.store(Y + cols, y)
|
65 |
+
else:
|
66 |
+
tl.store(Y + cols, y, mask=mask)
|
67 |
+
|
68 |
+
@triton.jit
|
69 |
+
def _rmsnorm_bwd_kernel(
|
70 |
+
X, # pointer to the input
|
71 |
+
W, # pointer to the weights
|
72 |
+
DY, # pointer to the output gradient
|
73 |
+
DX, # pointer to the input gradient
|
74 |
+
DW, # pointer to the partial sum of weights gradient
|
75 |
+
Rstd, # pointer to the 1/std
|
76 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
77 |
+
stride_dy_row,
|
78 |
+
stride_dx_row,
|
79 |
+
M, # number of rows in X
|
80 |
+
N, # number of columns in X
|
81 |
+
eps, # epsilon to avoid division by zero
|
82 |
+
rows_per_program,
|
83 |
+
BLOCK_N: tl.constexpr,
|
84 |
+
IS_EVEN_N: tl.constexpr
|
85 |
+
):
|
86 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
87 |
+
row_block_id = tl.program_id(0)
|
88 |
+
row_start = row_block_id * rows_per_program
|
89 |
+
cols = tl.arange(0, BLOCK_N)
|
90 |
+
mask = cols < N
|
91 |
+
X += row_start * stride_x_row
|
92 |
+
|
93 |
+
DY += row_start * stride_dy_row
|
94 |
+
DX += row_start * stride_dx_row
|
95 |
+
|
96 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
97 |
+
|
98 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
99 |
+
|
100 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
101 |
+
|
102 |
+
for row in range(row_start, row_end):
|
103 |
+
# Load data to SRAM
|
104 |
+
if IS_EVEN_N:
|
105 |
+
x = tl.load(X + cols).to(tl.float32)
|
106 |
+
dy = tl.load(DY + cols).to(tl.float32)
|
107 |
+
else:
|
108 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
109 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
110 |
+
|
111 |
+
rstd = tl.load(Rstd + row)
|
112 |
+
|
113 |
+
# Compute dx
|
114 |
+
xhat = x * rstd
|
115 |
+
if not IS_EVEN_N:
|
116 |
+
xhat = tl.where(mask, xhat, 0.0)
|
117 |
+
|
118 |
+
wdy = w * dy
|
119 |
+
dw += dy * xhat
|
120 |
+
|
121 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
122 |
+
dx = (wdy - xhat * c1) * rstd
|
123 |
+
|
124 |
+
tl.store(DX + cols, dx, mask=mask)
|
125 |
+
|
126 |
+
X += stride_x_row
|
127 |
+
|
128 |
+
DY += stride_dy_row
|
129 |
+
DX += stride_dx_row
|
130 |
+
|
131 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
132 |
+
|
133 |
+
|
134 |
+
@torch.library.custom_op("flasht5::rmsnorm_triton_fwd", mutates_args=(), device_types="cuda")
|
135 |
+
def rmsnorm_triton_fwd(
|
136 |
+
X: torch.Tensor,
|
137 |
+
weight: torch.Tensor,
|
138 |
+
eps: float
|
139 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
140 |
+
|
141 |
+
M, N = X.shape
|
142 |
+
|
143 |
+
assert X.stride(-1) == 1
|
144 |
+
|
145 |
+
assert weight.shape == (N,)
|
146 |
+
assert weight.stride(-1) == 1
|
147 |
+
|
148 |
+
# allocate output
|
149 |
+
Y = torch.empty_like(X)
|
150 |
+
assert Y.stride(-1) == 1
|
151 |
+
|
152 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=X.device)
|
153 |
+
|
154 |
+
# Less than 64KB per feature: enqueue fused kernel
|
155 |
+
MAX_FUSED_SIZE = 65536 // X.element_size()
|
156 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
157 |
+
assert N <= BLOCK_N
|
158 |
+
|
159 |
+
# heuristics for number of warps
|
160 |
+
with torch.cuda.device(X.device.index):
|
161 |
+
_rmsnorm_fwd_kernel[(M,)](
|
162 |
+
X,
|
163 |
+
Y,
|
164 |
+
weight,
|
165 |
+
rstd,
|
166 |
+
X.stride(0),
|
167 |
+
Y.stride(0),
|
168 |
+
N,
|
169 |
+
eps,
|
170 |
+
BLOCK_N,
|
171 |
+
(N % BLOCK_N == 0)
|
172 |
+
)
|
173 |
+
|
174 |
+
return Y, rstd
|
175 |
+
|
176 |
+
|
177 |
+
@torch.library.register_fake("flasht5::rmsnorm_triton_fwd")
|
178 |
+
def rmsnorm_triton_fwd_abstract(X, weight, eps):
|
179 |
+
M, N = X.shape
|
180 |
+
|
181 |
+
Y = torch.empty_like(X)
|
182 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=X.device)
|
183 |
+
|
184 |
+
return Y, rstd
|
185 |
+
|
186 |
+
@torch.library.custom_op("flasht5::rmsnorm_triton_bwd", mutates_args=(), device_types="cuda")
|
187 |
+
def rmsnorm_triton_bwd(
|
188 |
+
dy: torch.Tensor,
|
189 |
+
x: torch.Tensor,
|
190 |
+
weight: torch.Tensor,
|
191 |
+
rstd: torch.Tensor,
|
192 |
+
eps: float
|
193 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
194 |
+
M, N = x.shape
|
195 |
+
assert x.stride(-1) == 1
|
196 |
+
assert dy.stride(-1) == 1
|
197 |
+
assert dy.shape == (M, N)
|
198 |
+
|
199 |
+
assert weight.shape == (N,)
|
200 |
+
assert weight.stride(-1) == 1
|
201 |
+
|
202 |
+
# allocate output
|
203 |
+
dx = torch.empty_like(x)
|
204 |
+
|
205 |
+
# Less than 64KB per feature: enqueue fused kernel
|
206 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
207 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
208 |
+
|
209 |
+
assert N <= BLOCK_N
|
210 |
+
|
211 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
212 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
213 |
+
|
214 |
+
rows_per_program = math.ceil(M / sm_count)
|
215 |
+
grid = (sm_count,)
|
216 |
+
with torch.cuda.device(x.device.index):
|
217 |
+
_rmsnorm_bwd_kernel[grid](
|
218 |
+
x,
|
219 |
+
weight,
|
220 |
+
dy,
|
221 |
+
dx,
|
222 |
+
_dw,
|
223 |
+
rstd,
|
224 |
+
x.stride(0),
|
225 |
+
dy.stride(0),
|
226 |
+
dx.stride(0),
|
227 |
+
M,
|
228 |
+
N,
|
229 |
+
eps,
|
230 |
+
rows_per_program,
|
231 |
+
BLOCK_N,
|
232 |
+
(N % BLOCK_N == 0)
|
233 |
+
)
|
234 |
+
dw = _dw.sum(0).to(weight.dtype)
|
235 |
+
|
236 |
+
return dx, dw
|
237 |
+
|
238 |
+
|
239 |
+
@torch.library.register_fake("flasht5::rmsnorm_triton_bwd")
|
240 |
+
def rmsnorm_triton_bwd_abstract(dy, x, weight, rstd, eps):
|
241 |
+
|
242 |
+
M, N = x.shape
|
243 |
+
dx = torch.empty_like(x)
|
244 |
+
dw = torch.empty((1, N), dtype=torch.float32, device=weight.device)
|
245 |
+
|
246 |
+
|
247 |
+
return dx, dw
|
248 |
+
|
249 |
+
|
250 |
+
class Fast_RMS_Layernorm(torch.autograd.Function):
|
251 |
+
@staticmethod
|
252 |
+
def forward(ctx, X, W, eps=1e-6):
|
253 |
+
|
254 |
+
X_orig_shape = X.shape
|
255 |
+
X = X.reshape(-1, X.shape[-1])
|
256 |
+
|
257 |
+
y, rstd, = torch.ops.flasht5.rmsnorm_triton_fwd(X, W, eps)
|
258 |
+
|
259 |
+
y = y.reshape(X_orig_shape)
|
260 |
+
|
261 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
262 |
+
ctx.save_for_backward(X, W, rstd)
|
263 |
+
ctx.x_shape_og = X_orig_shape
|
264 |
+
ctx.eps = eps
|
265 |
+
|
266 |
+
return y
|
267 |
+
|
268 |
+
@staticmethod
|
269 |
+
def backward(ctx, dY):
|
270 |
+
X, weight, rstd = ctx.saved_tensors
|
271 |
+
dY = dY.reshape(-1, dY.shape[-1])
|
272 |
+
|
273 |
+
assert dY.shape == X.shape
|
274 |
+
|
275 |
+
dx, dw = torch.ops.flasht5.rmsnorm_triton_bwd(
|
276 |
+
dY,
|
277 |
+
X,
|
278 |
+
weight,
|
279 |
+
rstd,
|
280 |
+
ctx.eps
|
281 |
+
)
|
282 |
+
|
283 |
+
return dx.reshape(ctx.x_shape_og), dw, None
|
284 |
+
|
285 |
+
def fast_rms_layernorm(X, W, eps):
|
286 |
+
out = Fast_RMS_Layernorm.apply(X, W, eps)
|
287 |
+
return out
|
rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:672546ccb6198cb6029856dffdff7fa8bbc52726b212606155673ca847d54511
|
3 |
+
size 14244
|
scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:605bdfbc75175c1891038202d54ff9a471d27b66aed2deedcdaf8b1fa6ca2185
|
3 |
+
size 1256
|
special_tokens_map.json
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<extra_id_0>",
|
4 |
+
"<extra_id_1>",
|
5 |
+
"<extra_id_2>",
|
6 |
+
"<extra_id_3>",
|
7 |
+
"<extra_id_4>",
|
8 |
+
"<extra_id_5>",
|
9 |
+
"<extra_id_6>",
|
10 |
+
"<extra_id_7>",
|
11 |
+
"<extra_id_8>",
|
12 |
+
"<extra_id_9>",
|
13 |
+
"<extra_id_10>",
|
14 |
+
"<extra_id_11>",
|
15 |
+
"<extra_id_12>",
|
16 |
+
"<extra_id_13>",
|
17 |
+
"<extra_id_14>",
|
18 |
+
"<extra_id_15>",
|
19 |
+
"<extra_id_16>",
|
20 |
+
"<extra_id_17>",
|
21 |
+
"<extra_id_18>",
|
22 |
+
"<extra_id_19>",
|
23 |
+
"<extra_id_20>",
|
24 |
+
"<extra_id_21>",
|
25 |
+
"<extra_id_22>",
|
26 |
+
"<extra_id_23>",
|
27 |
+
"<extra_id_24>",
|
28 |
+
"<extra_id_25>",
|
29 |
+
"<extra_id_26>",
|
30 |
+
"<extra_id_27>",
|
31 |
+
"<extra_id_28>",
|
32 |
+
"<extra_id_29>",
|
33 |
+
"<extra_id_30>",
|
34 |
+
"<extra_id_31>",
|
35 |
+
"<extra_id_32>",
|
36 |
+
"<extra_id_33>",
|
37 |
+
"<extra_id_34>",
|
38 |
+
"<extra_id_35>",
|
39 |
+
"<extra_id_36>",
|
40 |
+
"<extra_id_37>",
|
41 |
+
"<extra_id_38>",
|
42 |
+
"<extra_id_39>",
|
43 |
+
"<extra_id_40>",
|
44 |
+
"<extra_id_41>",
|
45 |
+
"<extra_id_42>",
|
46 |
+
"<extra_id_43>",
|
47 |
+
"<extra_id_44>",
|
48 |
+
"<extra_id_45>",
|
49 |
+
"<extra_id_46>",
|
50 |
+
"<extra_id_47>",
|
51 |
+
"<extra_id_48>",
|
52 |
+
"<extra_id_49>",
|
53 |
+
"<extra_id_50>",
|
54 |
+
"<extra_id_51>",
|
55 |
+
"<extra_id_52>",
|
56 |
+
"<extra_id_53>",
|
57 |
+
"<extra_id_54>",
|
58 |
+
"<extra_id_55>",
|
59 |
+
"<extra_id_56>",
|
60 |
+
"<extra_id_57>",
|
61 |
+
"<extra_id_58>",
|
62 |
+
"<extra_id_59>",
|
63 |
+
"<extra_id_60>",
|
64 |
+
"<extra_id_61>",
|
65 |
+
"<extra_id_62>",
|
66 |
+
"<extra_id_63>",
|
67 |
+
"<extra_id_64>",
|
68 |
+
"<extra_id_65>",
|
69 |
+
"<extra_id_66>",
|
70 |
+
"<extra_id_67>",
|
71 |
+
"<extra_id_68>",
|
72 |
+
"<extra_id_69>",
|
73 |
+
"<extra_id_70>",
|
74 |
+
"<extra_id_71>",
|
75 |
+
"<extra_id_72>",
|
76 |
+
"<extra_id_73>",
|
77 |
+
"<extra_id_74>",
|
78 |
+
"<extra_id_75>",
|
79 |
+
"<extra_id_76>",
|
80 |
+
"<extra_id_77>",
|
81 |
+
"<extra_id_78>",
|
82 |
+
"<extra_id_79>",
|
83 |
+
"<extra_id_80>",
|
84 |
+
"<extra_id_81>",
|
85 |
+
"<extra_id_82>",
|
86 |
+
"<extra_id_83>",
|
87 |
+
"<extra_id_84>",
|
88 |
+
"<extra_id_85>",
|
89 |
+
"<extra_id_86>",
|
90 |
+
"<extra_id_87>",
|
91 |
+
"<extra_id_88>",
|
92 |
+
"<extra_id_89>",
|
93 |
+
"<extra_id_90>",
|
94 |
+
"<extra_id_91>",
|
95 |
+
"<extra_id_92>",
|
96 |
+
"<extra_id_93>",
|
97 |
+
"<extra_id_94>",
|
98 |
+
"<extra_id_95>",
|
99 |
+
"<extra_id_96>",
|
100 |
+
"<extra_id_97>",
|
101 |
+
"<extra_id_98>",
|
102 |
+
"<extra_id_99>",
|
103 |
+
"<extra_id_100>",
|
104 |
+
"<extra_id_101>",
|
105 |
+
"<extra_id_102>",
|
106 |
+
"<extra_id_103>",
|
107 |
+
"<extra_id_104>",
|
108 |
+
"<extra_id_105>",
|
109 |
+
"<extra_id_106>",
|
110 |
+
"<extra_id_107>",
|
111 |
+
"<extra_id_108>",
|
112 |
+
"<extra_id_109>",
|
113 |
+
"<extra_id_110>",
|
114 |
+
"<extra_id_111>",
|
115 |
+
"<extra_id_112>",
|
116 |
+
"<extra_id_113>",
|
117 |
+
"<extra_id_114>",
|
118 |
+
"<extra_id_115>",
|
119 |
+
"<extra_id_116>",
|
120 |
+
"<extra_id_117>",
|
121 |
+
"<extra_id_118>",
|
122 |
+
"<extra_id_119>",
|
123 |
+
"<extra_id_120>",
|
124 |
+
"<extra_id_121>",
|
125 |
+
"<extra_id_122>",
|
126 |
+
"<extra_id_123>",
|
127 |
+
"<extra_id_124>",
|
128 |
+
"<extra_id_125>",
|
129 |
+
"<extra_id_126>",
|
130 |
+
"<extra_id_127>",
|
131 |
+
"<extra_id_128>",
|
132 |
+
"<extra_id_129>",
|
133 |
+
"<extra_id_130>",
|
134 |
+
"<extra_id_131>",
|
135 |
+
"<extra_id_132>",
|
136 |
+
"<extra_id_133>",
|
137 |
+
"<extra_id_134>",
|
138 |
+
"<extra_id_135>",
|
139 |
+
"<extra_id_136>",
|
140 |
+
"<extra_id_137>",
|
141 |
+
"<extra_id_138>",
|
142 |
+
"<extra_id_139>",
|
143 |
+
"<extra_id_140>",
|
144 |
+
"<extra_id_141>",
|
145 |
+
"<extra_id_142>",
|
146 |
+
"<extra_id_143>",
|
147 |
+
"<extra_id_144>",
|
148 |
+
"<extra_id_145>",
|
149 |
+
"<extra_id_146>",
|
150 |
+
"<extra_id_147>",
|
151 |
+
"<extra_id_148>",
|
152 |
+
"<extra_id_149>",
|
153 |
+
"<extra_id_150>",
|
154 |
+
"<extra_id_151>",
|
155 |
+
"<extra_id_152>",
|
156 |
+
"<extra_id_153>",
|
157 |
+
"<extra_id_154>",
|
158 |
+
"<extra_id_155>",
|
159 |
+
"<extra_id_156>",
|
160 |
+
"<extra_id_157>",
|
161 |
+
"<extra_id_158>",
|
162 |
+
"<extra_id_159>",
|
163 |
+
"<extra_id_160>",
|
164 |
+
"<extra_id_161>",
|
165 |
+
"<extra_id_162>",
|
166 |
+
"<extra_id_163>",
|
167 |
+
"<extra_id_164>",
|
168 |
+
"<extra_id_165>",
|
169 |
+
"<extra_id_166>",
|
170 |
+
"<extra_id_167>",
|
171 |
+
"<extra_id_168>",
|
172 |
+
"<extra_id_169>",
|
173 |
+
"<extra_id_170>",
|
174 |
+
"<extra_id_171>",
|
175 |
+
"<extra_id_172>",
|
176 |
+
"<extra_id_173>",
|
177 |
+
"<extra_id_174>",
|
178 |
+
"<extra_id_175>",
|
179 |
+
"<extra_id_176>",
|
180 |
+
"<extra_id_177>",
|
181 |
+
"<extra_id_178>",
|
182 |
+
"<extra_id_179>",
|
183 |
+
"<extra_id_180>",
|
184 |
+
"<extra_id_181>",
|
185 |
+
"<extra_id_182>",
|
186 |
+
"<extra_id_183>",
|
187 |
+
"<extra_id_184>",
|
188 |
+
"<extra_id_185>",
|
189 |
+
"<extra_id_186>",
|
190 |
+
"<extra_id_187>",
|
191 |
+
"<extra_id_188>",
|
192 |
+
"<extra_id_189>",
|
193 |
+
"<extra_id_190>",
|
194 |
+
"<extra_id_191>",
|
195 |
+
"<extra_id_192>",
|
196 |
+
"<extra_id_193>",
|
197 |
+
"<extra_id_194>",
|
198 |
+
"<extra_id_195>",
|
199 |
+
"<extra_id_196>",
|
200 |
+
"<extra_id_197>",
|
201 |
+
"<extra_id_198>",
|
202 |
+
"<extra_id_199>",
|
203 |
+
"<extra_id_200>",
|
204 |
+
"<extra_id_201>",
|
205 |
+
"<extra_id_202>",
|
206 |
+
"<extra_id_203>",
|
207 |
+
"<extra_id_204>",
|
208 |
+
"<extra_id_205>",
|
209 |
+
"<extra_id_206>",
|
210 |
+
"<extra_id_207>",
|
211 |
+
"<extra_id_208>",
|
212 |
+
"<extra_id_209>",
|
213 |
+
"<extra_id_210>",
|
214 |
+
"<extra_id_211>",
|
215 |
+
"<extra_id_212>",
|
216 |
+
"<extra_id_213>",
|
217 |
+
"<extra_id_214>",
|
218 |
+
"<extra_id_215>",
|
219 |
+
"<extra_id_216>",
|
220 |
+
"<extra_id_217>",
|
221 |
+
"<extra_id_218>",
|
222 |
+
"<extra_id_219>",
|
223 |
+
"<extra_id_220>",
|
224 |
+
"<extra_id_221>",
|
225 |
+
"<extra_id_222>",
|
226 |
+
"<extra_id_223>",
|
227 |
+
"<extra_id_224>",
|
228 |
+
"<extra_id_225>",
|
229 |
+
"<extra_id_226>",
|
230 |
+
"<extra_id_227>",
|
231 |
+
"<extra_id_228>",
|
232 |
+
"<extra_id_229>",
|
233 |
+
"<extra_id_230>",
|
234 |
+
"<extra_id_231>",
|
235 |
+
"<extra_id_232>",
|
236 |
+
"<extra_id_233>",
|
237 |
+
"<extra_id_234>",
|
238 |
+
"<extra_id_235>",
|
239 |
+
"<extra_id_236>",
|
240 |
+
"<extra_id_237>",
|
241 |
+
"<extra_id_238>",
|
242 |
+
"<extra_id_239>",
|
243 |
+
"<extra_id_240>",
|
244 |
+
"<extra_id_241>",
|
245 |
+
"<extra_id_242>",
|
246 |
+
"<extra_id_243>",
|
247 |
+
"<extra_id_244>",
|
248 |
+
"<extra_id_245>",
|
249 |
+
"<extra_id_246>",
|
250 |
+
"<extra_id_247>",
|
251 |
+
"<extra_id_248>",
|
252 |
+
"<extra_id_249>",
|
253 |
+
"<extra_id_250>",
|
254 |
+
"<extra_id_251>",
|
255 |
+
"<extra_id_252>",
|
256 |
+
"<extra_id_253>",
|
257 |
+
"<extra_id_254>",
|
258 |
+
"<extra_id_255>"
|
259 |
+
],
|
260 |
+
"cls_token": "<cls>",
|
261 |
+
"eos_token": "</s>",
|
262 |
+
"mask_token": "<mask>",
|
263 |
+
"pad_token": "<pad>",
|
264 |
+
"sep_token": "<sep>",
|
265 |
+
"unk_token": "<unk>"
|
266 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,2367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<cls>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "</s>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "<mask>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "<pad>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"4": {
|
36 |
+
"content": "<sep>",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
},
|
43 |
+
"5": {
|
44 |
+
"content": "<unk>",
|
45 |
+
"lstrip": false,
|
46 |
+
"normalized": false,
|
47 |
+
"rstrip": false,
|
48 |
+
"single_word": false,
|
49 |
+
"special": true
|
50 |
+
},
|
51 |
+
"6": {
|
52 |
+
"content": "<extra_id_0>",
|
53 |
+
"lstrip": false,
|
54 |
+
"normalized": false,
|
55 |
+
"rstrip": false,
|
56 |
+
"single_word": false,
|
57 |
+
"special": true
|
58 |
+
},
|
59 |
+
"7": {
|
60 |
+
"content": "<extra_id_1>",
|
61 |
+
"lstrip": false,
|
62 |
+
"normalized": false,
|
63 |
+
"rstrip": false,
|
64 |
+
"single_word": false,
|
65 |
+
"special": true
|
66 |
+
},
|
67 |
+
"8": {
|
68 |
+
"content": "<extra_id_2>",
|
69 |
+
"lstrip": false,
|
70 |
+
"normalized": false,
|
71 |
+
"rstrip": false,
|
72 |
+
"single_word": false,
|
73 |
+
"special": true
|
74 |
+
},
|
75 |
+
"9": {
|
76 |
+
"content": "<extra_id_3>",
|
77 |
+
"lstrip": false,
|
78 |
+
"normalized": false,
|
79 |
+
"rstrip": false,
|
80 |
+
"single_word": false,
|
81 |
+
"special": true
|
82 |
+
},
|
83 |
+
"10": {
|
84 |
+
"content": "<extra_id_4>",
|
85 |
+
"lstrip": false,
|
86 |
+
"normalized": false,
|
87 |
+
"rstrip": false,
|
88 |
+
"single_word": false,
|
89 |
+
"special": true
|
90 |
+
},
|
91 |
+
"11": {
|
92 |
+
"content": "<extra_id_5>",
|
93 |
+
"lstrip": false,
|
94 |
+
"normalized": false,
|
95 |
+
"rstrip": false,
|
96 |
+
"single_word": false,
|
97 |
+
"special": true
|
98 |
+
},
|
99 |
+
"12": {
|
100 |
+
"content": "<extra_id_6>",
|
101 |
+
"lstrip": false,
|
102 |
+
"normalized": false,
|
103 |
+
"rstrip": false,
|
104 |
+
"single_word": false,
|
105 |
+
"special": true
|
106 |
+
},
|
107 |
+
"13": {
|
108 |
+
"content": "<extra_id_7>",
|
109 |
+
"lstrip": false,
|
110 |
+
"normalized": false,
|
111 |
+
"rstrip": false,
|
112 |
+
"single_word": false,
|
113 |
+
"special": true
|
114 |
+
},
|
115 |
+
"14": {
|
116 |
+
"content": "<extra_id_8>",
|
117 |
+
"lstrip": false,
|
118 |
+
"normalized": false,
|
119 |
+
"rstrip": false,
|
120 |
+
"single_word": false,
|
121 |
+
"special": true
|
122 |
+
},
|
123 |
+
"15": {
|
124 |
+
"content": "<extra_id_9>",
|
125 |
+
"lstrip": false,
|
126 |
+
"normalized": false,
|
127 |
+
"rstrip": false,
|
128 |
+
"single_word": false,
|
129 |
+
"special": true
|
130 |
+
},
|
131 |
+
"16": {
|
132 |
+
"content": "<extra_id_10>",
|
133 |
+
"lstrip": false,
|
134 |
+
"normalized": false,
|
135 |
+
"rstrip": false,
|
136 |
+
"single_word": false,
|
137 |
+
"special": true
|
138 |
+
},
|
139 |
+
"17": {
|
140 |
+
"content": "<extra_id_11>",
|
141 |
+
"lstrip": false,
|
142 |
+
"normalized": false,
|
143 |
+
"rstrip": false,
|
144 |
+
"single_word": false,
|
145 |
+
"special": true
|
146 |
+
},
|
147 |
+
"18": {
|
148 |
+
"content": "<extra_id_12>",
|
149 |
+
"lstrip": false,
|
150 |
+
"normalized": false,
|
151 |
+
"rstrip": false,
|
152 |
+
"single_word": false,
|
153 |
+
"special": true
|
154 |
+
},
|
155 |
+
"19": {
|
156 |
+
"content": "<extra_id_13>",
|
157 |
+
"lstrip": false,
|
158 |
+
"normalized": false,
|
159 |
+
"rstrip": false,
|
160 |
+
"single_word": false,
|
161 |
+
"special": true
|
162 |
+
},
|
163 |
+
"20": {
|
164 |
+
"content": "<extra_id_14>",
|
165 |
+
"lstrip": false,
|
166 |
+
"normalized": false,
|
167 |
+
"rstrip": false,
|
168 |
+
"single_word": false,
|
169 |
+
"special": true
|
170 |
+
},
|
171 |
+
"21": {
|
172 |
+
"content": "<extra_id_15>",
|
173 |
+
"lstrip": false,
|
174 |
+
"normalized": false,
|
175 |
+
"rstrip": false,
|
176 |
+
"single_word": false,
|
177 |
+
"special": true
|
178 |
+
},
|
179 |
+
"22": {
|
180 |
+
"content": "<extra_id_16>",
|
181 |
+
"lstrip": false,
|
182 |
+
"normalized": false,
|
183 |
+
"rstrip": false,
|
184 |
+
"single_word": false,
|
185 |
+
"special": true
|
186 |
+
},
|
187 |
+
"23": {
|
188 |
+
"content": "<extra_id_17>",
|
189 |
+
"lstrip": false,
|
190 |
+
"normalized": false,
|
191 |
+
"rstrip": false,
|
192 |
+
"single_word": false,
|
193 |
+
"special": true
|
194 |
+
},
|
195 |
+
"24": {
|
196 |
+
"content": "<extra_id_18>",
|
197 |
+
"lstrip": false,
|
198 |
+
"normalized": false,
|
199 |
+
"rstrip": false,
|
200 |
+
"single_word": false,
|
201 |
+
"special": true
|
202 |
+
},
|
203 |
+
"25": {
|
204 |
+
"content": "<extra_id_19>",
|
205 |
+
"lstrip": false,
|
206 |
+
"normalized": false,
|
207 |
+
"rstrip": false,
|
208 |
+
"single_word": false,
|
209 |
+
"special": true
|
210 |
+
},
|
211 |
+
"26": {
|
212 |
+
"content": "<extra_id_20>",
|
213 |
+
"lstrip": false,
|
214 |
+
"normalized": false,
|
215 |
+
"rstrip": false,
|
216 |
+
"single_word": false,
|
217 |
+
"special": true
|
218 |
+
},
|
219 |
+
"27": {
|
220 |
+
"content": "<extra_id_21>",
|
221 |
+
"lstrip": false,
|
222 |
+
"normalized": false,
|
223 |
+
"rstrip": false,
|
224 |
+
"single_word": false,
|
225 |
+
"special": true
|
226 |
+
},
|
227 |
+
"28": {
|
228 |
+
"content": "<extra_id_22>",
|
229 |
+
"lstrip": false,
|
230 |
+
"normalized": false,
|
231 |
+
"rstrip": false,
|
232 |
+
"single_word": false,
|
233 |
+
"special": true
|
234 |
+
},
|
235 |
+
"29": {
|
236 |
+
"content": "<extra_id_23>",
|
237 |
+
"lstrip": false,
|
238 |
+
"normalized": false,
|
239 |
+
"rstrip": false,
|
240 |
+
"single_word": false,
|
241 |
+
"special": true
|
242 |
+
},
|
243 |
+
"30": {
|
244 |
+
"content": "<extra_id_24>",
|
245 |
+
"lstrip": false,
|
246 |
+
"normalized": false,
|
247 |
+
"rstrip": false,
|
248 |
+
"single_word": false,
|
249 |
+
"special": true
|
250 |
+
},
|
251 |
+
"31": {
|
252 |
+
"content": "<extra_id_25>",
|
253 |
+
"lstrip": false,
|
254 |
+
"normalized": false,
|
255 |
+
"rstrip": false,
|
256 |
+
"single_word": false,
|
257 |
+
"special": true
|
258 |
+
},
|
259 |
+
"32": {
|
260 |
+
"content": "<extra_id_26>",
|
261 |
+
"lstrip": false,
|
262 |
+
"normalized": false,
|
263 |
+
"rstrip": false,
|
264 |
+
"single_word": false,
|
265 |
+
"special": true
|
266 |
+
},
|
267 |
+
"33": {
|
268 |
+
"content": "<extra_id_27>",
|
269 |
+
"lstrip": false,
|
270 |
+
"normalized": false,
|
271 |
+
"rstrip": false,
|
272 |
+
"single_word": false,
|
273 |
+
"special": true
|
274 |
+
},
|
275 |
+
"34": {
|
276 |
+
"content": "<extra_id_28>",
|
277 |
+
"lstrip": false,
|
278 |
+
"normalized": false,
|
279 |
+
"rstrip": false,
|
280 |
+
"single_word": false,
|
281 |
+
"special": true
|
282 |
+
},
|
283 |
+
"35": {
|
284 |
+
"content": "<extra_id_29>",
|
285 |
+
"lstrip": false,
|
286 |
+
"normalized": false,
|
287 |
+
"rstrip": false,
|
288 |
+
"single_word": false,
|
289 |
+
"special": true
|
290 |
+
},
|
291 |
+
"36": {
|
292 |
+
"content": "<extra_id_30>",
|
293 |
+
"lstrip": false,
|
294 |
+
"normalized": false,
|
295 |
+
"rstrip": false,
|
296 |
+
"single_word": false,
|
297 |
+
"special": true
|
298 |
+
},
|
299 |
+
"37": {
|
300 |
+
"content": "<extra_id_31>",
|
301 |
+
"lstrip": false,
|
302 |
+
"normalized": false,
|
303 |
+
"rstrip": false,
|
304 |
+
"single_word": false,
|
305 |
+
"special": true
|
306 |
+
},
|
307 |
+
"38": {
|
308 |
+
"content": "<extra_id_32>",
|
309 |
+
"lstrip": false,
|
310 |
+
"normalized": false,
|
311 |
+
"rstrip": false,
|
312 |
+
"single_word": false,
|
313 |
+
"special": true
|
314 |
+
},
|
315 |
+
"39": {
|
316 |
+
"content": "<extra_id_33>",
|
317 |
+
"lstrip": false,
|
318 |
+
"normalized": false,
|
319 |
+
"rstrip": false,
|
320 |
+
"single_word": false,
|
321 |
+
"special": true
|
322 |
+
},
|
323 |
+
"40": {
|
324 |
+
"content": "<extra_id_34>",
|
325 |
+
"lstrip": false,
|
326 |
+
"normalized": false,
|
327 |
+
"rstrip": false,
|
328 |
+
"single_word": false,
|
329 |
+
"special": true
|
330 |
+
},
|
331 |
+
"41": {
|
332 |
+
"content": "<extra_id_35>",
|
333 |
+
"lstrip": false,
|
334 |
+
"normalized": false,
|
335 |
+
"rstrip": false,
|
336 |
+
"single_word": false,
|
337 |
+
"special": true
|
338 |
+
},
|
339 |
+
"42": {
|
340 |
+
"content": "<extra_id_36>",
|
341 |
+
"lstrip": false,
|
342 |
+
"normalized": false,
|
343 |
+
"rstrip": false,
|
344 |
+
"single_word": false,
|
345 |
+
"special": true
|
346 |
+
},
|
347 |
+
"43": {
|
348 |
+
"content": "<extra_id_37>",
|
349 |
+
"lstrip": false,
|
350 |
+
"normalized": false,
|
351 |
+
"rstrip": false,
|
352 |
+
"single_word": false,
|
353 |
+
"special": true
|
354 |
+
},
|
355 |
+
"44": {
|
356 |
+
"content": "<extra_id_38>",
|
357 |
+
"lstrip": false,
|
358 |
+
"normalized": false,
|
359 |
+
"rstrip": false,
|
360 |
+
"single_word": false,
|
361 |
+
"special": true
|
362 |
+
},
|
363 |
+
"45": {
|
364 |
+
"content": "<extra_id_39>",
|
365 |
+
"lstrip": false,
|
366 |
+
"normalized": false,
|
367 |
+
"rstrip": false,
|
368 |
+
"single_word": false,
|
369 |
+
"special": true
|
370 |
+
},
|
371 |
+
"46": {
|
372 |
+
"content": "<extra_id_40>",
|
373 |
+
"lstrip": false,
|
374 |
+
"normalized": false,
|
375 |
+
"rstrip": false,
|
376 |
+
"single_word": false,
|
377 |
+
"special": true
|
378 |
+
},
|
379 |
+
"47": {
|
380 |
+
"content": "<extra_id_41>",
|
381 |
+
"lstrip": false,
|
382 |
+
"normalized": false,
|
383 |
+
"rstrip": false,
|
384 |
+
"single_word": false,
|
385 |
+
"special": true
|
386 |
+
},
|
387 |
+
"48": {
|
388 |
+
"content": "<extra_id_42>",
|
389 |
+
"lstrip": false,
|
390 |
+
"normalized": false,
|
391 |
+
"rstrip": false,
|
392 |
+
"single_word": false,
|
393 |
+
"special": true
|
394 |
+
},
|
395 |
+
"49": {
|
396 |
+
"content": "<extra_id_43>",
|
397 |
+
"lstrip": false,
|
398 |
+
"normalized": false,
|
399 |
+
"rstrip": false,
|
400 |
+
"single_word": false,
|
401 |
+
"special": true
|
402 |
+
},
|
403 |
+
"50": {
|
404 |
+
"content": "<extra_id_44>",
|
405 |
+
"lstrip": false,
|
406 |
+
"normalized": false,
|
407 |
+
"rstrip": false,
|
408 |
+
"single_word": false,
|
409 |
+
"special": true
|
410 |
+
},
|
411 |
+
"51": {
|
412 |
+
"content": "<extra_id_45>",
|
413 |
+
"lstrip": false,
|
414 |
+
"normalized": false,
|
415 |
+
"rstrip": false,
|
416 |
+
"single_word": false,
|
417 |
+
"special": true
|
418 |
+
},
|
419 |
+
"52": {
|
420 |
+
"content": "<extra_id_46>",
|
421 |
+
"lstrip": false,
|
422 |
+
"normalized": false,
|
423 |
+
"rstrip": false,
|
424 |
+
"single_word": false,
|
425 |
+
"special": true
|
426 |
+
},
|
427 |
+
"53": {
|
428 |
+
"content": "<extra_id_47>",
|
429 |
+
"lstrip": false,
|
430 |
+
"normalized": false,
|
431 |
+
"rstrip": false,
|
432 |
+
"single_word": false,
|
433 |
+
"special": true
|
434 |
+
},
|
435 |
+
"54": {
|
436 |
+
"content": "<extra_id_48>",
|
437 |
+
"lstrip": false,
|
438 |
+
"normalized": false,
|
439 |
+
"rstrip": false,
|
440 |
+
"single_word": false,
|
441 |
+
"special": true
|
442 |
+
},
|
443 |
+
"55": {
|
444 |
+
"content": "<extra_id_49>",
|
445 |
+
"lstrip": false,
|
446 |
+
"normalized": false,
|
447 |
+
"rstrip": false,
|
448 |
+
"single_word": false,
|
449 |
+
"special": true
|
450 |
+
},
|
451 |
+
"56": {
|
452 |
+
"content": "<extra_id_50>",
|
453 |
+
"lstrip": false,
|
454 |
+
"normalized": false,
|
455 |
+
"rstrip": false,
|
456 |
+
"single_word": false,
|
457 |
+
"special": true
|
458 |
+
},
|
459 |
+
"57": {
|
460 |
+
"content": "<extra_id_51>",
|
461 |
+
"lstrip": false,
|
462 |
+
"normalized": false,
|
463 |
+
"rstrip": false,
|
464 |
+
"single_word": false,
|
465 |
+
"special": true
|
466 |
+
},
|
467 |
+
"58": {
|
468 |
+
"content": "<extra_id_52>",
|
469 |
+
"lstrip": false,
|
470 |
+
"normalized": false,
|
471 |
+
"rstrip": false,
|
472 |
+
"single_word": false,
|
473 |
+
"special": true
|
474 |
+
},
|
475 |
+
"59": {
|
476 |
+
"content": "<extra_id_53>",
|
477 |
+
"lstrip": false,
|
478 |
+
"normalized": false,
|
479 |
+
"rstrip": false,
|
480 |
+
"single_word": false,
|
481 |
+
"special": true
|
482 |
+
},
|
483 |
+
"60": {
|
484 |
+
"content": "<extra_id_54>",
|
485 |
+
"lstrip": false,
|
486 |
+
"normalized": false,
|
487 |
+
"rstrip": false,
|
488 |
+
"single_word": false,
|
489 |
+
"special": true
|
490 |
+
},
|
491 |
+
"61": {
|
492 |
+
"content": "<extra_id_55>",
|
493 |
+
"lstrip": false,
|
494 |
+
"normalized": false,
|
495 |
+
"rstrip": false,
|
496 |
+
"single_word": false,
|
497 |
+
"special": true
|
498 |
+
},
|
499 |
+
"62": {
|
500 |
+
"content": "<extra_id_56>",
|
501 |
+
"lstrip": false,
|
502 |
+
"normalized": false,
|
503 |
+
"rstrip": false,
|
504 |
+
"single_word": false,
|
505 |
+
"special": true
|
506 |
+
},
|
507 |
+
"63": {
|
508 |
+
"content": "<extra_id_57>",
|
509 |
+
"lstrip": false,
|
510 |
+
"normalized": false,
|
511 |
+
"rstrip": false,
|
512 |
+
"single_word": false,
|
513 |
+
"special": true
|
514 |
+
},
|
515 |
+
"64": {
|
516 |
+
"content": "<extra_id_58>",
|
517 |
+
"lstrip": false,
|
518 |
+
"normalized": false,
|
519 |
+
"rstrip": false,
|
520 |
+
"single_word": false,
|
521 |
+
"special": true
|
522 |
+
},
|
523 |
+
"65": {
|
524 |
+
"content": "<extra_id_59>",
|
525 |
+
"lstrip": false,
|
526 |
+
"normalized": false,
|
527 |
+
"rstrip": false,
|
528 |
+
"single_word": false,
|
529 |
+
"special": true
|
530 |
+
},
|
531 |
+
"66": {
|
532 |
+
"content": "<extra_id_60>",
|
533 |
+
"lstrip": false,
|
534 |
+
"normalized": false,
|
535 |
+
"rstrip": false,
|
536 |
+
"single_word": false,
|
537 |
+
"special": true
|
538 |
+
},
|
539 |
+
"67": {
|
540 |
+
"content": "<extra_id_61>",
|
541 |
+
"lstrip": false,
|
542 |
+
"normalized": false,
|
543 |
+
"rstrip": false,
|
544 |
+
"single_word": false,
|
545 |
+
"special": true
|
546 |
+
},
|
547 |
+
"68": {
|
548 |
+
"content": "<extra_id_62>",
|
549 |
+
"lstrip": false,
|
550 |
+
"normalized": false,
|
551 |
+
"rstrip": false,
|
552 |
+
"single_word": false,
|
553 |
+
"special": true
|
554 |
+
},
|
555 |
+
"69": {
|
556 |
+
"content": "<extra_id_63>",
|
557 |
+
"lstrip": false,
|
558 |
+
"normalized": false,
|
559 |
+
"rstrip": false,
|
560 |
+
"single_word": false,
|
561 |
+
"special": true
|
562 |
+
},
|
563 |
+
"70": {
|
564 |
+
"content": "<extra_id_64>",
|
565 |
+
"lstrip": false,
|
566 |
+
"normalized": false,
|
567 |
+
"rstrip": false,
|
568 |
+
"single_word": false,
|
569 |
+
"special": true
|
570 |
+
},
|
571 |
+
"71": {
|
572 |
+
"content": "<extra_id_65>",
|
573 |
+
"lstrip": false,
|
574 |
+
"normalized": false,
|
575 |
+
"rstrip": false,
|
576 |
+
"single_word": false,
|
577 |
+
"special": true
|
578 |
+
},
|
579 |
+
"72": {
|
580 |
+
"content": "<extra_id_66>",
|
581 |
+
"lstrip": false,
|
582 |
+
"normalized": false,
|
583 |
+
"rstrip": false,
|
584 |
+
"single_word": false,
|
585 |
+
"special": true
|
586 |
+
},
|
587 |
+
"73": {
|
588 |
+
"content": "<extra_id_67>",
|
589 |
+
"lstrip": false,
|
590 |
+
"normalized": false,
|
591 |
+
"rstrip": false,
|
592 |
+
"single_word": false,
|
593 |
+
"special": true
|
594 |
+
},
|
595 |
+
"74": {
|
596 |
+
"content": "<extra_id_68>",
|
597 |
+
"lstrip": false,
|
598 |
+
"normalized": false,
|
599 |
+
"rstrip": false,
|
600 |
+
"single_word": false,
|
601 |
+
"special": true
|
602 |
+
},
|
603 |
+
"75": {
|
604 |
+
"content": "<extra_id_69>",
|
605 |
+
"lstrip": false,
|
606 |
+
"normalized": false,
|
607 |
+
"rstrip": false,
|
608 |
+
"single_word": false,
|
609 |
+
"special": true
|
610 |
+
},
|
611 |
+
"76": {
|
612 |
+
"content": "<extra_id_70>",
|
613 |
+
"lstrip": false,
|
614 |
+
"normalized": false,
|
615 |
+
"rstrip": false,
|
616 |
+
"single_word": false,
|
617 |
+
"special": true
|
618 |
+
},
|
619 |
+
"77": {
|
620 |
+
"content": "<extra_id_71>",
|
621 |
+
"lstrip": false,
|
622 |
+
"normalized": false,
|
623 |
+
"rstrip": false,
|
624 |
+
"single_word": false,
|
625 |
+
"special": true
|
626 |
+
},
|
627 |
+
"78": {
|
628 |
+
"content": "<extra_id_72>",
|
629 |
+
"lstrip": false,
|
630 |
+
"normalized": false,
|
631 |
+
"rstrip": false,
|
632 |
+
"single_word": false,
|
633 |
+
"special": true
|
634 |
+
},
|
635 |
+
"79": {
|
636 |
+
"content": "<extra_id_73>",
|
637 |
+
"lstrip": false,
|
638 |
+
"normalized": false,
|
639 |
+
"rstrip": false,
|
640 |
+
"single_word": false,
|
641 |
+
"special": true
|
642 |
+
},
|
643 |
+
"80": {
|
644 |
+
"content": "<extra_id_74>",
|
645 |
+
"lstrip": false,
|
646 |
+
"normalized": false,
|
647 |
+
"rstrip": false,
|
648 |
+
"single_word": false,
|
649 |
+
"special": true
|
650 |
+
},
|
651 |
+
"81": {
|
652 |
+
"content": "<extra_id_75>",
|
653 |
+
"lstrip": false,
|
654 |
+
"normalized": false,
|
655 |
+
"rstrip": false,
|
656 |
+
"single_word": false,
|
657 |
+
"special": true
|
658 |
+
},
|
659 |
+
"82": {
|
660 |
+
"content": "<extra_id_76>",
|
661 |
+
"lstrip": false,
|
662 |
+
"normalized": false,
|
663 |
+
"rstrip": false,
|
664 |
+
"single_word": false,
|
665 |
+
"special": true
|
666 |
+
},
|
667 |
+
"83": {
|
668 |
+
"content": "<extra_id_77>",
|
669 |
+
"lstrip": false,
|
670 |
+
"normalized": false,
|
671 |
+
"rstrip": false,
|
672 |
+
"single_word": false,
|
673 |
+
"special": true
|
674 |
+
},
|
675 |
+
"84": {
|
676 |
+
"content": "<extra_id_78>",
|
677 |
+
"lstrip": false,
|
678 |
+
"normalized": false,
|
679 |
+
"rstrip": false,
|
680 |
+
"single_word": false,
|
681 |
+
"special": true
|
682 |
+
},
|
683 |
+
"85": {
|
684 |
+
"content": "<extra_id_79>",
|
685 |
+
"lstrip": false,
|
686 |
+
"normalized": false,
|
687 |
+
"rstrip": false,
|
688 |
+
"single_word": false,
|
689 |
+
"special": true
|
690 |
+
},
|
691 |
+
"86": {
|
692 |
+
"content": "<extra_id_80>",
|
693 |
+
"lstrip": false,
|
694 |
+
"normalized": false,
|
695 |
+
"rstrip": false,
|
696 |
+
"single_word": false,
|
697 |
+
"special": true
|
698 |
+
},
|
699 |
+
"87": {
|
700 |
+
"content": "<extra_id_81>",
|
701 |
+
"lstrip": false,
|
702 |
+
"normalized": false,
|
703 |
+
"rstrip": false,
|
704 |
+
"single_word": false,
|
705 |
+
"special": true
|
706 |
+
},
|
707 |
+
"88": {
|
708 |
+
"content": "<extra_id_82>",
|
709 |
+
"lstrip": false,
|
710 |
+
"normalized": false,
|
711 |
+
"rstrip": false,
|
712 |
+
"single_word": false,
|
713 |
+
"special": true
|
714 |
+
},
|
715 |
+
"89": {
|
716 |
+
"content": "<extra_id_83>",
|
717 |
+
"lstrip": false,
|
718 |
+
"normalized": false,
|
719 |
+
"rstrip": false,
|
720 |
+
"single_word": false,
|
721 |
+
"special": true
|
722 |
+
},
|
723 |
+
"90": {
|
724 |
+
"content": "<extra_id_84>",
|
725 |
+
"lstrip": false,
|
726 |
+
"normalized": false,
|
727 |
+
"rstrip": false,
|
728 |
+
"single_word": false,
|
729 |
+
"special": true
|
730 |
+
},
|
731 |
+
"91": {
|
732 |
+
"content": "<extra_id_85>",
|
733 |
+
"lstrip": false,
|
734 |
+
"normalized": false,
|
735 |
+
"rstrip": false,
|
736 |
+
"single_word": false,
|
737 |
+
"special": true
|
738 |
+
},
|
739 |
+
"92": {
|
740 |
+
"content": "<extra_id_86>",
|
741 |
+
"lstrip": false,
|
742 |
+
"normalized": false,
|
743 |
+
"rstrip": false,
|
744 |
+
"single_word": false,
|
745 |
+
"special": true
|
746 |
+
},
|
747 |
+
"93": {
|
748 |
+
"content": "<extra_id_87>",
|
749 |
+
"lstrip": false,
|
750 |
+
"normalized": false,
|
751 |
+
"rstrip": false,
|
752 |
+
"single_word": false,
|
753 |
+
"special": true
|
754 |
+
},
|
755 |
+
"94": {
|
756 |
+
"content": "<extra_id_88>",
|
757 |
+
"lstrip": false,
|
758 |
+
"normalized": false,
|
759 |
+
"rstrip": false,
|
760 |
+
"single_word": false,
|
761 |
+
"special": true
|
762 |
+
},
|
763 |
+
"95": {
|
764 |
+
"content": "<extra_id_89>",
|
765 |
+
"lstrip": false,
|
766 |
+
"normalized": false,
|
767 |
+
"rstrip": false,
|
768 |
+
"single_word": false,
|
769 |
+
"special": true
|
770 |
+
},
|
771 |
+
"96": {
|
772 |
+
"content": "<extra_id_90>",
|
773 |
+
"lstrip": false,
|
774 |
+
"normalized": false,
|
775 |
+
"rstrip": false,
|
776 |
+
"single_word": false,
|
777 |
+
"special": true
|
778 |
+
},
|
779 |
+
"97": {
|
780 |
+
"content": "<extra_id_91>",
|
781 |
+
"lstrip": false,
|
782 |
+
"normalized": false,
|
783 |
+
"rstrip": false,
|
784 |
+
"single_word": false,
|
785 |
+
"special": true
|
786 |
+
},
|
787 |
+
"98": {
|
788 |
+
"content": "<extra_id_92>",
|
789 |
+
"lstrip": false,
|
790 |
+
"normalized": false,
|
791 |
+
"rstrip": false,
|
792 |
+
"single_word": false,
|
793 |
+
"special": true
|
794 |
+
},
|
795 |
+
"99": {
|
796 |
+
"content": "<extra_id_93>",
|
797 |
+
"lstrip": false,
|
798 |
+
"normalized": false,
|
799 |
+
"rstrip": false,
|
800 |
+
"single_word": false,
|
801 |
+
"special": true
|
802 |
+
},
|
803 |
+
"100": {
|
804 |
+
"content": "<extra_id_94>",
|
805 |
+
"lstrip": false,
|
806 |
+
"normalized": false,
|
807 |
+
"rstrip": false,
|
808 |
+
"single_word": false,
|
809 |
+
"special": true
|
810 |
+
},
|
811 |
+
"101": {
|
812 |
+
"content": "<extra_id_95>",
|
813 |
+
"lstrip": false,
|
814 |
+
"normalized": false,
|
815 |
+
"rstrip": false,
|
816 |
+
"single_word": false,
|
817 |
+
"special": true
|
818 |
+
},
|
819 |
+
"102": {
|
820 |
+
"content": "<extra_id_96>",
|
821 |
+
"lstrip": false,
|
822 |
+
"normalized": false,
|
823 |
+
"rstrip": false,
|
824 |
+
"single_word": false,
|
825 |
+
"special": true
|
826 |
+
},
|
827 |
+
"103": {
|
828 |
+
"content": "<extra_id_97>",
|
829 |
+
"lstrip": false,
|
830 |
+
"normalized": false,
|
831 |
+
"rstrip": false,
|
832 |
+
"single_word": false,
|
833 |
+
"special": true
|
834 |
+
},
|
835 |
+
"104": {
|
836 |
+
"content": "<extra_id_98>",
|
837 |
+
"lstrip": false,
|
838 |
+
"normalized": false,
|
839 |
+
"rstrip": false,
|
840 |
+
"single_word": false,
|
841 |
+
"special": true
|
842 |
+
},
|
843 |
+
"105": {
|
844 |
+
"content": "<extra_id_99>",
|
845 |
+
"lstrip": false,
|
846 |
+
"normalized": false,
|
847 |
+
"rstrip": false,
|
848 |
+
"single_word": false,
|
849 |
+
"special": true
|
850 |
+
},
|
851 |
+
"106": {
|
852 |
+
"content": "<extra_id_100>",
|
853 |
+
"lstrip": false,
|
854 |
+
"normalized": false,
|
855 |
+
"rstrip": false,
|
856 |
+
"single_word": false,
|
857 |
+
"special": true
|
858 |
+
},
|
859 |
+
"107": {
|
860 |
+
"content": "<extra_id_101>",
|
861 |
+
"lstrip": false,
|
862 |
+
"normalized": false,
|
863 |
+
"rstrip": false,
|
864 |
+
"single_word": false,
|
865 |
+
"special": true
|
866 |
+
},
|
867 |
+
"108": {
|
868 |
+
"content": "<extra_id_102>",
|
869 |
+
"lstrip": false,
|
870 |
+
"normalized": false,
|
871 |
+
"rstrip": false,
|
872 |
+
"single_word": false,
|
873 |
+
"special": true
|
874 |
+
},
|
875 |
+
"109": {
|
876 |
+
"content": "<extra_id_103>",
|
877 |
+
"lstrip": false,
|
878 |
+
"normalized": false,
|
879 |
+
"rstrip": false,
|
880 |
+
"single_word": false,
|
881 |
+
"special": true
|
882 |
+
},
|
883 |
+
"110": {
|
884 |
+
"content": "<extra_id_104>",
|
885 |
+
"lstrip": false,
|
886 |
+
"normalized": false,
|
887 |
+
"rstrip": false,
|
888 |
+
"single_word": false,
|
889 |
+
"special": true
|
890 |
+
},
|
891 |
+
"111": {
|
892 |
+
"content": "<extra_id_105>",
|
893 |
+
"lstrip": false,
|
894 |
+
"normalized": false,
|
895 |
+
"rstrip": false,
|
896 |
+
"single_word": false,
|
897 |
+
"special": true
|
898 |
+
},
|
899 |
+
"112": {
|
900 |
+
"content": "<extra_id_106>",
|
901 |
+
"lstrip": false,
|
902 |
+
"normalized": false,
|
903 |
+
"rstrip": false,
|
904 |
+
"single_word": false,
|
905 |
+
"special": true
|
906 |
+
},
|
907 |
+
"113": {
|
908 |
+
"content": "<extra_id_107>",
|
909 |
+
"lstrip": false,
|
910 |
+
"normalized": false,
|
911 |
+
"rstrip": false,
|
912 |
+
"single_word": false,
|
913 |
+
"special": true
|
914 |
+
},
|
915 |
+
"114": {
|
916 |
+
"content": "<extra_id_108>",
|
917 |
+
"lstrip": false,
|
918 |
+
"normalized": false,
|
919 |
+
"rstrip": false,
|
920 |
+
"single_word": false,
|
921 |
+
"special": true
|
922 |
+
},
|
923 |
+
"115": {
|
924 |
+
"content": "<extra_id_109>",
|
925 |
+
"lstrip": false,
|
926 |
+
"normalized": false,
|
927 |
+
"rstrip": false,
|
928 |
+
"single_word": false,
|
929 |
+
"special": true
|
930 |
+
},
|
931 |
+
"116": {
|
932 |
+
"content": "<extra_id_110>",
|
933 |
+
"lstrip": false,
|
934 |
+
"normalized": false,
|
935 |
+
"rstrip": false,
|
936 |
+
"single_word": false,
|
937 |
+
"special": true
|
938 |
+
},
|
939 |
+
"117": {
|
940 |
+
"content": "<extra_id_111>",
|
941 |
+
"lstrip": false,
|
942 |
+
"normalized": false,
|
943 |
+
"rstrip": false,
|
944 |
+
"single_word": false,
|
945 |
+
"special": true
|
946 |
+
},
|
947 |
+
"118": {
|
948 |
+
"content": "<extra_id_112>",
|
949 |
+
"lstrip": false,
|
950 |
+
"normalized": false,
|
951 |
+
"rstrip": false,
|
952 |
+
"single_word": false,
|
953 |
+
"special": true
|
954 |
+
},
|
955 |
+
"119": {
|
956 |
+
"content": "<extra_id_113>",
|
957 |
+
"lstrip": false,
|
958 |
+
"normalized": false,
|
959 |
+
"rstrip": false,
|
960 |
+
"single_word": false,
|
961 |
+
"special": true
|
962 |
+
},
|
963 |
+
"120": {
|
964 |
+
"content": "<extra_id_114>",
|
965 |
+
"lstrip": false,
|
966 |
+
"normalized": false,
|
967 |
+
"rstrip": false,
|
968 |
+
"single_word": false,
|
969 |
+
"special": true
|
970 |
+
},
|
971 |
+
"121": {
|
972 |
+
"content": "<extra_id_115>",
|
973 |
+
"lstrip": false,
|
974 |
+
"normalized": false,
|
975 |
+
"rstrip": false,
|
976 |
+
"single_word": false,
|
977 |
+
"special": true
|
978 |
+
},
|
979 |
+
"122": {
|
980 |
+
"content": "<extra_id_116>",
|
981 |
+
"lstrip": false,
|
982 |
+
"normalized": false,
|
983 |
+
"rstrip": false,
|
984 |
+
"single_word": false,
|
985 |
+
"special": true
|
986 |
+
},
|
987 |
+
"123": {
|
988 |
+
"content": "<extra_id_117>",
|
989 |
+
"lstrip": false,
|
990 |
+
"normalized": false,
|
991 |
+
"rstrip": false,
|
992 |
+
"single_word": false,
|
993 |
+
"special": true
|
994 |
+
},
|
995 |
+
"124": {
|
996 |
+
"content": "<extra_id_118>",
|
997 |
+
"lstrip": false,
|
998 |
+
"normalized": false,
|
999 |
+
"rstrip": false,
|
1000 |
+
"single_word": false,
|
1001 |
+
"special": true
|
1002 |
+
},
|
1003 |
+
"125": {
|
1004 |
+
"content": "<extra_id_119>",
|
1005 |
+
"lstrip": false,
|
1006 |
+
"normalized": false,
|
1007 |
+
"rstrip": false,
|
1008 |
+
"single_word": false,
|
1009 |
+
"special": true
|
1010 |
+
},
|
1011 |
+
"126": {
|
1012 |
+
"content": "<extra_id_120>",
|
1013 |
+
"lstrip": false,
|
1014 |
+
"normalized": false,
|
1015 |
+
"rstrip": false,
|
1016 |
+
"single_word": false,
|
1017 |
+
"special": true
|
1018 |
+
},
|
1019 |
+
"127": {
|
1020 |
+
"content": "<extra_id_121>",
|
1021 |
+
"lstrip": false,
|
1022 |
+
"normalized": false,
|
1023 |
+
"rstrip": false,
|
1024 |
+
"single_word": false,
|
1025 |
+
"special": true
|
1026 |
+
},
|
1027 |
+
"128": {
|
1028 |
+
"content": "<extra_id_122>",
|
1029 |
+
"lstrip": false,
|
1030 |
+
"normalized": false,
|
1031 |
+
"rstrip": false,
|
1032 |
+
"single_word": false,
|
1033 |
+
"special": true
|
1034 |
+
},
|
1035 |
+
"129": {
|
1036 |
+
"content": "<extra_id_123>",
|
1037 |
+
"lstrip": false,
|
1038 |
+
"normalized": false,
|
1039 |
+
"rstrip": false,
|
1040 |
+
"single_word": false,
|
1041 |
+
"special": true
|
1042 |
+
},
|
1043 |
+
"130": {
|
1044 |
+
"content": "<extra_id_124>",
|
1045 |
+
"lstrip": false,
|
1046 |
+
"normalized": false,
|
1047 |
+
"rstrip": false,
|
1048 |
+
"single_word": false,
|
1049 |
+
"special": true
|
1050 |
+
},
|
1051 |
+
"131": {
|
1052 |
+
"content": "<extra_id_125>",
|
1053 |
+
"lstrip": false,
|
1054 |
+
"normalized": false,
|
1055 |
+
"rstrip": false,
|
1056 |
+
"single_word": false,
|
1057 |
+
"special": true
|
1058 |
+
},
|
1059 |
+
"132": {
|
1060 |
+
"content": "<extra_id_126>",
|
1061 |
+
"lstrip": false,
|
1062 |
+
"normalized": false,
|
1063 |
+
"rstrip": false,
|
1064 |
+
"single_word": false,
|
1065 |
+
"special": true
|
1066 |
+
},
|
1067 |
+
"133": {
|
1068 |
+
"content": "<extra_id_127>",
|
1069 |
+
"lstrip": false,
|
1070 |
+
"normalized": false,
|
1071 |
+
"rstrip": false,
|
1072 |
+
"single_word": false,
|
1073 |
+
"special": true
|
1074 |
+
},
|
1075 |
+
"134": {
|
1076 |
+
"content": "<extra_id_128>",
|
1077 |
+
"lstrip": false,
|
1078 |
+
"normalized": false,
|
1079 |
+
"rstrip": false,
|
1080 |
+
"single_word": false,
|
1081 |
+
"special": true
|
1082 |
+
},
|
1083 |
+
"135": {
|
1084 |
+
"content": "<extra_id_129>",
|
1085 |
+
"lstrip": false,
|
1086 |
+
"normalized": false,
|
1087 |
+
"rstrip": false,
|
1088 |
+
"single_word": false,
|
1089 |
+
"special": true
|
1090 |
+
},
|
1091 |
+
"136": {
|
1092 |
+
"content": "<extra_id_130>",
|
1093 |
+
"lstrip": false,
|
1094 |
+
"normalized": false,
|
1095 |
+
"rstrip": false,
|
1096 |
+
"single_word": false,
|
1097 |
+
"special": true
|
1098 |
+
},
|
1099 |
+
"137": {
|
1100 |
+
"content": "<extra_id_131>",
|
1101 |
+
"lstrip": false,
|
1102 |
+
"normalized": false,
|
1103 |
+
"rstrip": false,
|
1104 |
+
"single_word": false,
|
1105 |
+
"special": true
|
1106 |
+
},
|
1107 |
+
"138": {
|
1108 |
+
"content": "<extra_id_132>",
|
1109 |
+
"lstrip": false,
|
1110 |
+
"normalized": false,
|
1111 |
+
"rstrip": false,
|
1112 |
+
"single_word": false,
|
1113 |
+
"special": true
|
1114 |
+
},
|
1115 |
+
"139": {
|
1116 |
+
"content": "<extra_id_133>",
|
1117 |
+
"lstrip": false,
|
1118 |
+
"normalized": false,
|
1119 |
+
"rstrip": false,
|
1120 |
+
"single_word": false,
|
1121 |
+
"special": true
|
1122 |
+
},
|
1123 |
+
"140": {
|
1124 |
+
"content": "<extra_id_134>",
|
1125 |
+
"lstrip": false,
|
1126 |
+
"normalized": false,
|
1127 |
+
"rstrip": false,
|
1128 |
+
"single_word": false,
|
1129 |
+
"special": true
|
1130 |
+
},
|
1131 |
+
"141": {
|
1132 |
+
"content": "<extra_id_135>",
|
1133 |
+
"lstrip": false,
|
1134 |
+
"normalized": false,
|
1135 |
+
"rstrip": false,
|
1136 |
+
"single_word": false,
|
1137 |
+
"special": true
|
1138 |
+
},
|
1139 |
+
"142": {
|
1140 |
+
"content": "<extra_id_136>",
|
1141 |
+
"lstrip": false,
|
1142 |
+
"normalized": false,
|
1143 |
+
"rstrip": false,
|
1144 |
+
"single_word": false,
|
1145 |
+
"special": true
|
1146 |
+
},
|
1147 |
+
"143": {
|
1148 |
+
"content": "<extra_id_137>",
|
1149 |
+
"lstrip": false,
|
1150 |
+
"normalized": false,
|
1151 |
+
"rstrip": false,
|
1152 |
+
"single_word": false,
|
1153 |
+
"special": true
|
1154 |
+
},
|
1155 |
+
"144": {
|
1156 |
+
"content": "<extra_id_138>",
|
1157 |
+
"lstrip": false,
|
1158 |
+
"normalized": false,
|
1159 |
+
"rstrip": false,
|
1160 |
+
"single_word": false,
|
1161 |
+
"special": true
|
1162 |
+
},
|
1163 |
+
"145": {
|
1164 |
+
"content": "<extra_id_139>",
|
1165 |
+
"lstrip": false,
|
1166 |
+
"normalized": false,
|
1167 |
+
"rstrip": false,
|
1168 |
+
"single_word": false,
|
1169 |
+
"special": true
|
1170 |
+
},
|
1171 |
+
"146": {
|
1172 |
+
"content": "<extra_id_140>",
|
1173 |
+
"lstrip": false,
|
1174 |
+
"normalized": false,
|
1175 |
+
"rstrip": false,
|
1176 |
+
"single_word": false,
|
1177 |
+
"special": true
|
1178 |
+
},
|
1179 |
+
"147": {
|
1180 |
+
"content": "<extra_id_141>",
|
1181 |
+
"lstrip": false,
|
1182 |
+
"normalized": false,
|
1183 |
+
"rstrip": false,
|
1184 |
+
"single_word": false,
|
1185 |
+
"special": true
|
1186 |
+
},
|
1187 |
+
"148": {
|
1188 |
+
"content": "<extra_id_142>",
|
1189 |
+
"lstrip": false,
|
1190 |
+
"normalized": false,
|
1191 |
+
"rstrip": false,
|
1192 |
+
"single_word": false,
|
1193 |
+
"special": true
|
1194 |
+
},
|
1195 |
+
"149": {
|
1196 |
+
"content": "<extra_id_143>",
|
1197 |
+
"lstrip": false,
|
1198 |
+
"normalized": false,
|
1199 |
+
"rstrip": false,
|
1200 |
+
"single_word": false,
|
1201 |
+
"special": true
|
1202 |
+
},
|
1203 |
+
"150": {
|
1204 |
+
"content": "<extra_id_144>",
|
1205 |
+
"lstrip": false,
|
1206 |
+
"normalized": false,
|
1207 |
+
"rstrip": false,
|
1208 |
+
"single_word": false,
|
1209 |
+
"special": true
|
1210 |
+
},
|
1211 |
+
"151": {
|
1212 |
+
"content": "<extra_id_145>",
|
1213 |
+
"lstrip": false,
|
1214 |
+
"normalized": false,
|
1215 |
+
"rstrip": false,
|
1216 |
+
"single_word": false,
|
1217 |
+
"special": true
|
1218 |
+
},
|
1219 |
+
"152": {
|
1220 |
+
"content": "<extra_id_146>",
|
1221 |
+
"lstrip": false,
|
1222 |
+
"normalized": false,
|
1223 |
+
"rstrip": false,
|
1224 |
+
"single_word": false,
|
1225 |
+
"special": true
|
1226 |
+
},
|
1227 |
+
"153": {
|
1228 |
+
"content": "<extra_id_147>",
|
1229 |
+
"lstrip": false,
|
1230 |
+
"normalized": false,
|
1231 |
+
"rstrip": false,
|
1232 |
+
"single_word": false,
|
1233 |
+
"special": true
|
1234 |
+
},
|
1235 |
+
"154": {
|
1236 |
+
"content": "<extra_id_148>",
|
1237 |
+
"lstrip": false,
|
1238 |
+
"normalized": false,
|
1239 |
+
"rstrip": false,
|
1240 |
+
"single_word": false,
|
1241 |
+
"special": true
|
1242 |
+
},
|
1243 |
+
"155": {
|
1244 |
+
"content": "<extra_id_149>",
|
1245 |
+
"lstrip": false,
|
1246 |
+
"normalized": false,
|
1247 |
+
"rstrip": false,
|
1248 |
+
"single_word": false,
|
1249 |
+
"special": true
|
1250 |
+
},
|
1251 |
+
"156": {
|
1252 |
+
"content": "<extra_id_150>",
|
1253 |
+
"lstrip": false,
|
1254 |
+
"normalized": false,
|
1255 |
+
"rstrip": false,
|
1256 |
+
"single_word": false,
|
1257 |
+
"special": true
|
1258 |
+
},
|
1259 |
+
"157": {
|
1260 |
+
"content": "<extra_id_151>",
|
1261 |
+
"lstrip": false,
|
1262 |
+
"normalized": false,
|
1263 |
+
"rstrip": false,
|
1264 |
+
"single_word": false,
|
1265 |
+
"special": true
|
1266 |
+
},
|
1267 |
+
"158": {
|
1268 |
+
"content": "<extra_id_152>",
|
1269 |
+
"lstrip": false,
|
1270 |
+
"normalized": false,
|
1271 |
+
"rstrip": false,
|
1272 |
+
"single_word": false,
|
1273 |
+
"special": true
|
1274 |
+
},
|
1275 |
+
"159": {
|
1276 |
+
"content": "<extra_id_153>",
|
1277 |
+
"lstrip": false,
|
1278 |
+
"normalized": false,
|
1279 |
+
"rstrip": false,
|
1280 |
+
"single_word": false,
|
1281 |
+
"special": true
|
1282 |
+
},
|
1283 |
+
"160": {
|
1284 |
+
"content": "<extra_id_154>",
|
1285 |
+
"lstrip": false,
|
1286 |
+
"normalized": false,
|
1287 |
+
"rstrip": false,
|
1288 |
+
"single_word": false,
|
1289 |
+
"special": true
|
1290 |
+
},
|
1291 |
+
"161": {
|
1292 |
+
"content": "<extra_id_155>",
|
1293 |
+
"lstrip": false,
|
1294 |
+
"normalized": false,
|
1295 |
+
"rstrip": false,
|
1296 |
+
"single_word": false,
|
1297 |
+
"special": true
|
1298 |
+
},
|
1299 |
+
"162": {
|
1300 |
+
"content": "<extra_id_156>",
|
1301 |
+
"lstrip": false,
|
1302 |
+
"normalized": false,
|
1303 |
+
"rstrip": false,
|
1304 |
+
"single_word": false,
|
1305 |
+
"special": true
|
1306 |
+
},
|
1307 |
+
"163": {
|
1308 |
+
"content": "<extra_id_157>",
|
1309 |
+
"lstrip": false,
|
1310 |
+
"normalized": false,
|
1311 |
+
"rstrip": false,
|
1312 |
+
"single_word": false,
|
1313 |
+
"special": true
|
1314 |
+
},
|
1315 |
+
"164": {
|
1316 |
+
"content": "<extra_id_158>",
|
1317 |
+
"lstrip": false,
|
1318 |
+
"normalized": false,
|
1319 |
+
"rstrip": false,
|
1320 |
+
"single_word": false,
|
1321 |
+
"special": true
|
1322 |
+
},
|
1323 |
+
"165": {
|
1324 |
+
"content": "<extra_id_159>",
|
1325 |
+
"lstrip": false,
|
1326 |
+
"normalized": false,
|
1327 |
+
"rstrip": false,
|
1328 |
+
"single_word": false,
|
1329 |
+
"special": true
|
1330 |
+
},
|
1331 |
+
"166": {
|
1332 |
+
"content": "<extra_id_160>",
|
1333 |
+
"lstrip": false,
|
1334 |
+
"normalized": false,
|
1335 |
+
"rstrip": false,
|
1336 |
+
"single_word": false,
|
1337 |
+
"special": true
|
1338 |
+
},
|
1339 |
+
"167": {
|
1340 |
+
"content": "<extra_id_161>",
|
1341 |
+
"lstrip": false,
|
1342 |
+
"normalized": false,
|
1343 |
+
"rstrip": false,
|
1344 |
+
"single_word": false,
|
1345 |
+
"special": true
|
1346 |
+
},
|
1347 |
+
"168": {
|
1348 |
+
"content": "<extra_id_162>",
|
1349 |
+
"lstrip": false,
|
1350 |
+
"normalized": false,
|
1351 |
+
"rstrip": false,
|
1352 |
+
"single_word": false,
|
1353 |
+
"special": true
|
1354 |
+
},
|
1355 |
+
"169": {
|
1356 |
+
"content": "<extra_id_163>",
|
1357 |
+
"lstrip": false,
|
1358 |
+
"normalized": false,
|
1359 |
+
"rstrip": false,
|
1360 |
+
"single_word": false,
|
1361 |
+
"special": true
|
1362 |
+
},
|
1363 |
+
"170": {
|
1364 |
+
"content": "<extra_id_164>",
|
1365 |
+
"lstrip": false,
|
1366 |
+
"normalized": false,
|
1367 |
+
"rstrip": false,
|
1368 |
+
"single_word": false,
|
1369 |
+
"special": true
|
1370 |
+
},
|
1371 |
+
"171": {
|
1372 |
+
"content": "<extra_id_165>",
|
1373 |
+
"lstrip": false,
|
1374 |
+
"normalized": false,
|
1375 |
+
"rstrip": false,
|
1376 |
+
"single_word": false,
|
1377 |
+
"special": true
|
1378 |
+
},
|
1379 |
+
"172": {
|
1380 |
+
"content": "<extra_id_166>",
|
1381 |
+
"lstrip": false,
|
1382 |
+
"normalized": false,
|
1383 |
+
"rstrip": false,
|
1384 |
+
"single_word": false,
|
1385 |
+
"special": true
|
1386 |
+
},
|
1387 |
+
"173": {
|
1388 |
+
"content": "<extra_id_167>",
|
1389 |
+
"lstrip": false,
|
1390 |
+
"normalized": false,
|
1391 |
+
"rstrip": false,
|
1392 |
+
"single_word": false,
|
1393 |
+
"special": true
|
1394 |
+
},
|
1395 |
+
"174": {
|
1396 |
+
"content": "<extra_id_168>",
|
1397 |
+
"lstrip": false,
|
1398 |
+
"normalized": false,
|
1399 |
+
"rstrip": false,
|
1400 |
+
"single_word": false,
|
1401 |
+
"special": true
|
1402 |
+
},
|
1403 |
+
"175": {
|
1404 |
+
"content": "<extra_id_169>",
|
1405 |
+
"lstrip": false,
|
1406 |
+
"normalized": false,
|
1407 |
+
"rstrip": false,
|
1408 |
+
"single_word": false,
|
1409 |
+
"special": true
|
1410 |
+
},
|
1411 |
+
"176": {
|
1412 |
+
"content": "<extra_id_170>",
|
1413 |
+
"lstrip": false,
|
1414 |
+
"normalized": false,
|
1415 |
+
"rstrip": false,
|
1416 |
+
"single_word": false,
|
1417 |
+
"special": true
|
1418 |
+
},
|
1419 |
+
"177": {
|
1420 |
+
"content": "<extra_id_171>",
|
1421 |
+
"lstrip": false,
|
1422 |
+
"normalized": false,
|
1423 |
+
"rstrip": false,
|
1424 |
+
"single_word": false,
|
1425 |
+
"special": true
|
1426 |
+
},
|
1427 |
+
"178": {
|
1428 |
+
"content": "<extra_id_172>",
|
1429 |
+
"lstrip": false,
|
1430 |
+
"normalized": false,
|
1431 |
+
"rstrip": false,
|
1432 |
+
"single_word": false,
|
1433 |
+
"special": true
|
1434 |
+
},
|
1435 |
+
"179": {
|
1436 |
+
"content": "<extra_id_173>",
|
1437 |
+
"lstrip": false,
|
1438 |
+
"normalized": false,
|
1439 |
+
"rstrip": false,
|
1440 |
+
"single_word": false,
|
1441 |
+
"special": true
|
1442 |
+
},
|
1443 |
+
"180": {
|
1444 |
+
"content": "<extra_id_174>",
|
1445 |
+
"lstrip": false,
|
1446 |
+
"normalized": false,
|
1447 |
+
"rstrip": false,
|
1448 |
+
"single_word": false,
|
1449 |
+
"special": true
|
1450 |
+
},
|
1451 |
+
"181": {
|
1452 |
+
"content": "<extra_id_175>",
|
1453 |
+
"lstrip": false,
|
1454 |
+
"normalized": false,
|
1455 |
+
"rstrip": false,
|
1456 |
+
"single_word": false,
|
1457 |
+
"special": true
|
1458 |
+
},
|
1459 |
+
"182": {
|
1460 |
+
"content": "<extra_id_176>",
|
1461 |
+
"lstrip": false,
|
1462 |
+
"normalized": false,
|
1463 |
+
"rstrip": false,
|
1464 |
+
"single_word": false,
|
1465 |
+
"special": true
|
1466 |
+
},
|
1467 |
+
"183": {
|
1468 |
+
"content": "<extra_id_177>",
|
1469 |
+
"lstrip": false,
|
1470 |
+
"normalized": false,
|
1471 |
+
"rstrip": false,
|
1472 |
+
"single_word": false,
|
1473 |
+
"special": true
|
1474 |
+
},
|
1475 |
+
"184": {
|
1476 |
+
"content": "<extra_id_178>",
|
1477 |
+
"lstrip": false,
|
1478 |
+
"normalized": false,
|
1479 |
+
"rstrip": false,
|
1480 |
+
"single_word": false,
|
1481 |
+
"special": true
|
1482 |
+
},
|
1483 |
+
"185": {
|
1484 |
+
"content": "<extra_id_179>",
|
1485 |
+
"lstrip": false,
|
1486 |
+
"normalized": false,
|
1487 |
+
"rstrip": false,
|
1488 |
+
"single_word": false,
|
1489 |
+
"special": true
|
1490 |
+
},
|
1491 |
+
"186": {
|
1492 |
+
"content": "<extra_id_180>",
|
1493 |
+
"lstrip": false,
|
1494 |
+
"normalized": false,
|
1495 |
+
"rstrip": false,
|
1496 |
+
"single_word": false,
|
1497 |
+
"special": true
|
1498 |
+
},
|
1499 |
+
"187": {
|
1500 |
+
"content": "<extra_id_181>",
|
1501 |
+
"lstrip": false,
|
1502 |
+
"normalized": false,
|
1503 |
+
"rstrip": false,
|
1504 |
+
"single_word": false,
|
1505 |
+
"special": true
|
1506 |
+
},
|
1507 |
+
"188": {
|
1508 |
+
"content": "<extra_id_182>",
|
1509 |
+
"lstrip": false,
|
1510 |
+
"normalized": false,
|
1511 |
+
"rstrip": false,
|
1512 |
+
"single_word": false,
|
1513 |
+
"special": true
|
1514 |
+
},
|
1515 |
+
"189": {
|
1516 |
+
"content": "<extra_id_183>",
|
1517 |
+
"lstrip": false,
|
1518 |
+
"normalized": false,
|
1519 |
+
"rstrip": false,
|
1520 |
+
"single_word": false,
|
1521 |
+
"special": true
|
1522 |
+
},
|
1523 |
+
"190": {
|
1524 |
+
"content": "<extra_id_184>",
|
1525 |
+
"lstrip": false,
|
1526 |
+
"normalized": false,
|
1527 |
+
"rstrip": false,
|
1528 |
+
"single_word": false,
|
1529 |
+
"special": true
|
1530 |
+
},
|
1531 |
+
"191": {
|
1532 |
+
"content": "<extra_id_185>",
|
1533 |
+
"lstrip": false,
|
1534 |
+
"normalized": false,
|
1535 |
+
"rstrip": false,
|
1536 |
+
"single_word": false,
|
1537 |
+
"special": true
|
1538 |
+
},
|
1539 |
+
"192": {
|
1540 |
+
"content": "<extra_id_186>",
|
1541 |
+
"lstrip": false,
|
1542 |
+
"normalized": false,
|
1543 |
+
"rstrip": false,
|
1544 |
+
"single_word": false,
|
1545 |
+
"special": true
|
1546 |
+
},
|
1547 |
+
"193": {
|
1548 |
+
"content": "<extra_id_187>",
|
1549 |
+
"lstrip": false,
|
1550 |
+
"normalized": false,
|
1551 |
+
"rstrip": false,
|
1552 |
+
"single_word": false,
|
1553 |
+
"special": true
|
1554 |
+
},
|
1555 |
+
"194": {
|
1556 |
+
"content": "<extra_id_188>",
|
1557 |
+
"lstrip": false,
|
1558 |
+
"normalized": false,
|
1559 |
+
"rstrip": false,
|
1560 |
+
"single_word": false,
|
1561 |
+
"special": true
|
1562 |
+
},
|
1563 |
+
"195": {
|
1564 |
+
"content": "<extra_id_189>",
|
1565 |
+
"lstrip": false,
|
1566 |
+
"normalized": false,
|
1567 |
+
"rstrip": false,
|
1568 |
+
"single_word": false,
|
1569 |
+
"special": true
|
1570 |
+
},
|
1571 |
+
"196": {
|
1572 |
+
"content": "<extra_id_190>",
|
1573 |
+
"lstrip": false,
|
1574 |
+
"normalized": false,
|
1575 |
+
"rstrip": false,
|
1576 |
+
"single_word": false,
|
1577 |
+
"special": true
|
1578 |
+
},
|
1579 |
+
"197": {
|
1580 |
+
"content": "<extra_id_191>",
|
1581 |
+
"lstrip": false,
|
1582 |
+
"normalized": false,
|
1583 |
+
"rstrip": false,
|
1584 |
+
"single_word": false,
|
1585 |
+
"special": true
|
1586 |
+
},
|
1587 |
+
"198": {
|
1588 |
+
"content": "<extra_id_192>",
|
1589 |
+
"lstrip": false,
|
1590 |
+
"normalized": false,
|
1591 |
+
"rstrip": false,
|
1592 |
+
"single_word": false,
|
1593 |
+
"special": true
|
1594 |
+
},
|
1595 |
+
"199": {
|
1596 |
+
"content": "<extra_id_193>",
|
1597 |
+
"lstrip": false,
|
1598 |
+
"normalized": false,
|
1599 |
+
"rstrip": false,
|
1600 |
+
"single_word": false,
|
1601 |
+
"special": true
|
1602 |
+
},
|
1603 |
+
"200": {
|
1604 |
+
"content": "<extra_id_194>",
|
1605 |
+
"lstrip": false,
|
1606 |
+
"normalized": false,
|
1607 |
+
"rstrip": false,
|
1608 |
+
"single_word": false,
|
1609 |
+
"special": true
|
1610 |
+
},
|
1611 |
+
"201": {
|
1612 |
+
"content": "<extra_id_195>",
|
1613 |
+
"lstrip": false,
|
1614 |
+
"normalized": false,
|
1615 |
+
"rstrip": false,
|
1616 |
+
"single_word": false,
|
1617 |
+
"special": true
|
1618 |
+
},
|
1619 |
+
"202": {
|
1620 |
+
"content": "<extra_id_196>",
|
1621 |
+
"lstrip": false,
|
1622 |
+
"normalized": false,
|
1623 |
+
"rstrip": false,
|
1624 |
+
"single_word": false,
|
1625 |
+
"special": true
|
1626 |
+
},
|
1627 |
+
"203": {
|
1628 |
+
"content": "<extra_id_197>",
|
1629 |
+
"lstrip": false,
|
1630 |
+
"normalized": false,
|
1631 |
+
"rstrip": false,
|
1632 |
+
"single_word": false,
|
1633 |
+
"special": true
|
1634 |
+
},
|
1635 |
+
"204": {
|
1636 |
+
"content": "<extra_id_198>",
|
1637 |
+
"lstrip": false,
|
1638 |
+
"normalized": false,
|
1639 |
+
"rstrip": false,
|
1640 |
+
"single_word": false,
|
1641 |
+
"special": true
|
1642 |
+
},
|
1643 |
+
"205": {
|
1644 |
+
"content": "<extra_id_199>",
|
1645 |
+
"lstrip": false,
|
1646 |
+
"normalized": false,
|
1647 |
+
"rstrip": false,
|
1648 |
+
"single_word": false,
|
1649 |
+
"special": true
|
1650 |
+
},
|
1651 |
+
"206": {
|
1652 |
+
"content": "<extra_id_200>",
|
1653 |
+
"lstrip": false,
|
1654 |
+
"normalized": false,
|
1655 |
+
"rstrip": false,
|
1656 |
+
"single_word": false,
|
1657 |
+
"special": true
|
1658 |
+
},
|
1659 |
+
"207": {
|
1660 |
+
"content": "<extra_id_201>",
|
1661 |
+
"lstrip": false,
|
1662 |
+
"normalized": false,
|
1663 |
+
"rstrip": false,
|
1664 |
+
"single_word": false,
|
1665 |
+
"special": true
|
1666 |
+
},
|
1667 |
+
"208": {
|
1668 |
+
"content": "<extra_id_202>",
|
1669 |
+
"lstrip": false,
|
1670 |
+
"normalized": false,
|
1671 |
+
"rstrip": false,
|
1672 |
+
"single_word": false,
|
1673 |
+
"special": true
|
1674 |
+
},
|
1675 |
+
"209": {
|
1676 |
+
"content": "<extra_id_203>",
|
1677 |
+
"lstrip": false,
|
1678 |
+
"normalized": false,
|
1679 |
+
"rstrip": false,
|
1680 |
+
"single_word": false,
|
1681 |
+
"special": true
|
1682 |
+
},
|
1683 |
+
"210": {
|
1684 |
+
"content": "<extra_id_204>",
|
1685 |
+
"lstrip": false,
|
1686 |
+
"normalized": false,
|
1687 |
+
"rstrip": false,
|
1688 |
+
"single_word": false,
|
1689 |
+
"special": true
|
1690 |
+
},
|
1691 |
+
"211": {
|
1692 |
+
"content": "<extra_id_205>",
|
1693 |
+
"lstrip": false,
|
1694 |
+
"normalized": false,
|
1695 |
+
"rstrip": false,
|
1696 |
+
"single_word": false,
|
1697 |
+
"special": true
|
1698 |
+
},
|
1699 |
+
"212": {
|
1700 |
+
"content": "<extra_id_206>",
|
1701 |
+
"lstrip": false,
|
1702 |
+
"normalized": false,
|
1703 |
+
"rstrip": false,
|
1704 |
+
"single_word": false,
|
1705 |
+
"special": true
|
1706 |
+
},
|
1707 |
+
"213": {
|
1708 |
+
"content": "<extra_id_207>",
|
1709 |
+
"lstrip": false,
|
1710 |
+
"normalized": false,
|
1711 |
+
"rstrip": false,
|
1712 |
+
"single_word": false,
|
1713 |
+
"special": true
|
1714 |
+
},
|
1715 |
+
"214": {
|
1716 |
+
"content": "<extra_id_208>",
|
1717 |
+
"lstrip": false,
|
1718 |
+
"normalized": false,
|
1719 |
+
"rstrip": false,
|
1720 |
+
"single_word": false,
|
1721 |
+
"special": true
|
1722 |
+
},
|
1723 |
+
"215": {
|
1724 |
+
"content": "<extra_id_209>",
|
1725 |
+
"lstrip": false,
|
1726 |
+
"normalized": false,
|
1727 |
+
"rstrip": false,
|
1728 |
+
"single_word": false,
|
1729 |
+
"special": true
|
1730 |
+
},
|
1731 |
+
"216": {
|
1732 |
+
"content": "<extra_id_210>",
|
1733 |
+
"lstrip": false,
|
1734 |
+
"normalized": false,
|
1735 |
+
"rstrip": false,
|
1736 |
+
"single_word": false,
|
1737 |
+
"special": true
|
1738 |
+
},
|
1739 |
+
"217": {
|
1740 |
+
"content": "<extra_id_211>",
|
1741 |
+
"lstrip": false,
|
1742 |
+
"normalized": false,
|
1743 |
+
"rstrip": false,
|
1744 |
+
"single_word": false,
|
1745 |
+
"special": true
|
1746 |
+
},
|
1747 |
+
"218": {
|
1748 |
+
"content": "<extra_id_212>",
|
1749 |
+
"lstrip": false,
|
1750 |
+
"normalized": false,
|
1751 |
+
"rstrip": false,
|
1752 |
+
"single_word": false,
|
1753 |
+
"special": true
|
1754 |
+
},
|
1755 |
+
"219": {
|
1756 |
+
"content": "<extra_id_213>",
|
1757 |
+
"lstrip": false,
|
1758 |
+
"normalized": false,
|
1759 |
+
"rstrip": false,
|
1760 |
+
"single_word": false,
|
1761 |
+
"special": true
|
1762 |
+
},
|
1763 |
+
"220": {
|
1764 |
+
"content": "<extra_id_214>",
|
1765 |
+
"lstrip": false,
|
1766 |
+
"normalized": false,
|
1767 |
+
"rstrip": false,
|
1768 |
+
"single_word": false,
|
1769 |
+
"special": true
|
1770 |
+
},
|
1771 |
+
"221": {
|
1772 |
+
"content": "<extra_id_215>",
|
1773 |
+
"lstrip": false,
|
1774 |
+
"normalized": false,
|
1775 |
+
"rstrip": false,
|
1776 |
+
"single_word": false,
|
1777 |
+
"special": true
|
1778 |
+
},
|
1779 |
+
"222": {
|
1780 |
+
"content": "<extra_id_216>",
|
1781 |
+
"lstrip": false,
|
1782 |
+
"normalized": false,
|
1783 |
+
"rstrip": false,
|
1784 |
+
"single_word": false,
|
1785 |
+
"special": true
|
1786 |
+
},
|
1787 |
+
"223": {
|
1788 |
+
"content": "<extra_id_217>",
|
1789 |
+
"lstrip": false,
|
1790 |
+
"normalized": false,
|
1791 |
+
"rstrip": false,
|
1792 |
+
"single_word": false,
|
1793 |
+
"special": true
|
1794 |
+
},
|
1795 |
+
"224": {
|
1796 |
+
"content": "<extra_id_218>",
|
1797 |
+
"lstrip": false,
|
1798 |
+
"normalized": false,
|
1799 |
+
"rstrip": false,
|
1800 |
+
"single_word": false,
|
1801 |
+
"special": true
|
1802 |
+
},
|
1803 |
+
"225": {
|
1804 |
+
"content": "<extra_id_219>",
|
1805 |
+
"lstrip": false,
|
1806 |
+
"normalized": false,
|
1807 |
+
"rstrip": false,
|
1808 |
+
"single_word": false,
|
1809 |
+
"special": true
|
1810 |
+
},
|
1811 |
+
"226": {
|
1812 |
+
"content": "<extra_id_220>",
|
1813 |
+
"lstrip": false,
|
1814 |
+
"normalized": false,
|
1815 |
+
"rstrip": false,
|
1816 |
+
"single_word": false,
|
1817 |
+
"special": true
|
1818 |
+
},
|
1819 |
+
"227": {
|
1820 |
+
"content": "<extra_id_221>",
|
1821 |
+
"lstrip": false,
|
1822 |
+
"normalized": false,
|
1823 |
+
"rstrip": false,
|
1824 |
+
"single_word": false,
|
1825 |
+
"special": true
|
1826 |
+
},
|
1827 |
+
"228": {
|
1828 |
+
"content": "<extra_id_222>",
|
1829 |
+
"lstrip": false,
|
1830 |
+
"normalized": false,
|
1831 |
+
"rstrip": false,
|
1832 |
+
"single_word": false,
|
1833 |
+
"special": true
|
1834 |
+
},
|
1835 |
+
"229": {
|
1836 |
+
"content": "<extra_id_223>",
|
1837 |
+
"lstrip": false,
|
1838 |
+
"normalized": false,
|
1839 |
+
"rstrip": false,
|
1840 |
+
"single_word": false,
|
1841 |
+
"special": true
|
1842 |
+
},
|
1843 |
+
"230": {
|
1844 |
+
"content": "<extra_id_224>",
|
1845 |
+
"lstrip": false,
|
1846 |
+
"normalized": false,
|
1847 |
+
"rstrip": false,
|
1848 |
+
"single_word": false,
|
1849 |
+
"special": true
|
1850 |
+
},
|
1851 |
+
"231": {
|
1852 |
+
"content": "<extra_id_225>",
|
1853 |
+
"lstrip": false,
|
1854 |
+
"normalized": false,
|
1855 |
+
"rstrip": false,
|
1856 |
+
"single_word": false,
|
1857 |
+
"special": true
|
1858 |
+
},
|
1859 |
+
"232": {
|
1860 |
+
"content": "<extra_id_226>",
|
1861 |
+
"lstrip": false,
|
1862 |
+
"normalized": false,
|
1863 |
+
"rstrip": false,
|
1864 |
+
"single_word": false,
|
1865 |
+
"special": true
|
1866 |
+
},
|
1867 |
+
"233": {
|
1868 |
+
"content": "<extra_id_227>",
|
1869 |
+
"lstrip": false,
|
1870 |
+
"normalized": false,
|
1871 |
+
"rstrip": false,
|
1872 |
+
"single_word": false,
|
1873 |
+
"special": true
|
1874 |
+
},
|
1875 |
+
"234": {
|
1876 |
+
"content": "<extra_id_228>",
|
1877 |
+
"lstrip": false,
|
1878 |
+
"normalized": false,
|
1879 |
+
"rstrip": false,
|
1880 |
+
"single_word": false,
|
1881 |
+
"special": true
|
1882 |
+
},
|
1883 |
+
"235": {
|
1884 |
+
"content": "<extra_id_229>",
|
1885 |
+
"lstrip": false,
|
1886 |
+
"normalized": false,
|
1887 |
+
"rstrip": false,
|
1888 |
+
"single_word": false,
|
1889 |
+
"special": true
|
1890 |
+
},
|
1891 |
+
"236": {
|
1892 |
+
"content": "<extra_id_230>",
|
1893 |
+
"lstrip": false,
|
1894 |
+
"normalized": false,
|
1895 |
+
"rstrip": false,
|
1896 |
+
"single_word": false,
|
1897 |
+
"special": true
|
1898 |
+
},
|
1899 |
+
"237": {
|
1900 |
+
"content": "<extra_id_231>",
|
1901 |
+
"lstrip": false,
|
1902 |
+
"normalized": false,
|
1903 |
+
"rstrip": false,
|
1904 |
+
"single_word": false,
|
1905 |
+
"special": true
|
1906 |
+
},
|
1907 |
+
"238": {
|
1908 |
+
"content": "<extra_id_232>",
|
1909 |
+
"lstrip": false,
|
1910 |
+
"normalized": false,
|
1911 |
+
"rstrip": false,
|
1912 |
+
"single_word": false,
|
1913 |
+
"special": true
|
1914 |
+
},
|
1915 |
+
"239": {
|
1916 |
+
"content": "<extra_id_233>",
|
1917 |
+
"lstrip": false,
|
1918 |
+
"normalized": false,
|
1919 |
+
"rstrip": false,
|
1920 |
+
"single_word": false,
|
1921 |
+
"special": true
|
1922 |
+
},
|
1923 |
+
"240": {
|
1924 |
+
"content": "<extra_id_234>",
|
1925 |
+
"lstrip": false,
|
1926 |
+
"normalized": false,
|
1927 |
+
"rstrip": false,
|
1928 |
+
"single_word": false,
|
1929 |
+
"special": true
|
1930 |
+
},
|
1931 |
+
"241": {
|
1932 |
+
"content": "<extra_id_235>",
|
1933 |
+
"lstrip": false,
|
1934 |
+
"normalized": false,
|
1935 |
+
"rstrip": false,
|
1936 |
+
"single_word": false,
|
1937 |
+
"special": true
|
1938 |
+
},
|
1939 |
+
"242": {
|
1940 |
+
"content": "<extra_id_236>",
|
1941 |
+
"lstrip": false,
|
1942 |
+
"normalized": false,
|
1943 |
+
"rstrip": false,
|
1944 |
+
"single_word": false,
|
1945 |
+
"special": true
|
1946 |
+
},
|
1947 |
+
"243": {
|
1948 |
+
"content": "<extra_id_237>",
|
1949 |
+
"lstrip": false,
|
1950 |
+
"normalized": false,
|
1951 |
+
"rstrip": false,
|
1952 |
+
"single_word": false,
|
1953 |
+
"special": true
|
1954 |
+
},
|
1955 |
+
"244": {
|
1956 |
+
"content": "<extra_id_238>",
|
1957 |
+
"lstrip": false,
|
1958 |
+
"normalized": false,
|
1959 |
+
"rstrip": false,
|
1960 |
+
"single_word": false,
|
1961 |
+
"special": true
|
1962 |
+
},
|
1963 |
+
"245": {
|
1964 |
+
"content": "<extra_id_239>",
|
1965 |
+
"lstrip": false,
|
1966 |
+
"normalized": false,
|
1967 |
+
"rstrip": false,
|
1968 |
+
"single_word": false,
|
1969 |
+
"special": true
|
1970 |
+
},
|
1971 |
+
"246": {
|
1972 |
+
"content": "<extra_id_240>",
|
1973 |
+
"lstrip": false,
|
1974 |
+
"normalized": false,
|
1975 |
+
"rstrip": false,
|
1976 |
+
"single_word": false,
|
1977 |
+
"special": true
|
1978 |
+
},
|
1979 |
+
"247": {
|
1980 |
+
"content": "<extra_id_241>",
|
1981 |
+
"lstrip": false,
|
1982 |
+
"normalized": false,
|
1983 |
+
"rstrip": false,
|
1984 |
+
"single_word": false,
|
1985 |
+
"special": true
|
1986 |
+
},
|
1987 |
+
"248": {
|
1988 |
+
"content": "<extra_id_242>",
|
1989 |
+
"lstrip": false,
|
1990 |
+
"normalized": false,
|
1991 |
+
"rstrip": false,
|
1992 |
+
"single_word": false,
|
1993 |
+
"special": true
|
1994 |
+
},
|
1995 |
+
"249": {
|
1996 |
+
"content": "<extra_id_243>",
|
1997 |
+
"lstrip": false,
|
1998 |
+
"normalized": false,
|
1999 |
+
"rstrip": false,
|
2000 |
+
"single_word": false,
|
2001 |
+
"special": true
|
2002 |
+
},
|
2003 |
+
"250": {
|
2004 |
+
"content": "<extra_id_244>",
|
2005 |
+
"lstrip": false,
|
2006 |
+
"normalized": false,
|
2007 |
+
"rstrip": false,
|
2008 |
+
"single_word": false,
|
2009 |
+
"special": true
|
2010 |
+
},
|
2011 |
+
"251": {
|
2012 |
+
"content": "<extra_id_245>",
|
2013 |
+
"lstrip": false,
|
2014 |
+
"normalized": false,
|
2015 |
+
"rstrip": false,
|
2016 |
+
"single_word": false,
|
2017 |
+
"special": true
|
2018 |
+
},
|
2019 |
+
"252": {
|
2020 |
+
"content": "<extra_id_246>",
|
2021 |
+
"lstrip": false,
|
2022 |
+
"normalized": false,
|
2023 |
+
"rstrip": false,
|
2024 |
+
"single_word": false,
|
2025 |
+
"special": true
|
2026 |
+
},
|
2027 |
+
"253": {
|
2028 |
+
"content": "<extra_id_247>",
|
2029 |
+
"lstrip": false,
|
2030 |
+
"normalized": false,
|
2031 |
+
"rstrip": false,
|
2032 |
+
"single_word": false,
|
2033 |
+
"special": true
|
2034 |
+
},
|
2035 |
+
"254": {
|
2036 |
+
"content": "<extra_id_248>",
|
2037 |
+
"lstrip": false,
|
2038 |
+
"normalized": false,
|
2039 |
+
"rstrip": false,
|
2040 |
+
"single_word": false,
|
2041 |
+
"special": true
|
2042 |
+
},
|
2043 |
+
"255": {
|
2044 |
+
"content": "<extra_id_249>",
|
2045 |
+
"lstrip": false,
|
2046 |
+
"normalized": false,
|
2047 |
+
"rstrip": false,
|
2048 |
+
"single_word": false,
|
2049 |
+
"special": true
|
2050 |
+
},
|
2051 |
+
"256": {
|
2052 |
+
"content": "<extra_id_250>",
|
2053 |
+
"lstrip": false,
|
2054 |
+
"normalized": false,
|
2055 |
+
"rstrip": false,
|
2056 |
+
"single_word": false,
|
2057 |
+
"special": true
|
2058 |
+
},
|
2059 |
+
"257": {
|
2060 |
+
"content": "<extra_id_251>",
|
2061 |
+
"lstrip": false,
|
2062 |
+
"normalized": false,
|
2063 |
+
"rstrip": false,
|
2064 |
+
"single_word": false,
|
2065 |
+
"special": true
|
2066 |
+
},
|
2067 |
+
"258": {
|
2068 |
+
"content": "<extra_id_252>",
|
2069 |
+
"lstrip": false,
|
2070 |
+
"normalized": false,
|
2071 |
+
"rstrip": false,
|
2072 |
+
"single_word": false,
|
2073 |
+
"special": true
|
2074 |
+
},
|
2075 |
+
"259": {
|
2076 |
+
"content": "<extra_id_253>",
|
2077 |
+
"lstrip": false,
|
2078 |
+
"normalized": false,
|
2079 |
+
"rstrip": false,
|
2080 |
+
"single_word": false,
|
2081 |
+
"special": true
|
2082 |
+
},
|
2083 |
+
"260": {
|
2084 |
+
"content": "<extra_id_254>",
|
2085 |
+
"lstrip": false,
|
2086 |
+
"normalized": false,
|
2087 |
+
"rstrip": false,
|
2088 |
+
"single_word": false,
|
2089 |
+
"special": true
|
2090 |
+
},
|
2091 |
+
"261": {
|
2092 |
+
"content": "<extra_id_255>",
|
2093 |
+
"lstrip": false,
|
2094 |
+
"normalized": false,
|
2095 |
+
"rstrip": false,
|
2096 |
+
"single_word": false,
|
2097 |
+
"special": true
|
2098 |
+
}
|
2099 |
+
},
|
2100 |
+
"additional_special_tokens": [
|
2101 |
+
"<extra_id_0>",
|
2102 |
+
"<extra_id_1>",
|
2103 |
+
"<extra_id_2>",
|
2104 |
+
"<extra_id_3>",
|
2105 |
+
"<extra_id_4>",
|
2106 |
+
"<extra_id_5>",
|
2107 |
+
"<extra_id_6>",
|
2108 |
+
"<extra_id_7>",
|
2109 |
+
"<extra_id_8>",
|
2110 |
+
"<extra_id_9>",
|
2111 |
+
"<extra_id_10>",
|
2112 |
+
"<extra_id_11>",
|
2113 |
+
"<extra_id_12>",
|
2114 |
+
"<extra_id_13>",
|
2115 |
+
"<extra_id_14>",
|
2116 |
+
"<extra_id_15>",
|
2117 |
+
"<extra_id_16>",
|
2118 |
+
"<extra_id_17>",
|
2119 |
+
"<extra_id_18>",
|
2120 |
+
"<extra_id_19>",
|
2121 |
+
"<extra_id_20>",
|
2122 |
+
"<extra_id_21>",
|
2123 |
+
"<extra_id_22>",
|
2124 |
+
"<extra_id_23>",
|
2125 |
+
"<extra_id_24>",
|
2126 |
+
"<extra_id_25>",
|
2127 |
+
"<extra_id_26>",
|
2128 |
+
"<extra_id_27>",
|
2129 |
+
"<extra_id_28>",
|
2130 |
+
"<extra_id_29>",
|
2131 |
+
"<extra_id_30>",
|
2132 |
+
"<extra_id_31>",
|
2133 |
+
"<extra_id_32>",
|
2134 |
+
"<extra_id_33>",
|
2135 |
+
"<extra_id_34>",
|
2136 |
+
"<extra_id_35>",
|
2137 |
+
"<extra_id_36>",
|
2138 |
+
"<extra_id_37>",
|
2139 |
+
"<extra_id_38>",
|
2140 |
+
"<extra_id_39>",
|
2141 |
+
"<extra_id_40>",
|
2142 |
+
"<extra_id_41>",
|
2143 |
+
"<extra_id_42>",
|
2144 |
+
"<extra_id_43>",
|
2145 |
+
"<extra_id_44>",
|
2146 |
+
"<extra_id_45>",
|
2147 |
+
"<extra_id_46>",
|
2148 |
+
"<extra_id_47>",
|
2149 |
+
"<extra_id_48>",
|
2150 |
+
"<extra_id_49>",
|
2151 |
+
"<extra_id_50>",
|
2152 |
+
"<extra_id_51>",
|
2153 |
+
"<extra_id_52>",
|
2154 |
+
"<extra_id_53>",
|
2155 |
+
"<extra_id_54>",
|
2156 |
+
"<extra_id_55>",
|
2157 |
+
"<extra_id_56>",
|
2158 |
+
"<extra_id_57>",
|
2159 |
+
"<extra_id_58>",
|
2160 |
+
"<extra_id_59>",
|
2161 |
+
"<extra_id_60>",
|
2162 |
+
"<extra_id_61>",
|
2163 |
+
"<extra_id_62>",
|
2164 |
+
"<extra_id_63>",
|
2165 |
+
"<extra_id_64>",
|
2166 |
+
"<extra_id_65>",
|
2167 |
+
"<extra_id_66>",
|
2168 |
+
"<extra_id_67>",
|
2169 |
+
"<extra_id_68>",
|
2170 |
+
"<extra_id_69>",
|
2171 |
+
"<extra_id_70>",
|
2172 |
+
"<extra_id_71>",
|
2173 |
+
"<extra_id_72>",
|
2174 |
+
"<extra_id_73>",
|
2175 |
+
"<extra_id_74>",
|
2176 |
+
"<extra_id_75>",
|
2177 |
+
"<extra_id_76>",
|
2178 |
+
"<extra_id_77>",
|
2179 |
+
"<extra_id_78>",
|
2180 |
+
"<extra_id_79>",
|
2181 |
+
"<extra_id_80>",
|
2182 |
+
"<extra_id_81>",
|
2183 |
+
"<extra_id_82>",
|
2184 |
+
"<extra_id_83>",
|
2185 |
+
"<extra_id_84>",
|
2186 |
+
"<extra_id_85>",
|
2187 |
+
"<extra_id_86>",
|
2188 |
+
"<extra_id_87>",
|
2189 |
+
"<extra_id_88>",
|
2190 |
+
"<extra_id_89>",
|
2191 |
+
"<extra_id_90>",
|
2192 |
+
"<extra_id_91>",
|
2193 |
+
"<extra_id_92>",
|
2194 |
+
"<extra_id_93>",
|
2195 |
+
"<extra_id_94>",
|
2196 |
+
"<extra_id_95>",
|
2197 |
+
"<extra_id_96>",
|
2198 |
+
"<extra_id_97>",
|
2199 |
+
"<extra_id_98>",
|
2200 |
+
"<extra_id_99>",
|
2201 |
+
"<extra_id_100>",
|
2202 |
+
"<extra_id_101>",
|
2203 |
+
"<extra_id_102>",
|
2204 |
+
"<extra_id_103>",
|
2205 |
+
"<extra_id_104>",
|
2206 |
+
"<extra_id_105>",
|
2207 |
+
"<extra_id_106>",
|
2208 |
+
"<extra_id_107>",
|
2209 |
+
"<extra_id_108>",
|
2210 |
+
"<extra_id_109>",
|
2211 |
+
"<extra_id_110>",
|
2212 |
+
"<extra_id_111>",
|
2213 |
+
"<extra_id_112>",
|
2214 |
+
"<extra_id_113>",
|
2215 |
+
"<extra_id_114>",
|
2216 |
+
"<extra_id_115>",
|
2217 |
+
"<extra_id_116>",
|
2218 |
+
"<extra_id_117>",
|
2219 |
+
"<extra_id_118>",
|
2220 |
+
"<extra_id_119>",
|
2221 |
+
"<extra_id_120>",
|
2222 |
+
"<extra_id_121>",
|
2223 |
+
"<extra_id_122>",
|
2224 |
+
"<extra_id_123>",
|
2225 |
+
"<extra_id_124>",
|
2226 |
+
"<extra_id_125>",
|
2227 |
+
"<extra_id_126>",
|
2228 |
+
"<extra_id_127>",
|
2229 |
+
"<extra_id_128>",
|
2230 |
+
"<extra_id_129>",
|
2231 |
+
"<extra_id_130>",
|
2232 |
+
"<extra_id_131>",
|
2233 |
+
"<extra_id_132>",
|
2234 |
+
"<extra_id_133>",
|
2235 |
+
"<extra_id_134>",
|
2236 |
+
"<extra_id_135>",
|
2237 |
+
"<extra_id_136>",
|
2238 |
+
"<extra_id_137>",
|
2239 |
+
"<extra_id_138>",
|
2240 |
+
"<extra_id_139>",
|
2241 |
+
"<extra_id_140>",
|
2242 |
+
"<extra_id_141>",
|
2243 |
+
"<extra_id_142>",
|
2244 |
+
"<extra_id_143>",
|
2245 |
+
"<extra_id_144>",
|
2246 |
+
"<extra_id_145>",
|
2247 |
+
"<extra_id_146>",
|
2248 |
+
"<extra_id_147>",
|
2249 |
+
"<extra_id_148>",
|
2250 |
+
"<extra_id_149>",
|
2251 |
+
"<extra_id_150>",
|
2252 |
+
"<extra_id_151>",
|
2253 |
+
"<extra_id_152>",
|
2254 |
+
"<extra_id_153>",
|
2255 |
+
"<extra_id_154>",
|
2256 |
+
"<extra_id_155>",
|
2257 |
+
"<extra_id_156>",
|
2258 |
+
"<extra_id_157>",
|
2259 |
+
"<extra_id_158>",
|
2260 |
+
"<extra_id_159>",
|
2261 |
+
"<extra_id_160>",
|
2262 |
+
"<extra_id_161>",
|
2263 |
+
"<extra_id_162>",
|
2264 |
+
"<extra_id_163>",
|
2265 |
+
"<extra_id_164>",
|
2266 |
+
"<extra_id_165>",
|
2267 |
+
"<extra_id_166>",
|
2268 |
+
"<extra_id_167>",
|
2269 |
+
"<extra_id_168>",
|
2270 |
+
"<extra_id_169>",
|
2271 |
+
"<extra_id_170>",
|
2272 |
+
"<extra_id_171>",
|
2273 |
+
"<extra_id_172>",
|
2274 |
+
"<extra_id_173>",
|
2275 |
+
"<extra_id_174>",
|
2276 |
+
"<extra_id_175>",
|
2277 |
+
"<extra_id_176>",
|
2278 |
+
"<extra_id_177>",
|
2279 |
+
"<extra_id_178>",
|
2280 |
+
"<extra_id_179>",
|
2281 |
+
"<extra_id_180>",
|
2282 |
+
"<extra_id_181>",
|
2283 |
+
"<extra_id_182>",
|
2284 |
+
"<extra_id_183>",
|
2285 |
+
"<extra_id_184>",
|
2286 |
+
"<extra_id_185>",
|
2287 |
+
"<extra_id_186>",
|
2288 |
+
"<extra_id_187>",
|
2289 |
+
"<extra_id_188>",
|
2290 |
+
"<extra_id_189>",
|
2291 |
+
"<extra_id_190>",
|
2292 |
+
"<extra_id_191>",
|
2293 |
+
"<extra_id_192>",
|
2294 |
+
"<extra_id_193>",
|
2295 |
+
"<extra_id_194>",
|
2296 |
+
"<extra_id_195>",
|
2297 |
+
"<extra_id_196>",
|
2298 |
+
"<extra_id_197>",
|
2299 |
+
"<extra_id_198>",
|
2300 |
+
"<extra_id_199>",
|
2301 |
+
"<extra_id_200>",
|
2302 |
+
"<extra_id_201>",
|
2303 |
+
"<extra_id_202>",
|
2304 |
+
"<extra_id_203>",
|
2305 |
+
"<extra_id_204>",
|
2306 |
+
"<extra_id_205>",
|
2307 |
+
"<extra_id_206>",
|
2308 |
+
"<extra_id_207>",
|
2309 |
+
"<extra_id_208>",
|
2310 |
+
"<extra_id_209>",
|
2311 |
+
"<extra_id_210>",
|
2312 |
+
"<extra_id_211>",
|
2313 |
+
"<extra_id_212>",
|
2314 |
+
"<extra_id_213>",
|
2315 |
+
"<extra_id_214>",
|
2316 |
+
"<extra_id_215>",
|
2317 |
+
"<extra_id_216>",
|
2318 |
+
"<extra_id_217>",
|
2319 |
+
"<extra_id_218>",
|
2320 |
+
"<extra_id_219>",
|
2321 |
+
"<extra_id_220>",
|
2322 |
+
"<extra_id_221>",
|
2323 |
+
"<extra_id_222>",
|
2324 |
+
"<extra_id_223>",
|
2325 |
+
"<extra_id_224>",
|
2326 |
+
"<extra_id_225>",
|
2327 |
+
"<extra_id_226>",
|
2328 |
+
"<extra_id_227>",
|
2329 |
+
"<extra_id_228>",
|
2330 |
+
"<extra_id_229>",
|
2331 |
+
"<extra_id_230>",
|
2332 |
+
"<extra_id_231>",
|
2333 |
+
"<extra_id_232>",
|
2334 |
+
"<extra_id_233>",
|
2335 |
+
"<extra_id_234>",
|
2336 |
+
"<extra_id_235>",
|
2337 |
+
"<extra_id_236>",
|
2338 |
+
"<extra_id_237>",
|
2339 |
+
"<extra_id_238>",
|
2340 |
+
"<extra_id_239>",
|
2341 |
+
"<extra_id_240>",
|
2342 |
+
"<extra_id_241>",
|
2343 |
+
"<extra_id_242>",
|
2344 |
+
"<extra_id_243>",
|
2345 |
+
"<extra_id_244>",
|
2346 |
+
"<extra_id_245>",
|
2347 |
+
"<extra_id_246>",
|
2348 |
+
"<extra_id_247>",
|
2349 |
+
"<extra_id_248>",
|
2350 |
+
"<extra_id_249>",
|
2351 |
+
"<extra_id_250>",
|
2352 |
+
"<extra_id_251>",
|
2353 |
+
"<extra_id_252>",
|
2354 |
+
"<extra_id_253>",
|
2355 |
+
"<extra_id_254>",
|
2356 |
+
"<extra_id_255>"
|
2357 |
+
],
|
2358 |
+
"clean_up_tokenization_spaces": true,
|
2359 |
+
"cls_token": "<cls>",
|
2360 |
+
"eos_token": "</s>",
|
2361 |
+
"mask_token": "<mask>",
|
2362 |
+
"model_max_length": 1024,
|
2363 |
+
"pad_token": "<pad>",
|
2364 |
+
"sep_token": "<sep>",
|
2365 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
2366 |
+
"unk_token": "<unk>"
|
2367 |
+
}
|
trainer_state.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48108a519b2d546b6b73832c5c3752b2c0920e3ce76ab654621c6c98f2de2ef0
|
3 |
+
size 5240
|