Microsoft Open Source commited on
Commit
4443628
·
1 Parent(s): cd3817b

chore(root): Initial files upload.

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore DELETED
@@ -1,160 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/#use-with-ide
110
- .pdm.toml
111
-
112
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
- __pypackages__/
114
-
115
- # Celery stuff
116
- celerybeat-schedule
117
- celerybeat.pid
118
-
119
- # SageMath parsed files
120
- *.sage.py
121
-
122
- # Environments
123
- .env
124
- .venv
125
- env/
126
- venv/
127
- ENV/
128
- env.bak/
129
- venv.bak/
130
-
131
- # Spyder project settings
132
- .spyderproject
133
- .spyproject
134
-
135
- # Rope project settings
136
- .ropeproject
137
-
138
- # mkdocs documentation
139
- /site
140
-
141
- # mypy
142
- .mypy_cache/
143
- .dmypy.json
144
- dmypy.json
145
-
146
- # Pyre type checker
147
- .pyre/
148
-
149
- # pytype static type analyzer
150
- .pytype/
151
-
152
- # Cython debug symbols
153
- cython_debug/
154
-
155
- # PyCharm
156
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
- # and can be added to the global gitignore or merged into this file. For a more nuclear
159
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [[email protected]](mailto:[email protected]) with questions or concerns
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
README.md ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ license_link: https://huggingface.co/microsoft/Phi-3-small-8k-instruct/resolve/main/LICENSE
4
+
5
+ language:
6
+ - multilingual
7
+ pipeline_tag: text-generation
8
+ tags:
9
+ - nlp
10
+ - code
11
+ inference:
12
+ parameters:
13
+ temperature: 0.7
14
+ widget:
15
+ - messages:
16
+ - role: user
17
+ content: Can you provide ways to eat combinations of bananas and dragonfruits?
18
+ ---
19
+ ## Model Summary
20
+
21
+ The Phi-3-Small-8K-Instruct is a 7B parameters, lightweight, state-of-the-art open model trained with the Phi-3 datasets that includes both synthetic data and the filtered publicly available websites data with a focus on high-quality and reasoning dense properties.
22
+ The model belongs to the Phi-3 family with the Small version in two variants [8K](https://huggingface.co/microsoft/Phi-3-small-8k-instruct) and [128K](https://huggingface.co/microsoft/Phi-3-small-128k-instruct) which is the context length (in tokens) that it can support.
23
+
24
+ The model has underwent a post-training process that incorporates both supervised fine-tuning and direct preference optimization for the instruction following and safety measures.
25
+ When assessed against benchmarks testing common sense, language understanding, math, code, long context and logical reasoning, Phi-3-Small-8K-Instruct showcased a robust and state-of-the-art performance among models with less than 13 billion parameters.
26
+
27
+ Resources and Technical Documentation:
28
+
29
+ + [Phi-3 Microsoft Blog](https://aka.ms/phi3blog-april)
30
+ + [Phi-3 Technical Report](https://aka.ms/phi3-tech-report)
31
+ + [Phi-3 on Azure AI Studio](https://aka.ms/phi3-azure-ai)
32
+
33
+ | | Short Context | Long Context |
34
+ | ------- | ------------- | ------------ |
35
+ | Mini | 4K [[HF]](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx) ; [[GGUF]](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf) | 128K [[HF]](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct-onnx)|
36
+ | Small | 8K [[HF]](https://huggingface.co/microsoft/Phi-3-small-8k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-small-8k-instruct-onnx) | 128K [[HF]](https://huggingface.co/microsoft/Phi-3-small-128k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-small-128k-instruct-onnx)|
37
+ | Medium | 4K [[HF]](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct-onnx) | 128K [[HF]](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct-onnx)|
38
+
39
+ ## Intended Uses
40
+
41
+ **Primary use cases**
42
+
43
+ The model is intended for broad commercial and research use in English. The model provides uses for general purpose AI systems and applications which require:
44
+
45
+ 1) Memory/compute constrained environments
46
+ 2) Latency bound scenarios
47
+ 3) Strong reasoning (especially code, math and logic)
48
+
49
+ Our model is designed to accelerate research on language and multimodal models, for use as a building block for generative AI powered features.
50
+
51
+ **Use case considerations**
52
+
53
+ Our models are not specifically designed or evaluated for all downstream purposes. Developers should consider common limitations of language models as they select use cases, and evaluate and mitigate for accuracy, safety, and fariness before using within a specific downstream use case, particularly for high risk scenarios. Developers should be aware of and adhere to applicable laws or regulations (including privacy, trade compliance laws, etc.) that are relevant to their use case.
54
+
55
+ Nothing contained in this Model Card should be interpreted as or deemed a restriction or modification to the license the model is released under.
56
+
57
+ ## How to Use
58
+
59
+ Phi-3-Small-8K-Instruct has been integrated in the development version () of `transformers`. Until the official version is released through `pip`, ensure that you are doing one of the following:
60
+ * Install tiktoken (0.6.0) ans triton (2.3.0)
61
+
62
+ * When loading the model, ensure that `trust_remote_code=True` is passed as an argument of the `from_pretrained()` function.
63
+
64
+ * Update your local `transformers` to the development version: `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers`. The previous command is an alternative to cloning and installing from the source.
65
+
66
+ The current `transformers` version can be verified with: `pip list | grep transformers`.
67
+
68
+ Phi-3-Small-8K-Instruct is also available in [Azure AI](https://ai.azure.com/explore/models?&selectedCollection=phi).
69
+
70
+ ### Tokenizer
71
+
72
+ Phi-3-Small-8K-Instruct supports a vocabulary size of up to `100352` tokens.
73
+
74
+ ### Chat Format
75
+
76
+ Given the nature of the training data, the Phi-3-Small-8K-Instruct model is best suited for prompts using the chat format as follows.
77
+ You can provide the prompt as a question with a generic template as follow:
78
+ ```markdown
79
+ <|endoftext|><|user|>\nQuestion <|end|>\n<|assistant|>
80
+ ```
81
+ For example:
82
+ ```markdown
83
+ <|endoftext|><|user|>
84
+ How to explain Internet for a medieval knight?<|end|>
85
+ <|assistant|>
86
+ ```
87
+
88
+ where the model generates the text after `<|assistant|>` . In case of few-shots prompt, the prompt can be formatted as the following:
89
+
90
+ ```markdown
91
+ <|endoftext|><|user|>
92
+ I am going to Paris, what should I see?<|end|>
93
+ <|assistant|>
94
+ Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."<|end|>
95
+ <|user|>
96
+ What is so great about #1?<|end|>
97
+ <|assistant|>
98
+ ```
99
+
100
+ ### Sample inference code
101
+
102
+ This code snippets show how to get quickly started with running the model on a GPU:
103
+
104
+ ```python
105
+ import torch
106
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
107
+
108
+ torch.random.manual_seed(0)
109
+ model_id = "microsoft/Phi-3-small-8k-instruct"
110
+ model = AutoModelForCausalLM.from_pretrained(
111
+ model_id,
112
+ torch_dtype="auto",
113
+ trust_remote_code=True,
114
+ )
115
+ assert torch.cuda.is_available(), "This model needs a GPU to run ..."
116
+ device = torch.cuda.current_device()
117
+ model = model.to(device)
118
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
119
+
120
+ messages = [
121
+ {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
122
+ {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
123
+ {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
124
+ ]
125
+
126
+ pipe = pipeline(
127
+ "text-generation",
128
+ model=model,
129
+ tokenizer=tokenizer,
130
+ device=device
131
+ )
132
+
133
+ generation_args = {
134
+ "max_new_tokens": 500,
135
+ "return_full_text": False,
136
+ "temperature": 0.0,
137
+ "do_sample": False,
138
+ }
139
+
140
+ output = pipe(messages, **generation_args)
141
+ print(output[0]['generated_text'])
142
+ ```
143
+
144
+ *Some applications/frameworks might not include a BOS token (`<|endoftext|>`) at the start of the conversation. Please ensure that it is included since it provides more reliable results.*
145
+
146
+ ## Responsible AI Considerations
147
+
148
+ Like other language models, the Phi series models can potentially behave in ways that are unfair, unreliable, or offensive. Some of the limiting behaviors to be aware of include:
149
+
150
+ + Quality of Service: the Phi models are trained primarily on English text. Languages other than English will experience worse performance. English language varieties with less representation in the training data might experience worse performance than standard American English.
151
+ + Representation of Harms & Perpetuation of Stereotypes: These models can over- or under-represent groups of people, erase representation of some groups, or reinforce demeaning or negative stereotypes. Despite safety post-training, these limitations may still be present due to differing levels of representation of different groups or prevalence of examples of negative stereotypes in training data that reflect real-world patterns and societal biases.
152
+ + Inappropriate or Offensive Content: these models may produce other types of inappropriate or offensive content, which may make it inappropriate to deploy for sensitive contexts without additional mitigations that are specific to the use case.
153
+ + Information Reliability: Language models can generate nonsensical content or fabricate content that might sound reasonable but is inaccurate or outdated.
154
+ + Limited Scope for Code: Majority of Phi-3 training data is based in Python and use common packages such as "typing, math, random, collections, datetime, itertools". If the model generates Python scripts that utilize other packages or scripts in other languages, we strongly recommend users manually verify all API uses.
155
+
156
+ Developers should apply responsible AI best practices and are responsible for ensuring that a specific use case complies with relevant laws and regulations (e.g. privacy, trade, etc.). Important areas for consideration include:
157
+
158
+ + Allocation: Models may not be suitable for scenarios that could have consequential impact on legal status or the allocation of resources or life opportunities (ex: housing, employment, credit, etc.) without further assessments and additional debiasing techniques.
159
+ + High-Risk Scenarios: Developers should assess suitability of using models in high-risk scenarios where unfair, unreliable or offensive outputs might be extremely costly or lead to harm. This includes providing advice in sensitive or expert domains where accuracy and reliability are critical (ex: legal or health advice). Additional safeguards should be implemented at the application level according to the deployment context.
160
+ + Misinformation: Models may produce inaccurate information. Developers should follow transparency best practices and inform end-users they are interacting with an AI system. At the application level, developers can build feedback mechanisms and pipelines to ground responses in use-case specific, contextual information, a technique known as Retrieval Augmented Generation (RAG).
161
+ + Generation of Harmful Content: Developers should assess outputs for their context and use available safety classifiers or custom solutions appropriate for their use case.
162
+ + Misuse: Other forms of misuse such as fraud, spam, or malware production may be possible, and developers should ensure that their applications do not violate applicable laws and regulations.
163
+
164
+
165
+ ## Training
166
+
167
+ ### Model
168
+
169
+ * Architecture: Phi-3 Small-8K-Instruct has 7B parameters and is a dense decoder-only Transformer model. The model is fine-tuned with Supervised fine-tuning (SFT) and Direct Preference Optimization (DPO) to ensure alignment with human preferences and safety guidlines.
170
+ * Inputs: Text. It is best suited for prompts using chat format.
171
+ * Context length: 8K tokens
172
+ * GPUs: 1024 H100-80G
173
+ * Training time: 18 days
174
+ * Training data: 4.8T tokens
175
+ * Outputs: Generated text in response to the input
176
+ * Dates: Our models were trained between February and April 2024
177
+ * Status: This is a static model trained on an offline dataset with cutoff date October 2023. Future versions of the tuned models may be released as we improve models.
178
+ * Release dates The model weight is released on May 21, 2024.
179
+
180
+ ### Datasets
181
+
182
+ Our training data includes a wide variety of sources, totaling 4.8 trillion tokens (including 10% multilingual), and is a combination of
183
+ 1) Publicly available documents filtered rigorously for quality, selected high-quality educational data, and code;
184
+ 2) Newly created synthetic, “textbook-like” data for the purpose of teaching math, coding, common sense reasoning, general knowledge of the world (science, daily activities, theory of mind, etc.);
185
+ 3) High quality chat format supervised data covering various topics to reflect human preferences on different aspects such as instruct-following, truthfulness, honesty and helpfulness.
186
+
187
+ We are focusing on the quality of data that could potentially improve the reasoning ability for the model, and we filter the publicly available documents to contain the correct level of knowledge. As an example, the result of a game in premier league in a particular day might be good training data for frontier models, but we need to remove such information to leave more model capacity for reasoning for the small size models. More details about data can be found in the [Phi-3 Technical Report](https://aka.ms/phi3-tech-report).
188
+
189
+ ## Benchmarks
190
+
191
+ We report the results for Phi-3-Small-8K-Instruct on standard open-source benchmarks measuring the model's reasoning ability (both common sense reasoning and logical reasoning). We compare to Mixtral-8x7b, Gemini-Pro, Gemma 7B, Llama-3-8B-Instruct, GPT-3.5-Turbo-1106, and GPT-4-Turbo-1106.
192
+
193
+ All the reported numbers are produced with the exact same pipeline to ensure that the numbers are comparable. These numbers might differ from other published numbers due to slightly different choices in the evaluation.
194
+
195
+ As is now standard, we use few-shot prompts to evaluate the models, at temperature 0.
196
+ The prompts and number of shots are part of a Microsoft internal tool to evaluate language models, and in particular we did no optimization to the pipeline for Phi-3.
197
+ More specifically, we do not change prompts, pick different few-shot examples, change prompt format, or do any other form of optimization for the model.
198
+
199
+ The number of k–shot examples is listed per-benchmark.
200
+
201
+ |Benchmark|Phi-3-Small-8K-Instruct<br>7b|Gemma<br>7B|Mixtral<br>8x7B|Llama-3-Instruct<br>8b|GPT-3.5-Turbo<br>version 1106|Gemini<br>Pro|GPT-4-Turbo<br>version 1106 (Chat)|
202
+ |---------|-----------------------|--------|-------------|-------------------|-----------------|----------|------------------------|
203
+ |AGI Eval<br>5-shot|45.1|42.1|45.2|42.0|48.4|49.0|59.6|
204
+ |MMLU<br>5-shot|75.7|63.6|70.5|66.5|71.4|66.7|84.0|
205
+ |BigBench Hard<br>3-shot|79.1|59.6|69.7|51.5|68.3|75.6|87.7|
206
+ |ANLI<br>7-shot|58.1|48.7|55.2|57.3|58.1|64.2|71.7|
207
+ |HellaSwag<br>5-shot|77.0|49.8|70.4|71.1|78.8|76.2|88.3|
208
+ |ARC Challenge<br>10-shot|90.7|78.3|87.3|82.8|87.4|88.3|95.6|
209
+ |ARC Easy<br>10-shot|97.0|91.4|95.6|93.4|96.3|96.1|98.8|
210
+ |BoolQ<br>2-shot|84.8|66.0|76.6|80.9|79.1|86.4|91.3|
211
+ |CommonsenseQA<br>10-shot|80.0|76.2|78.1|79.0|79.6|81.8|86.7|
212
+ |MedQA<br>2-shot|65.4|49.6|62.2|60.5|63.4|58.2|83.7|
213
+ |OpenBookQA<br>10-shot|88.0|78.6|85.8|82.6|86.0|86.4|93.4|
214
+ |PIQA<br>5-shot|86.9|78.1|86.0|75.7|86.6|86.2|90.1|
215
+ |Social IQA<br>5-shot|79.2|65.5|75.9|73.9|68.3|75.4|81.7|
216
+ |TruthfulQA (MC2)<br>10-shot|70.2|52.1|60.1|63.2|67.7|72.6|85.2|
217
+ |WinoGrande<br>5-shot|81.5|55.6|62.0|65.0|68.8|72.2|86.7|
218
+ |TriviaQA<br>5-shot|58.1|72.3|82.2|67.7|85.8|80.2|73.3|
219
+ |GSM8K Chain of Thought<br>8-shot|89.6|59.8|64.7|77.4|78.1|80.4|94.2|
220
+ |HumanEval<br>0-shot|61.0|34.1|37.8|60.4|62.2|64.4|79.9|
221
+ |MBPP<br>3-shot|71.7|51.5|60.2|67.7|77.8|73.2|86.7|
222
+ |Average|75.7|61.8|69.8|69.4|74.3|75.4|85.2|
223
+
224
+ We take a closer look at different categories across 80 public benchmark datasets at the table below:
225
+
226
+ |Benchmark|Phi-3-Small-8K-Instruct<br>7b|Gemma<br>7B|Mixtral<br>8x7B|Llama-3-Instruct<br>8b|GPT-3.5-Turbo<br>version 1106|Gemini<br>Pro|GPT-4-Turbo<br>version 1106 (Chat)|
227
+ |--------|------------------------|--------|-------------|-------------------|-------------------|----------|------------------------|
228
+ |Popular aggregated benchmark|71.1|59.4|66.2|59.9|67.0|67.5|80.5|
229
+ |Reasoning|82.4|69.1|77.0|75.7|78.3|80.4|89.3|
230
+ |Language understanding|70.6|58.4|64.9|65.4|70.4|75.3|81.6|
231
+ |Code generation|60.7|45.6|52.7|56.4|70.4|66.7|76.1|
232
+ |Math|51.6|35.8|40.3|41.1|52.8|50.9|67.1|
233
+ |Factual knowledge|38.6|46.7|58.6|43.1|63.4|54.6|45.9|
234
+ |Multilingual|62.5|63.2|63.4|65.0|69.1|76.5|82.0|
235
+ |Robustness|72.9|38.4|51.0|64.5|69.3|69.7|84.6|
236
+
237
+
238
+ ## Software
239
+
240
+ * [PyTorch](https://github.com/pytorch/pytorch)
241
+ * [DeepSpeed](https://github.com/microsoft/DeepSpeed)
242
+ * [Transformers](https://github.com/huggingface/transformers)
243
+ * [Flash-Attention](https://github.com/HazyResearch/flash-attention)
244
+ * [Tiktoken](https://github.com/openai/tiktoken)
245
+ * [Triton](https://github.com/openai/triton)
246
+
247
+ ## Hardware
248
+ Note that by default, the Phi-3-Small model uses flash attention, which requires certain types of GPU hardware to run. We have tested on the following GPU types:
249
+ * NVIDIA A100
250
+ * NVIDIA A6000
251
+ * NVIDIA H100
252
+
253
+ If you want to run the model on:
254
+ + Optimized inference on GPU, CPU, and Mobile: use the **ONNX** models [8K](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct-onnx)
255
+
256
+
257
+ ## Cross Platform Support
258
+
259
+ ONNX runtime ecosystem now supports Phi3 small models across platforms and hardware.
260
+ Optimized phi-3 models are also published here in ONNX format, to run with ONNX Runtime on CPU and GPU across devices, including server platforms, Windows, Linux and Mac desktops, and mobile CPUs, with the precision best suited to each of these targets. DirectML GPU acceleration is supported for Windows desktops GPUs (AMD, Intel, and NVIDIA).
261
+ Along with DML, ONNX Runtime provides cross platform support for Phi3 Small across a range of devices CPU, GPU, and mobile.
262
+ Here are some of the optimized configurations we have added:
263
+
264
+ 1. ONNX models for int4 DML: Quantized to int4 via AWQ
265
+ 2. ONNX model for fp16 CUDA
266
+ 3. ONNX model for int4 CUDA: Quantized to int4 via RTN
267
+ 4. ONNX model for int4 CPU and Mobile: Quantized to int4 via RTN
268
+
269
+ ## License
270
+
271
+ The model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-small-8k/resolve/main/LICENSE).
272
+
273
+ ## Trademarks
274
+
275
+ This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
SUPPORT.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: The maintainer of this repo has not yet edited this file
2
+
3
+ **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4
+
5
+ - **No CSS support:** Fill out this template with information about how to file issues and get help.
6
+ - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7
+ - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8
+
9
+ *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10
+
11
+ # Support
12
+
13
+ ## How to file issues and get help
14
+
15
+ This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16
+ issues before filing new issues to avoid duplicates. For new issues, file your bug or
17
+ feature request as a new Issue.
18
+
19
+ For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20
+ FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21
+ CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22
+
23
+ ## Microsoft Support Policy
24
+
25
+ Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
cl100k_base.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Phi-3-small-8k-instruct",
3
+ "architectures": [
4
+ "Phi3SmallForCausalLM"
5
+ ],
6
+ "attention_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_phi3_small.Phi3SmallConfig",
9
+ "AutoModelForCausalLM": "modeling_phi3_small.Phi3SmallForCausalLM",
10
+ "AutoTokenizer": [
11
+ "tokenization_phi3_small.Phi3SmallTokenizer",
12
+ "tokenization_phi3_small.Phi3SmallTokenizer"
13
+ ]
14
+ },
15
+ "blocksparse_block_size": 64,
16
+ "blocksparse_homo_head_pattern": false,
17
+ "blocksparse_num_local_blocks": 16,
18
+ "blocksparse_triton_kernel_block_size": 64,
19
+ "blocksparse_vert_stride": 8,
20
+ "bos_token_id": 100257,
21
+ "dense_attention_every_n_layers": 2,
22
+ "embedding_dropout_prob": 0.1,
23
+ "eos_token_id": 100257,
24
+ "ff_dim_multiplier": null,
25
+ "ff_intermediate_size": 14336,
26
+ "ffn_dropout_prob": 0.1,
27
+ "gegelu_limit": 20.0,
28
+ "gegelu_pad_to_256": true,
29
+ "hidden_act": "gegelu",
30
+ "hidden_size": 4096,
31
+ "initializer_range": 0.02,
32
+ "layer_norm_epsilon": 1e-05,
33
+ "max_position_embeddings": 8192,
34
+ "model_type": "phi3small",
35
+ "mup_attn_multiplier": 1.0,
36
+ "mup_embedding_multiplier": 10.0,
37
+ "mup_use_scaling": true,
38
+ "mup_width_multiplier": 8.0,
39
+ "num_attention_heads": 32,
40
+ "num_hidden_layers": 32,
41
+ "num_key_value_heads": 8,
42
+ "pad_sequence_to_multiple_of_64": true,
43
+ "reorder_and_upcast_attn": false,
44
+ "rope_embedding_base": 1000000,
45
+ "rope_position_scale": 1.0,
46
+ "torch_dtype": "bfloat16",
47
+ "transformers_version": "4.38.1",
48
+ "use_cache": true,
49
+ "vocab_size": 100352
50
+ }
configuration_phi3_small.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ from functools import cached_property
22
+
23
+ """ Phi3Small model configuration """
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ def next_mult(x, y):
28
+ return (x + y - 1) // y * y
29
+
30
+ class Phi3SmallConfig(PretrainedConfig):
31
+ """
32
+ This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
33
+ instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a
34
+ configuration with the defaults will yield a similar configuration to that of the GPT-2
35
+ [gpt2](https://huggingface.co/gpt2) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 50257):
43
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
45
+ n_positions (`int`, *optional*, defaults to 1024):
46
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
47
+ just in case (e.g., 512 or 1024 or 2048).
48
+ n_embd (`int`, *optional*, defaults to 768):
49
+ Dimensionality of the embeddings and hidden states.
50
+ n_layer (`int`, *optional*, defaults to 12):
51
+ Number of hidden layers in the Transformer encoder.
52
+ n_head (`int`, *optional*, defaults to 12):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ n_inner (`int`, *optional*, defaults to None):
55
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
56
+ activation_function (`str`, *optional*, defaults to `"gelu"`):
57
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
58
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
59
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
60
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
61
+ The dropout ratio for the embeddings.
62
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention.
64
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
65
+ The epsilon to use in the layer normalization layers.
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ use_cache (`bool`, *optional*, defaults to `True`):
69
+ Whether or not the model should return the last key/values attentions (not used by all models).
70
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
71
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
72
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
73
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
74
+ dot-product/softmax to float() when training with mixed precision.
75
+
76
+ Example:
77
+
78
+ ```python
79
+ >>> from transformers import Phi3SmallConfig, Phi3SmallModel
80
+
81
+ >>> # Initializing a Phi3Small configuration
82
+ >>> configuration = Phi3SmallConfig()
83
+
84
+ >>> # Initializing a model (with random weights) from the configuration
85
+ >>> model = Phi3SmallModel(configuration)
86
+
87
+ >>> # Accessing the model configuration
88
+ >>> configuration = model.config
89
+ ```"""
90
+
91
+ model_type = "phi3small"
92
+ keys_to_ignore_at_inference = ["past_key_values"]
93
+
94
+
95
+ def __init__(
96
+ self,
97
+ # General information about the model
98
+ vocab_size: int =100352,
99
+ max_position_embeddings: int = 8192,
100
+ # RoPE Related Parameters
101
+ rope_embedding_base: float = 10**6,
102
+ rope_position_scale: float = 1.0,
103
+ rope_scaling: Optional[Dict[str, Union[float, List[float], int]]] = None,
104
+ # General Model Parameters
105
+ hidden_size: int = 4096,
106
+ num_hidden_layers: int = 32,
107
+ # KV Shared Attention Configurations
108
+ num_attention_heads: int = 32,
109
+ num_key_value_heads: int = 8,
110
+ # GEGELU Related Parameters
111
+ hidden_act: str = "gegelu",
112
+ gegelu_limit: float = 20.0,
113
+ gegelu_pad_to_256: bool = True,
114
+ ff_dim_multiplier: Optional[int] = None,
115
+ ff_intermediate_size: Optional[int] = 14336,
116
+ # Block Sparse Attention
117
+ blocksparse_homo_head_pattern: bool = False,
118
+ blocksparse_block_size: int = 64,
119
+ blocksparse_num_local_blocks: int = 16,
120
+ blocksparse_vert_stride: int = 8,
121
+ blocksparse_triton_kernel_block_size: int = 64,
122
+ # Frequency of block-sparsity
123
+ dense_attention_every_n_layers: Optional[int] = 2,
124
+ # Reegularization parameters
125
+ embedding_dropout_prob: float =0.1,
126
+ attention_dropout_prob: float = 0.0,
127
+ ffn_dropout_prob: float = 0.1,
128
+ layer_norm_epsilon=1e-5,
129
+ initializer_range=0.02,
130
+ # MuP parameters
131
+ mup_use_scaling: bool = True,
132
+ mup_width_multiplier: bool = 8.0,
133
+ mup_embedding_multiplier: bool = 10.0,
134
+ mup_attn_multiplier: bool =1.0,
135
+ use_cache=True,
136
+ # The model does not have a bos token id
137
+ # However, in order for some of the downstream libraries to not break
138
+ # we set this to be the same as the eos_token_id
139
+ bos_token_id: int = 100257,
140
+ eos_token_id: int = 100257,
141
+ reorder_and_upcast_attn=False,
142
+ # Configuration to pad sequence length to a multiple of 64
143
+ pad_sequence_to_multiple_of_64: bool = True,
144
+ **kwargs,
145
+ ):
146
+ self.vocab_size = vocab_size
147
+ self.max_position_embeddings = max_position_embeddings
148
+ self.rope_embedding_base = rope_embedding_base
149
+ self.rope_position_scale = rope_position_scale
150
+ self.rope_scaling = rope_scaling
151
+ self.hidden_size = hidden_size
152
+ # QK Shared Attention
153
+ self.num_hidden_layers = num_hidden_layers
154
+ self.num_attention_heads = num_attention_heads
155
+ self.num_key_value_heads = num_key_value_heads
156
+ # Block Sparse Attention Pattern
157
+ self.blocksparse_homo_head_pattern = blocksparse_homo_head_pattern
158
+ self.blocksparse_block_size = blocksparse_block_size
159
+ self.blocksparse_num_local_blocks = blocksparse_num_local_blocks
160
+ self.blocksparse_vert_stride = blocksparse_vert_stride
161
+ self.blocksparse_triton_kernel_block_size = blocksparse_triton_kernel_block_size
162
+ # Frequency of block sparsity
163
+ self.dense_attention_every_n_layers = dense_attention_every_n_layers
164
+
165
+ # Activation function
166
+ self.hidden_act = hidden_act
167
+ self.gegelu_limit = gegelu_limit
168
+ self.gegelu_pad_to_256 = gegelu_pad_to_256
169
+ self.ff_dim_multiplier = ff_dim_multiplier
170
+ self.ff_intermediate_size = ff_intermediate_size
171
+ if self.ff_dim_multiplier is None and self.ff_intermediate_size is None:
172
+ raise ValueError(f"Cannot have both {self.ff_dim_multiplier} and {self.ff_intermediate_size} as None")
173
+ if self.ff_dim_multiplier is not None and self.ff_intermediate_size is not None:
174
+ raise ValueError(f"Cannot specify both {self.ff_dim_multiplier} and {self.ff_intermediate_size}.")
175
+ # General regularization
176
+ self.embedding_dropout_prob = embedding_dropout_prob
177
+ self.attention_dropout_prob = attention_dropout_prob
178
+ self.ffn_dropout_prob = ffn_dropout_prob
179
+
180
+ self.layer_norm_epsilon = layer_norm_epsilon
181
+ self.initializer_range = initializer_range
182
+
183
+ # MuP parameters
184
+ self.mup_use_scaling = mup_use_scaling
185
+ self.mup_width_multiplier = mup_width_multiplier
186
+ self.mup_embedding_multiplier = mup_embedding_multiplier
187
+ self.mup_attn_multiplier = mup_attn_multiplier
188
+ self.use_cache = use_cache
189
+
190
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
191
+ self.pad_sequence_to_multiple_of_64 = pad_sequence_to_multiple_of_64
192
+
193
+ self.bos_token_id = bos_token_id
194
+ self.eos_token_id = eos_token_id
195
+
196
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
197
+
198
+ @cached_property
199
+ def dummy_token_indices(self) -> List[int]:
200
+ # Importing here to avoid circular imports
201
+ from .tokenization_phi3_small import Phi3SmallTokenizer
202
+ tokenizer = Phi3SmallTokenizer()
203
+ return tokenizer.dummy_token_indices
204
+
205
+ @property
206
+ def intermediate_size(self) -> int:
207
+ if self.ff_intermediate_size is not None:
208
+ return self.ff_intermediate_size
209
+ intermediate_size = (self.ff_dim_multiplier) * (self.hidden_size // 3) * 2
210
+ if self.gegelu_pad_to_256:
211
+ intermediate_size = next_mult(intermediate_size, 256)
212
+ return intermediate_size
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 100257,
4
+ "eos_token_id": [
5
+ 100257,
6
+ 100266
7
+ ],
8
+ "transformers_version": "4.38.1"
9
+ }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a8435e8fd0cc2a302f057814bb7e2650f16a4812a9b34339e3769e213276797
3
+ size 4832943104
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0be58e1371e8630fff0f8655d6be99a2dfc6ccfb4e00bc4fa85e831b8042eac6
3
+ size 4799608224
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77aa243e7aa0a19eb37eb8dabc6f30de9a779c606cb476e5ac432d742fe7e917
3
+ size 4799608240
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cb5772a577868e7e794bed074c19b0d5284a5f9a0a89b537c33623873940f3a
3
+ size 352437304
model.safetensors.index.json ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 14784548864
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
7
+ "model.final_layernorm.bias": "model-00004-of-00004.safetensors",
8
+ "model.final_layernorm.weight": "model-00004-of-00004.safetensors",
9
+ "model.layers.0.input_layernorm.bias": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.dense.bias": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.dense.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.0.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
20
+ "model.layers.0.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.0.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.input_layernorm.bias": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.1.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
29
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
30
+ "model.layers.1.self_attn.dense.bias": "model-00001-of-00004.safetensors",
31
+ "model.layers.1.self_attn.dense.weight": "model-00001-of-00004.safetensors",
32
+ "model.layers.1.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
33
+ "model.layers.1.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
34
+ "model.layers.1.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
35
+ "model.layers.10.input_layernorm.bias": "model-00002-of-00004.safetensors",
36
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.10.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
38
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.10.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
40
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
41
+ "model.layers.10.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
42
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.10.self_attn.dense.bias": "model-00002-of-00004.safetensors",
44
+ "model.layers.10.self_attn.dense.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.10.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
46
+ "model.layers.10.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.10.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
48
+ "model.layers.11.input_layernorm.bias": "model-00002-of-00004.safetensors",
49
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.11.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
51
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.11.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
53
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.11.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
55
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.11.self_attn.dense.bias": "model-00002-of-00004.safetensors",
57
+ "model.layers.11.self_attn.dense.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.11.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
59
+ "model.layers.11.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.11.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
61
+ "model.layers.12.input_layernorm.bias": "model-00002-of-00004.safetensors",
62
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.12.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
64
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
65
+ "model.layers.12.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
66
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.12.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
68
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.12.self_attn.dense.bias": "model-00002-of-00004.safetensors",
70
+ "model.layers.12.self_attn.dense.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.12.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
72
+ "model.layers.12.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.12.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
74
+ "model.layers.13.input_layernorm.bias": "model-00002-of-00004.safetensors",
75
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.13.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
77
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.13.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
79
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.13.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
81
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.13.self_attn.dense.bias": "model-00002-of-00004.safetensors",
83
+ "model.layers.13.self_attn.dense.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.13.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
85
+ "model.layers.13.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
86
+ "model.layers.13.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
87
+ "model.layers.14.input_layernorm.bias": "model-00002-of-00004.safetensors",
88
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
89
+ "model.layers.14.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
90
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
91
+ "model.layers.14.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
92
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.14.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
94
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.14.self_attn.dense.bias": "model-00002-of-00004.safetensors",
96
+ "model.layers.14.self_attn.dense.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.14.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
98
+ "model.layers.14.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.14.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
100
+ "model.layers.15.input_layernorm.bias": "model-00002-of-00004.safetensors",
101
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
102
+ "model.layers.15.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
103
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.15.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
105
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.15.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
107
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.15.self_attn.dense.bias": "model-00002-of-00004.safetensors",
109
+ "model.layers.15.self_attn.dense.weight": "model-00002-of-00004.safetensors",
110
+ "model.layers.15.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
111
+ "model.layers.15.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
112
+ "model.layers.15.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
113
+ "model.layers.16.input_layernorm.bias": "model-00002-of-00004.safetensors",
114
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
115
+ "model.layers.16.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
116
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.16.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
118
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
119
+ "model.layers.16.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
120
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.16.self_attn.dense.bias": "model-00002-of-00004.safetensors",
122
+ "model.layers.16.self_attn.dense.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.16.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
124
+ "model.layers.16.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
125
+ "model.layers.16.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
126
+ "model.layers.17.input_layernorm.bias": "model-00002-of-00004.safetensors",
127
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
128
+ "model.layers.17.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
129
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
130
+ "model.layers.17.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
131
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
132
+ "model.layers.17.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
133
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
134
+ "model.layers.17.self_attn.dense.bias": "model-00002-of-00004.safetensors",
135
+ "model.layers.17.self_attn.dense.weight": "model-00002-of-00004.safetensors",
136
+ "model.layers.17.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
137
+ "model.layers.17.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
138
+ "model.layers.17.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
139
+ "model.layers.18.input_layernorm.bias": "model-00002-of-00004.safetensors",
140
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
141
+ "model.layers.18.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
142
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
143
+ "model.layers.18.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
144
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
145
+ "model.layers.18.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
146
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
147
+ "model.layers.18.self_attn.dense.bias": "model-00002-of-00004.safetensors",
148
+ "model.layers.18.self_attn.dense.weight": "model-00002-of-00004.safetensors",
149
+ "model.layers.18.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
150
+ "model.layers.18.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
151
+ "model.layers.18.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
152
+ "model.layers.19.input_layernorm.bias": "model-00002-of-00004.safetensors",
153
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
154
+ "model.layers.19.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
155
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
156
+ "model.layers.19.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
157
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
158
+ "model.layers.19.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
159
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
160
+ "model.layers.19.self_attn.dense.bias": "model-00002-of-00004.safetensors",
161
+ "model.layers.19.self_attn.dense.weight": "model-00002-of-00004.safetensors",
162
+ "model.layers.19.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
163
+ "model.layers.19.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
164
+ "model.layers.19.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
165
+ "model.layers.2.input_layernorm.bias": "model-00001-of-00004.safetensors",
166
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
167
+ "model.layers.2.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
168
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
169
+ "model.layers.2.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
170
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
171
+ "model.layers.2.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
172
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
173
+ "model.layers.2.self_attn.dense.bias": "model-00001-of-00004.safetensors",
174
+ "model.layers.2.self_attn.dense.weight": "model-00001-of-00004.safetensors",
175
+ "model.layers.2.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
176
+ "model.layers.2.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
177
+ "model.layers.2.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
178
+ "model.layers.20.input_layernorm.bias": "model-00003-of-00004.safetensors",
179
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
180
+ "model.layers.20.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
181
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
182
+ "model.layers.20.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
183
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
184
+ "model.layers.20.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
185
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
186
+ "model.layers.20.self_attn.dense.bias": "model-00002-of-00004.safetensors",
187
+ "model.layers.20.self_attn.dense.weight": "model-00002-of-00004.safetensors",
188
+ "model.layers.20.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
189
+ "model.layers.20.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
190
+ "model.layers.20.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
191
+ "model.layers.21.input_layernorm.bias": "model-00003-of-00004.safetensors",
192
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
193
+ "model.layers.21.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
194
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.21.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
196
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
197
+ "model.layers.21.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
198
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
199
+ "model.layers.21.self_attn.dense.bias": "model-00003-of-00004.safetensors",
200
+ "model.layers.21.self_attn.dense.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.21.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
202
+ "model.layers.21.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.21.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
204
+ "model.layers.22.input_layernorm.bias": "model-00003-of-00004.safetensors",
205
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
206
+ "model.layers.22.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
207
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.22.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
209
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
210
+ "model.layers.22.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
211
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.22.self_attn.dense.bias": "model-00003-of-00004.safetensors",
213
+ "model.layers.22.self_attn.dense.weight": "model-00003-of-00004.safetensors",
214
+ "model.layers.22.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
215
+ "model.layers.22.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.22.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
217
+ "model.layers.23.input_layernorm.bias": "model-00003-of-00004.safetensors",
218
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.23.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
220
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
221
+ "model.layers.23.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
222
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
223
+ "model.layers.23.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
224
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.23.self_attn.dense.bias": "model-00003-of-00004.safetensors",
226
+ "model.layers.23.self_attn.dense.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.23.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
228
+ "model.layers.23.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.23.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
230
+ "model.layers.24.input_layernorm.bias": "model-00003-of-00004.safetensors",
231
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.24.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
233
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.24.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
235
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.24.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
237
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
238
+ "model.layers.24.self_attn.dense.bias": "model-00003-of-00004.safetensors",
239
+ "model.layers.24.self_attn.dense.weight": "model-00003-of-00004.safetensors",
240
+ "model.layers.24.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
241
+ "model.layers.24.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
242
+ "model.layers.24.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
243
+ "model.layers.25.input_layernorm.bias": "model-00003-of-00004.safetensors",
244
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
245
+ "model.layers.25.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
246
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
247
+ "model.layers.25.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
248
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.25.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
250
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.25.self_attn.dense.bias": "model-00003-of-00004.safetensors",
252
+ "model.layers.25.self_attn.dense.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.25.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
254
+ "model.layers.25.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.25.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
256
+ "model.layers.26.input_layernorm.bias": "model-00003-of-00004.safetensors",
257
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
258
+ "model.layers.26.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
259
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.26.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
261
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
262
+ "model.layers.26.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
263
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
264
+ "model.layers.26.self_attn.dense.bias": "model-00003-of-00004.safetensors",
265
+ "model.layers.26.self_attn.dense.weight": "model-00003-of-00004.safetensors",
266
+ "model.layers.26.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
267
+ "model.layers.26.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
268
+ "model.layers.26.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
269
+ "model.layers.27.input_layernorm.bias": "model-00003-of-00004.safetensors",
270
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
271
+ "model.layers.27.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
272
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
273
+ "model.layers.27.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
274
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
275
+ "model.layers.27.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
276
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
277
+ "model.layers.27.self_attn.dense.bias": "model-00003-of-00004.safetensors",
278
+ "model.layers.27.self_attn.dense.weight": "model-00003-of-00004.safetensors",
279
+ "model.layers.27.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
280
+ "model.layers.27.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
281
+ "model.layers.27.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
282
+ "model.layers.28.input_layernorm.bias": "model-00003-of-00004.safetensors",
283
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
284
+ "model.layers.28.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
285
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
286
+ "model.layers.28.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
287
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
288
+ "model.layers.28.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
289
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
290
+ "model.layers.28.self_attn.dense.bias": "model-00003-of-00004.safetensors",
291
+ "model.layers.28.self_attn.dense.weight": "model-00003-of-00004.safetensors",
292
+ "model.layers.28.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
293
+ "model.layers.28.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
294
+ "model.layers.28.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
295
+ "model.layers.29.input_layernorm.bias": "model-00003-of-00004.safetensors",
296
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
297
+ "model.layers.29.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
298
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
299
+ "model.layers.29.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
300
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
301
+ "model.layers.29.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
302
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
303
+ "model.layers.29.self_attn.dense.bias": "model-00003-of-00004.safetensors",
304
+ "model.layers.29.self_attn.dense.weight": "model-00003-of-00004.safetensors",
305
+ "model.layers.29.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
306
+ "model.layers.29.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
307
+ "model.layers.29.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
308
+ "model.layers.3.input_layernorm.bias": "model-00001-of-00004.safetensors",
309
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
310
+ "model.layers.3.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
311
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
312
+ "model.layers.3.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
313
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
314
+ "model.layers.3.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
315
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
316
+ "model.layers.3.self_attn.dense.bias": "model-00001-of-00004.safetensors",
317
+ "model.layers.3.self_attn.dense.weight": "model-00001-of-00004.safetensors",
318
+ "model.layers.3.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
319
+ "model.layers.3.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
320
+ "model.layers.3.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
321
+ "model.layers.30.input_layernorm.bias": "model-00003-of-00004.safetensors",
322
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
323
+ "model.layers.30.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
324
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
325
+ "model.layers.30.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
326
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
327
+ "model.layers.30.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
328
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
329
+ "model.layers.30.self_attn.dense.bias": "model-00003-of-00004.safetensors",
330
+ "model.layers.30.self_attn.dense.weight": "model-00003-of-00004.safetensors",
331
+ "model.layers.30.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
332
+ "model.layers.30.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
333
+ "model.layers.30.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
334
+ "model.layers.31.input_layernorm.bias": "model-00004-of-00004.safetensors",
335
+ "model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
336
+ "model.layers.31.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
337
+ "model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
338
+ "model.layers.31.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
339
+ "model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
340
+ "model.layers.31.post_attention_layernorm.bias": "model-00004-of-00004.safetensors",
341
+ "model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
342
+ "model.layers.31.self_attn.dense.bias": "model-00003-of-00004.safetensors",
343
+ "model.layers.31.self_attn.dense.weight": "model-00003-of-00004.safetensors",
344
+ "model.layers.31.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
345
+ "model.layers.31.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
346
+ "model.layers.31.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
347
+ "model.layers.4.input_layernorm.bias": "model-00001-of-00004.safetensors",
348
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
349
+ "model.layers.4.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
350
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
351
+ "model.layers.4.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
352
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
353
+ "model.layers.4.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
354
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
355
+ "model.layers.4.self_attn.dense.bias": "model-00001-of-00004.safetensors",
356
+ "model.layers.4.self_attn.dense.weight": "model-00001-of-00004.safetensors",
357
+ "model.layers.4.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
358
+ "model.layers.4.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
359
+ "model.layers.4.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
360
+ "model.layers.5.input_layernorm.bias": "model-00001-of-00004.safetensors",
361
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
362
+ "model.layers.5.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
363
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
364
+ "model.layers.5.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
365
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
366
+ "model.layers.5.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
367
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
368
+ "model.layers.5.self_attn.dense.bias": "model-00001-of-00004.safetensors",
369
+ "model.layers.5.self_attn.dense.weight": "model-00001-of-00004.safetensors",
370
+ "model.layers.5.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
371
+ "model.layers.5.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
372
+ "model.layers.5.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
373
+ "model.layers.6.input_layernorm.bias": "model-00001-of-00004.safetensors",
374
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
375
+ "model.layers.6.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
376
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
377
+ "model.layers.6.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
378
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
379
+ "model.layers.6.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
380
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
381
+ "model.layers.6.self_attn.dense.bias": "model-00001-of-00004.safetensors",
382
+ "model.layers.6.self_attn.dense.weight": "model-00001-of-00004.safetensors",
383
+ "model.layers.6.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
384
+ "model.layers.6.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
385
+ "model.layers.6.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
386
+ "model.layers.7.input_layernorm.bias": "model-00001-of-00004.safetensors",
387
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
388
+ "model.layers.7.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
389
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
390
+ "model.layers.7.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
391
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
392
+ "model.layers.7.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
393
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
394
+ "model.layers.7.self_attn.dense.bias": "model-00001-of-00004.safetensors",
395
+ "model.layers.7.self_attn.dense.weight": "model-00001-of-00004.safetensors",
396
+ "model.layers.7.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
397
+ "model.layers.7.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
398
+ "model.layers.7.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
399
+ "model.layers.8.input_layernorm.bias": "model-00001-of-00004.safetensors",
400
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
401
+ "model.layers.8.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
402
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
403
+ "model.layers.8.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
404
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
405
+ "model.layers.8.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
406
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
407
+ "model.layers.8.self_attn.dense.bias": "model-00001-of-00004.safetensors",
408
+ "model.layers.8.self_attn.dense.weight": "model-00001-of-00004.safetensors",
409
+ "model.layers.8.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
410
+ "model.layers.8.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
411
+ "model.layers.8.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
412
+ "model.layers.9.input_layernorm.bias": "model-00002-of-00004.safetensors",
413
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
414
+ "model.layers.9.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
415
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
416
+ "model.layers.9.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
417
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
418
+ "model.layers.9.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
419
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
420
+ "model.layers.9.self_attn.dense.bias": "model-00001-of-00004.safetensors",
421
+ "model.layers.9.self_attn.dense.weight": "model-00001-of-00004.safetensors",
422
+ "model.layers.9.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
423
+ "model.layers.9.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
424
+ "model.layers.9.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors"
425
+ }
426
+ }
modeling_phi3_small.py ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, Optional, List, Tuple, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ from einops import rearrange
9
+
10
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast, CausalLMOutputWithPast, BaseModelOutputWithPast
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.utils import logging
13
+
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+
16
+ from .triton_flash_blocksparse_attn import BlockSparseParams
17
+ from .triton_blocksparse_attention_layer import BlockSparseAttentionLayer
18
+ from .positional_embedding import RotaryEmbedding
19
+
20
+ from .configuration_phi3_small import Phi3SmallConfig
21
+
22
+ # Flash Attention Related Imports
23
+ is_flash_attention_available = False
24
+ try:
25
+ import flash_attn
26
+ if int(flash_attn.__version__.split('.')[0]) < 2:
27
+ from flash_attn.flash_attn_interface import (
28
+ flash_attn_func,
29
+ flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
30
+ )
31
+
32
+ # rename `max_seqlen`
33
+ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, **kwargs):
34
+ return flash_attn_func(qkv, cu_seqlens, dropout_p=dropout_p, max_s=max_seqlen, **kwargs)
35
+
36
+ else:
37
+ from flash_attn.flash_attn_interface import (
38
+ flash_attn_varlen_kvpacked_func,
39
+ )
40
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
41
+ is_flash_attention_available = True
42
+ except ImportError:
43
+ pass
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ LegacyCache = Tuple[Tuple[torch.FloatTensor]]
48
+
49
+ # Taken from https://github.com/allenai/allennlp/blob/main/allennlp/nn/util.py
50
+ def info_value_of_dtype(dtype: torch.dtype):
51
+ """
52
+ Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool.
53
+ """
54
+ if dtype == torch.bool:
55
+ raise TypeError("Does not support torch.bool")
56
+ elif dtype.is_floating_point:
57
+ return torch.finfo(dtype)
58
+ else:
59
+ return torch.iinfo(dtype)
60
+
61
+
62
+ # Taken from https://github.com/allenai/allennlp/blob/main/allennlp/nn/util.py
63
+ def min_value_of_dtype(dtype: torch.dtype):
64
+ """
65
+ Returns the minimum value of a given PyTorch data type. Does not allow torch.bool.
66
+ """
67
+ return info_value_of_dtype(dtype).min
68
+
69
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
70
+ def _get_unpad_data(attention_mask):
71
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
72
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
73
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
74
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
75
+ return (
76
+ indices,
77
+ cu_seqlens,
78
+ max_seqlen_in_batch,
79
+ )
80
+
81
+
82
+ @torch.jit.script
83
+ def quick_gelu(x):
84
+ return x * torch.sigmoid(1.702 * x)
85
+
86
+
87
+ @torch.jit.script
88
+ def gegelu(input, limit: Optional[float] = None):
89
+ a_gelu, a_linear = input[..., ::2], input[..., 1::2]
90
+ if limit is not None:
91
+ a_gelu = torch.where(
92
+ torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
93
+ )
94
+ a_linear = torch.where(
95
+ torch.isinf(a_linear), a_linear, a_linear.clamp(min=-limit, max=limit)
96
+ )
97
+ out_gelu = quick_gelu(a_gelu)
98
+ return out_gelu * (a_linear + 1)
99
+
100
+ def collapse_first_n_dims(x: torch.Tensor, n: int) -> torch.Tensor:
101
+ """
102
+ Collapse the first `n` dimensions of a tensor into a single dimension.
103
+
104
+ Args:
105
+ x (torch.Tensor): The input tensor.
106
+ n (int): The number of dimensions to collapse.
107
+
108
+ Returns:
109
+ torch.Tensor: The output tensor.
110
+ """
111
+ return x.view(-1, *x.shape[n:])
112
+
113
+ def pad_tensor_to_next_mult_of(
114
+ tensor: torch.Tensor,
115
+ dim: int,
116
+ n: int,
117
+ ) -> Tuple[torch.Tensor, int]:
118
+ """
119
+ Pads a tensor along a specified dimension to the next multiple of a given number.
120
+
121
+ Args:
122
+ tensor (torch.Tensor): The input tensor.
123
+ dim (int): The dimension along which to pad the tensor.
124
+ n (int): The number to pad the tensor to the next multiple of.
125
+
126
+ Returns:
127
+ Tuple[torch.Tensor, int]: A tuple containing the padded tensor and the amount of padding added.
128
+ """
129
+ residual = tensor.size(dim) % n
130
+ if residual == 0:
131
+ return tensor, 0
132
+ padding = n - residual
133
+ padding_tensor = torch.zeros((*tensor.size()[:dim], padding, *tensor.size()[dim + 1:]), device=tensor.device, dtype=tensor.dtype)
134
+ return torch.cat([tensor, padding_tensor], dim=dim), padding
135
+
136
+ def strip_padding_from_tensor(
137
+ tensor: torch.Tensor,
138
+ dim: int,
139
+ residual: int,
140
+ ) -> torch.Tensor:
141
+ """
142
+ Removes padding from a tensor along a specified dimension.
143
+
144
+ Args:
145
+ tensor (torch.Tensor): The input tensor.
146
+ dim (int): The dimension along which to remove padding.
147
+ residual (int): The amount of padding to remove.
148
+
149
+ Returns:
150
+ torch.Tensor: The tensor with padding removed along the specified dimension.
151
+ """
152
+ return torch.narrow(tensor, dim, 0, tensor.size(dim) - residual)
153
+
154
+ class Phi3SmallMLP(nn.Module):
155
+ def __init__(self, config: Phi3SmallConfig):
156
+ super().__init__()
157
+ self.config = config
158
+ assert self.config.hidden_act == "gegelu", "Only `gegelu` is supported for the 4.7 series of models .."
159
+ self.hidden_size = config.hidden_size
160
+ self.gegelu_limit = config.gegelu_limit
161
+ self.intermediate_size = config.intermediate_size
162
+
163
+ self.up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size)
164
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
165
+ self.dropout = nn.Dropout(config.ffn_dropout_prob)
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ return self.dropout(
169
+ self.down_proj(
170
+ gegelu(self.up_proj(x), limit=self.gegelu_limit)
171
+ )
172
+ )
173
+
174
+
175
+ class Phi3SmallSelfAttention(nn.Module):
176
+ def __init__(self, config: Phi3SmallConfig, layer_idx: Optional[int] = None) -> None:
177
+ super().__init__()
178
+ self.config = config
179
+ self.layer_idx = layer_idx
180
+ if layer_idx is None:
181
+ logger.warning_once(
182
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
183
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
184
+ "when creating this class."
185
+ )
186
+
187
+ self.hidden_size = config.hidden_size
188
+ # Number of Query Heads
189
+ self.num_heads = config.num_attention_heads
190
+ self.head_dim = self.hidden_size // self.num_heads
191
+ # Number of Key Value Heads
192
+ self.num_key_value_heads = config.num_key_value_heads
193
+ self.num_q_per_kv = self.num_heads // self.num_key_value_heads
194
+ self.max_position_embeddings = config.max_position_embeddings
195
+ self.rope_embedding_base = config.rope_embedding_base
196
+ self.rope_position_scale = config.rope_position_scale
197
+ self.is_causal = True
198
+
199
+ self.attention_dropout_rate = config.attention_dropout_prob
200
+
201
+ norm_factor = None
202
+ if config.mup_use_scaling:
203
+ norm_factor = self.head_dim / config.mup_attn_multiplier
204
+ else:
205
+ norm_factor = math.sqrt(self.head_dim)
206
+ self.softmax_scale = 1.0 / norm_factor
207
+
208
+ self.query_key_value = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim)
209
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
210
+
211
+ self.blocksparse_params = None
212
+ # layer_idx is 0 indexed because that's what the KV Cache expects.
213
+ if self.config.dense_attention_every_n_layers and ((self.layer_idx + 1) % self.config.dense_attention_every_n_layers == 0):
214
+ logger.info(
215
+ f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
216
+ f"{self.config.dense_attention_every_n_layers}"
217
+ )
218
+ assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
219
+ else:
220
+ # BlockSparse related Parameters
221
+ self.blocksparse_params = BlockSparseParams.from_config(config)
222
+
223
+ if self.blocksparse:
224
+ active_head_range = None
225
+ """
226
+ ... note(bapatra)::
227
+
228
+ In case of tensor parallelism and while using the heterogeneous head patterns,
229
+ the active head range needs to be modified based on the tensor parallel rank
230
+ and the tensor parallel world size.
231
+
232
+ This is because in the case of heterogeneous head patterns, the kernel needs to know
233
+ which head is on which device, so that it can pick the corresponding blocksparse head
234
+ pattern correctly.
235
+
236
+ Example:
237
+ ```python
238
+
239
+ if not self.blocksparse_params.homo_head_pattern:
240
+ tp_rank = torch.distributed.get_rank() % tp_world_size
241
+ num_heads_per_partition = num_heads // tp_world_size
242
+ active_head_range = (tp_rank * num_heads_per_partition, (tp_rank + 1) * num_heads_per_partition)
243
+
244
+ ```
245
+
246
+ """
247
+
248
+ self._blocksparse_layer = BlockSparseAttentionLayer(
249
+ n_heads=self.num_heads,
250
+ max_seq_len=self.max_position_embeddings,
251
+ sparse_block_size=self.blocksparse_params.block_size,
252
+ local_blocks=self.blocksparse_params.num_local_blocks,
253
+ vert_stride=self.blocksparse_params.vert_stride,
254
+ kernel_block_size=self.blocksparse_params.kernel_block_size,
255
+ homo_head=self.blocksparse_params.homo_head_pattern,
256
+ active_head_range=active_head_range,
257
+ )
258
+ self.rotary_emb = RotaryEmbedding.from_config(config)
259
+
260
+
261
+ @property
262
+ def blocksparse(self):
263
+ return self.blocksparse_params is not None
264
+
265
+ def _split_heads(self, mixed_x_layer: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
266
+ bs, sq, _ = mixed_x_layer.size()
267
+ r"""
268
+ The main idea is that we group tensors as
269
+ [bs, sq, (q00, q01, ... q0m, k0, v0), (q10, q11, ... q1m, k1, v1), ... (qn0, qn1, ... qnm, kn, vn)]
270
+ That ways, when the MP column sharding happens, this tensor will be sharded keeping all the
271
+ queries and keys intact. In order to get the correct qkv, we first break into groups, and then
272
+ index into the groups.
273
+ """
274
+
275
+ intermediate_shape = (bs, sq, -1, (self.num_q_per_kv + 2), self.head_dim)
276
+ mixed_x_layer = mixed_x_layer.view(*intermediate_shape)
277
+ q = mixed_x_layer[:, :, :, :-2]
278
+ k = mixed_x_layer[:, :, :, [-2]]
279
+ v = mixed_x_layer[:, :, :, [-1]]
280
+ q, k, v = [
281
+ rearrange(
282
+ x,
283
+ "bs sq group nh hn -> bs sq (group nh) hn"
284
+ ) for x in (q, k, v)
285
+ ]
286
+ return q, k, v
287
+
288
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._unpad_input
289
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
290
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
291
+
292
+
293
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
294
+
295
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
296
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
297
+
298
+ if query_length == kv_seq_len:
299
+ query_layer = index_first_axis(
300
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
301
+ )
302
+ cu_seqlens_q = cu_seqlens_k
303
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
304
+ indices_q = indices_k
305
+ elif query_length == 1:
306
+ max_seqlen_in_batch_q = 1
307
+ cu_seqlens_q = torch.arange(
308
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
309
+ ) # There is a memcpy here, that is very bad.
310
+ indices_q = cu_seqlens_q[:-1]
311
+ query_layer = query_layer.squeeze(1)
312
+ else:
313
+ # The -q_len: slice assumes left padding.
314
+ attention_mask = attention_mask[:, -query_length:]
315
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
316
+
317
+ return (
318
+ query_layer,
319
+ key_layer,
320
+ value_layer,
321
+ indices_q,
322
+ (cu_seqlens_q, cu_seqlens_k),
323
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
324
+ )
325
+
326
+ def _apply_blocksparse_attention(
327
+ self,
328
+ q: torch.Tensor,
329
+ k: torch.Tensor,
330
+ v: torch.Tensor,
331
+ attention_mask: Optional[torch.LongTensor],
332
+ return_attention_probs: bool = False,
333
+ ) -> torch.Tensor:
334
+ """
335
+ Applies blocksparse attention to the input tensors.
336
+
337
+ Args:
338
+ q (torch.Tensor): The query tensor of shape (bs, nqp, seq_len, hn).
339
+ k (torch.Tensor): The key tensor of shape (bs, nkp, seq_len, hn).
340
+ v (torch.Tensor): The value tensor of shape (bs, nkp, seq_len, hn).
341
+ attention_mask (Optional[torch.LongTensor]): The attention mask tensor of shape (bs, seq_len).
342
+ return_attention_probs (bool, optional): Whether to return attention probabilities. Defaults to False.
343
+
344
+ Returns:
345
+ torch.Tensor: The context layer tensor of shape (bs, nqp, seq_len, hn).
346
+ """
347
+ assert not return_attention_probs, "return_attention_probs is not supported for blocksparse attention"
348
+ q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
349
+ # shape: (bs, nqp, seq_len, hn)
350
+ if torch.is_grad_enabled():
351
+ # Training or non-batched inference
352
+ context_layer = self._blocksparse_layer(
353
+ q=q, k=k, v=v, sm_scale=self.softmax_scale
354
+ )
355
+ elif attention_mask is None:
356
+ if q.size(0) != 1:
357
+ logger.warning_once(
358
+ "You are attempting to do batched inference without passing the attention mask.\n"
359
+ "This is okay if you are running loglikelihood requests. However, if you want to do generation, "
360
+ "this probably won't work as expected. Please pass the attention mask to the forward function."
361
+ )
362
+ context_layer = self._blocksparse_layer(
363
+ q=q, k=k, v=v, sm_scale=self.softmax_scale
364
+ )
365
+ else:
366
+ """
367
+ Shapes of tensors are as follows:
368
+ q: (bs, nqp, seq_len, hdim)
369
+ k: (bs, nkp, seq_len, hdim)
370
+ v: (bs, nkp, seq_len, hdim)
371
+ We first need to transpose the shapes to fit what the
372
+ kernel needs, and the reinvert it back at the end of the operations
373
+ """
374
+ assert attention_mask.ndim == 2, "The kernel, like flash-attention-2, only supports 2d attention masks ..."
375
+ left_paddings = attention_mask.shape[1] - attention_mask.sum(dim=-1)
376
+ # shape: (bs, seq_len, nqp, hdim)
377
+ q = q.transpose(1, 2).contiguous()
378
+ # shape: (bs, seq_len, nkp, hdim)
379
+ k = k.transpose(1, 2).contiguous()
380
+ # shape: (bs, seq_len, nkp, hdim)
381
+ v = v.transpose(1, 2).contiguous()
382
+ context_layer = self._blocksparse_layer(
383
+ q=q, k=k, v=v, sm_scale=self.softmax_scale, left_paddings=left_paddings.to(torch.int32)
384
+ )
385
+ # shape: (bs, nqp, seq_len, hdim)
386
+ context_layer = context_layer.transpose(1, 2).contiguous()
387
+ return context_layer
388
+
389
+ def _apply_dense_attention(
390
+ self,
391
+ q: torch.Tensor,
392
+ k: torch.Tensor,
393
+ v: torch.Tensor,
394
+ attention_mask: torch.Tensor,
395
+ return_attention_probs: bool = False,
396
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
397
+ """
398
+ Apply dense attention
399
+
400
+ Args:
401
+ q (torch.Tensor):
402
+ The query tensor, shape: (bs, num_query_heads, seq_len, head_size)
403
+ k (torch.Tensor):
404
+ The key tensor, shape: (bs, num_query_heads, seq_len, head_size)
405
+ v (torch.Tensor):
406
+ The value tensor, shape: (bs, num_query_heads, seq_len, head_size)
407
+
408
+ return_attention_probs (bool, optional):
409
+ Return the attention probabilities. Defaults to False.
410
+
411
+ Returns:
412
+ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
413
+ Return the output of the attention aggregation. If `return_attention_probs` is True, then
414
+ also return the attention probabilities
415
+
416
+ .. note::
417
+ Right now, am assuming the expansion for the query key values is already done
418
+ outside. But ideally, since Flash attention handles the MQA correctly, we can
419
+ avoid doing that.
420
+
421
+ """
422
+ attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
423
+ # Get into the correct shape for the Flash Attention API
424
+ # shape: (bs, seq_len, nqp, hn)
425
+ q = q.transpose(1, 2).contiguous()
426
+ query_length = q.size(1)
427
+ # shape: (bs, seq_len, npq, hn)
428
+ k = k.transpose(1, 2).contiguous()
429
+ # shape: (bs, seq_len, npq, hn)
430
+ v = v.transpose(1, 2).contiguous()
431
+
432
+ if attention_mask is not None:
433
+ causal = q.size(2) == k.size(2)
434
+ batch_size = q.shape[0]
435
+ flat_q, flat_k, flat_v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
436
+ q, k, v, attention_mask, query_length
437
+ )
438
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
439
+ max_seqlen_q, max_seqlen_k = max_seq_lens
440
+ flat_kv = torch.cat((flat_k.unsqueeze(1), flat_v.unsqueeze(1)), dim=1)
441
+ attn_output_unpad = flash_attn_varlen_kvpacked_func(
442
+ q=flat_q,
443
+ kv=flat_kv,
444
+ cu_seqlens_q=cu_seqlens_q,
445
+ cu_seqlens_k=cu_seqlens_k,
446
+ max_seqlen_q=max_seqlen_q,
447
+ max_seqlen_k=max_seqlen_k,
448
+ dropout_p=attention_dropout_prob,
449
+ softmax_scale=self.softmax_scale,
450
+ causal=causal,
451
+ return_attn_probs=return_attention_probs
452
+ )
453
+ attention_output = pad_input(
454
+ attn_output_unpad, indices_q, batch_size, query_length
455
+ )
456
+ else:
457
+ kv = torch.cat((k.unsqueeze(2), v.unsqueeze(2)), dim=2)
458
+ cu_seqlens_q = torch.arange(
459
+ 0, (q.size(0) + 1), device=q.device, dtype=torch.int32
460
+ ) * q.size(1)
461
+ cu_seqlens_kv = torch.arange(
462
+ 0, (kv.size(0) + 1), device=kv.device, dtype=torch.int32
463
+ ) * kv.size(1)
464
+ max_seqlen_q = q.size(1)
465
+ max_seqlen_k = kv.size(1)
466
+ attention_output = flash_attn_varlen_kvpacked_func(
467
+ q=collapse_first_n_dims(q, 2),
468
+ kv=collapse_first_n_dims(kv, 2),
469
+ cu_seqlens_q=cu_seqlens_q,
470
+ cu_seqlens_k=cu_seqlens_kv,
471
+ max_seqlen_q=max_seqlen_q,
472
+ max_seqlen_k=max_seqlen_k,
473
+ dropout_p=attention_dropout_prob,
474
+ softmax_scale=self.softmax_scale,
475
+ causal=q.size(1) == kv.size(1),
476
+ return_attn_probs=return_attention_probs
477
+ )
478
+ if return_attention_probs:
479
+ (context_layer, attn_probs) = attention_output
480
+ context_layer = context_layer.view(q.size(0), q.size(1), -1, q.size(3)).transpose(1, 2).contiguous()
481
+ return (context_layer, attn_probs)
482
+ context_layer = attention_output
483
+ context_layer = context_layer.view(q.size(0), q.size(1), -1, q.size(3)).transpose(1, 2).contiguous()
484
+ return context_layer
485
+
486
+
487
+ def expand_kv_to_q_size(self, kv: torch.Tensor, num_q_per_kv: int) -> torch.Tensor:
488
+ """
489
+ Expand the key-value tensor to match the size of the query tensor.
490
+
491
+ Args:
492
+ kv (torch.Tensor): The key-value tensor of shape (bsz, nkp, 2, seq_len, hdim).
493
+ num_q_per_kv (int): The number of queries per key-value.
494
+
495
+ Returns:
496
+ torch.Tensor: The expanded key-value tensor of shape (bsz, nqp, 2, seq_len, hdim).
497
+ Where nqp = num_q_per_kv * nkp
498
+
499
+ .. note::
500
+ Right now, I am using a repeat_interleave to expand the kv to the size of q.
501
+ This incurs a memory penalty, since the tensors are actually copied.
502
+ TODO: If this does yield benefits, then potentially we can use the re-written
503
+ flash attention kernel that can handle the MQA.
504
+ """
505
+
506
+ repeats = torch.tensor([num_q_per_kv] * kv.size(1)).to(kv.device)
507
+ total = repeats.sum()
508
+ expanded_kv = torch.repeat_interleave(
509
+ kv,
510
+ repeats=repeats,
511
+ dim=1,
512
+ output_size=total
513
+ )
514
+ return expanded_kv
515
+
516
+ def forward(
517
+ self,
518
+ hidden_states: torch.Tensor,
519
+ attention_mask: Optional[torch.Tensor] = None,
520
+ position_ids: Optional[torch.LongTensor] = None,
521
+ past_key_values: Optional[Cache] = None,
522
+ output_attentions: bool = False,
523
+ use_cache: bool = False,
524
+ **kwargs,
525
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
526
+ """
527
+ The forward function of the Self Attention Layer.
528
+
529
+ Args:
530
+ hidden_states (torch.Tensor):
531
+ The input tensor of shape (bs, q_len, h).
532
+ attention_mask (Optional[torch.Tensor], optional):
533
+ The attention mask tensor of shape (bs, seq_len). This is the 2D attention mask tensor as is standard in the flash-attention
534
+ kernel.
535
+ Defaults to None.
536
+ position_ids (Optional[torch.LongTensor], optional):
537
+ The position ids tensor of shape (bs, q_len). Defaults to None. Unused by the function.
538
+ past_key_value (Optional[Cache], optional):
539
+ The previous kv cache values. Defaults to None.
540
+ output_attentions (bool, optional):
541
+ Whether to return the attention scores. Defaults to False.
542
+ .. note::
543
+ For the blocksparse attention kernel, we do not support returning the attention scores.
544
+ use_cache (bool, optional):
545
+ Whether to use the cache for storing the kv. Defaults to False.
546
+
547
+ Returns:
548
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
549
+ The output tensor of shape (bs, q_len, h),
550
+ the attention scores tensor of shape (bs, nqp, q_len, seq_len) if `output_attentions` is True,
551
+ and the updated cache values if `use_cache` is True.
552
+
553
+ Notations:
554
+ ------------
555
+ bs: batch size
556
+ sq_len: sequence length of the entire sequence
557
+ q_len: sequence length of the query
558
+ cache_sq: sequence length in the cache
559
+ If there is no cache then cache_sq = 0
560
+ and sq_len = q_len
561
+ otherwise sq_len = q_len + cache_sq
562
+ h: hidden size
563
+ nq: number of query heads
564
+ nkv: number of key heads
565
+ hn: hidden size per head
566
+ hn = h // nq
567
+ nqp: number of query heads (per MP partition)
568
+ nqp = nq // (num mp partitions)
569
+ nkvp: number of key-value heads (per MP partition)
570
+ nkvp = nk // (num mp partitions)
571
+
572
+ """
573
+ # shape: (bs, q_len, h)
574
+ bsz, q_len, _ = hidden_states.size()
575
+
576
+ # shape: (bs, q_len, (nqp + 2 * nkvp) * hn)
577
+ mixed_x_layer = self.query_key_value(hidden_states)
578
+ # shape: (bs, q_len, nqp, hn), shape: (bs, q_len, nkvp, hn), shape: (bs, q_len, nkvp, hn)
579
+ q, k, v = self._split_heads(mixed_x_layer)
580
+
581
+ # shape: (bs, qnp, q_len, hn)
582
+ query_states = q.permute(0, 2, 1, 3).contiguous()
583
+ # shape: (bs, nkvp, q_len, hn)
584
+ key_states = k.permute(0, 2, 1, 3).contiguous()
585
+ # shape: (bs, nkvp, q_len, hn)
586
+ value_states = v.permute(0, 2, 1, 3).contiguous()
587
+
588
+ kv_seq_len = key_states.shape[-2]
589
+ if past_key_values is not None:
590
+ if self.layer_idx is None:
591
+ raise ValueError(
592
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
593
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
594
+ "with a layer index."
595
+ )
596
+ if self.rotary_emb is not None:
597
+ seqlen_offset = past_key_values.get_usable_length(kv_seq_len, layer_idx=self.layer_idx)
598
+ # shape: (bs, nqp, q_len, hn), shape: (bs, nkvp, q_len, hn)
599
+ query_states, key_states = self.rotary_emb(
600
+ query_states, key_states, seq_dimension=2, seqlen_offset=seqlen_offset
601
+ )
602
+ key_states, value_states = past_key_values.update(key_states=key_states, value_states=value_states, layer_idx=self.layer_idx)
603
+ else:
604
+ # In this case seq_len = q_len and cache_sq = 0
605
+ if self.rotary_emb is not None:
606
+ # shape: (bs, nqp, seq_len, hn), shape: (bs, nkvp, seq_len, hn)
607
+ query_states, key_states = self.rotary_emb(query_states, key_states, seq_dimension=2)
608
+
609
+ # shape: (bs, nkvp, 2, seq_len, hn)
610
+ kv_states = torch.cat((key_states.unsqueeze(2), value_states.unsqueeze(2)), dim=2)
611
+ # shape: (bs, nqp, 2, seq_len, hn)
612
+ expanded_kv_states = self.expand_kv_to_q_size(kv_states, num_q_per_kv=self.num_q_per_kv)
613
+ # shape: (bs, nqp, seq_len, hn), shape: (bs, nqp, seq_len, hn)
614
+ expanded_key_states, expanded_value_states = expanded_kv_states[:, :, 0], expanded_kv_states[:, :, 1]
615
+ if self.blocksparse:
616
+ attn_function_output = self._apply_blocksparse_attention(
617
+ q=query_states,
618
+ k=expanded_key_states,
619
+ v=expanded_value_states,
620
+ attention_mask=attention_mask,
621
+ return_attention_probs=output_attentions
622
+ )
623
+ else:
624
+ attn_function_output = self._apply_dense_attention(
625
+ q=query_states,
626
+ k=expanded_key_states,
627
+ v=expanded_value_states,
628
+ attention_mask=attention_mask,
629
+ return_attention_probs=output_attentions
630
+ )
631
+
632
+ attn_weights = None
633
+ if output_attentions:
634
+ attn_output, attn_weights = attn_function_output
635
+ else:
636
+ # shape: (bs, nqp, seq_len, hn)
637
+ attn_output = attn_function_output
638
+ # shape: (bs, seq_len, nqp, hn)
639
+ attn_output = attn_output.transpose(1, 2).contiguous()
640
+
641
+ # shape: (bs, seq_len, h)
642
+ attn_output = attn_output.view(bsz, q_len, -1)
643
+ attn_output = self.dense(attn_output)
644
+ return attn_output, attn_weights, past_key_values
645
+
646
+
647
+ class Phi3SmallDecoderLayer(nn.Module):
648
+ def __init__(self, config: Phi3SmallConfig, layer_idx: int):
649
+ super().__init__()
650
+ self.hidden_size = config.hidden_size
651
+ self.self_attn = Phi3SmallSelfAttention(config, layer_idx)
652
+ self.mlp = Phi3SmallMLP(config)
653
+
654
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
655
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
656
+
657
+ def forward(
658
+ self,
659
+ hidden_states: torch.Tensor,
660
+ attention_mask: Optional[torch.Tensor] = None,
661
+ position_ids: Optional[torch.LongTensor] = None,
662
+ past_key_values: Optional[Cache] = None,
663
+ output_attentions: Optional[bool] = None,
664
+ use_cache: Optional[bool] = None,
665
+ **kwargs,
666
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Cache]]:
667
+ residual = hidden_states
668
+ hidden_states = self.input_layernorm(hidden_states)
669
+
670
+ # Self Attention
671
+ hidden_states, self_attn_weights, present_key_values = self.self_attn(
672
+ hidden_states=hidden_states,
673
+ attention_mask=attention_mask,
674
+ position_ids=position_ids,
675
+ past_key_values=past_key_values,
676
+ output_attentions=output_attentions,
677
+ use_cache=use_cache,
678
+ )
679
+ hidden_states = residual + hidden_states
680
+
681
+ # Fully Connected
682
+ residual = hidden_states
683
+ hidden_states = self.post_attention_layernorm(hidden_states)
684
+ hidden_states = self.mlp(hidden_states)
685
+ hidden_states = residual + hidden_states
686
+
687
+ outputs = (hidden_states,)
688
+
689
+ if output_attentions:
690
+ outputs += (self_attn_weights,)
691
+
692
+ if use_cache:
693
+ outputs += (present_key_values,)
694
+
695
+ return outputs
696
+
697
+
698
+
699
+ class Phi3SmallPreTrainedModel(PreTrainedModel):
700
+ config_class = Phi3SmallConfig
701
+ base_model_prefix = "model"
702
+ supports_gradient_checkpointing = True
703
+ _no_split_modules = ["Phi3SmallDecoderLayer"]
704
+ skip_keys_device_placement = "past_key_values"
705
+ _supports_flash_attn_2 = True
706
+ _supports_sdpa = False
707
+ _supports_cache_class = True
708
+
709
+ def _init_weights(self, module: nn.Module):
710
+ std = self.config.initializer_range
711
+ if isinstance(module, nn.Linear):
712
+ # Slightly different from the TF version which uses truncated_normal for initialization
713
+ # cf https://github.com/pytorch/pytorch/pull/5617
714
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
715
+ elif isinstance(module, nn.Embedding):
716
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
717
+ if module.padding_idx is not None:
718
+ module.weight.data[module.padding_idx].zero_()
719
+ elif isinstance(module, nn.LayerNorm):
720
+ module.bias.data.zero_()
721
+ module.weight.data.fill_(1.0)
722
+
723
+ # The output projection on the decoder attention layer as well as the down_proj in the MLP are scaled
724
+ # differently (dubbed `output_layer_init_method` in the Megatron code). This is replicated here
725
+ for name, p in module.named_parameters():
726
+ if any(x in name for x in ("c_proj.weight", "down_proj.weight", "o_proj.weight")):
727
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
728
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers)))
729
+
730
+
731
+ class Phi3SmallModel(Phi3SmallPreTrainedModel):
732
+
733
+ def __init__(self, config):
734
+ super().__init__(config)
735
+ self.config = config
736
+
737
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
738
+
739
+ # Embedding Dropout
740
+ self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob)
741
+
742
+ # MuP Embedding scaling
743
+ self.mup_embedding_multiplier = config.mup_embedding_multiplier
744
+
745
+ self.layers = nn.ModuleList([Phi3SmallDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
746
+
747
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
748
+
749
+ self.gradient_checkpointing = False
750
+
751
+ # Initialize weights and apply final processing
752
+ self.post_init()
753
+
754
+ def get_input_embeddings(self):
755
+ return self.embed_tokens
756
+
757
+ def set_input_embeddings(self, value):
758
+ self.embed_tokens = value
759
+
760
+ @property
761
+ def pad_sequence_to_multiple_of_64(self):
762
+ # We only need to do this for the backward pass. So only required
763
+ # when we are in the context of generating gradients
764
+ return self.config.pad_sequence_to_multiple_of_64 and torch.is_grad_enabled()
765
+
766
+ def forward(
767
+ self,
768
+ input_ids: torch.LongTensor = None,
769
+ attention_mask: Optional[torch.Tensor] = None,
770
+ position_ids: Optional[torch.LongTensor] = None,
771
+ past_key_values: Optional[Union[Cache, LegacyCache]] = None,
772
+ inputs_embeds: Optional[torch.FloatTensor] = None,
773
+ use_cache: Optional[bool] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ output_hidden_states: Optional[bool] = None,
776
+ return_dict: Optional[bool] = None,
777
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
778
+
779
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
780
+ output_hidden_states = (
781
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
782
+ )
783
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
784
+
785
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
786
+
787
+ if input_ids is not None and inputs_embeds is not None:
788
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
789
+ elif input_ids is not None:
790
+ batch_size, seq_length = input_ids.shape
791
+ elif inputs_embeds is not None:
792
+ batch_size, seq_length, _ = inputs_embeds.shape
793
+ else:
794
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
795
+
796
+ if self.gradient_checkpointing and self.training:
797
+ if use_cache:
798
+ logger.warning_once(
799
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
800
+ )
801
+ use_cache = False
802
+
803
+ past_key_values_length = 0
804
+
805
+ if use_cache:
806
+ use_legacy_cache = not isinstance(past_key_values, Cache)
807
+ if use_legacy_cache:
808
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
809
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
810
+
811
+ if position_ids is None:
812
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
813
+ position_ids = torch.arange(
814
+ past_key_values_length, past_key_values_length + seq_length, dtype=torch.long, device=device
815
+ )
816
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
817
+ else:
818
+ position_ids = position_ids.view(-1, seq_length).long()
819
+
820
+ if attention_mask is not None:
821
+ if batch_size <= 0:
822
+ raise ValueError("batch_size has to be defined and > 0")
823
+
824
+ if inputs_embeds is None:
825
+ inputs_embeds = self.embed_tokens(input_ids)
826
+ inputs_embeds = self.embedding_dropout(inputs_embeds)
827
+
828
+ if self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0:
829
+ inputs_embeds = inputs_embeds * self.mup_embedding_multiplier
830
+
831
+ residual = 0
832
+ if self.pad_sequence_to_multiple_of_64:
833
+ # note(bapatra): Since we don't particularly use the position_ids and the attention mask
834
+ # we don't need to pad them
835
+ inputs_embeds, residual = pad_tensor_to_next_mult_of(tensor=inputs_embeds, dim=1, n=64)
836
+
837
+ hidden_states = inputs_embeds
838
+
839
+ # decoder layers
840
+ all_hidden_states = () if output_hidden_states else None
841
+ all_self_attns = () if output_attentions else None
842
+ next_decoder_cache = None
843
+
844
+ for decoder_layer in self.layers:
845
+ if output_hidden_states:
846
+ all_hidden_states += (hidden_states,)
847
+
848
+ if self.gradient_checkpointing and self.training:
849
+ layer_outputs = self._gradient_checkpointing_func(
850
+ decoder_layer.__call__,
851
+ hidden_states,
852
+ attention_mask,
853
+ position_ids,
854
+ past_key_values,
855
+ output_attentions,
856
+ use_cache,
857
+ )
858
+ else:
859
+ layer_outputs = decoder_layer(
860
+ hidden_states,
861
+ attention_mask=attention_mask,
862
+ position_ids=position_ids,
863
+ past_key_values=past_key_values,
864
+ output_attentions=output_attentions,
865
+ use_cache=use_cache,
866
+ )
867
+ hidden_states = layer_outputs[0]
868
+
869
+ if use_cache:
870
+ # Following the Mistral schema for layer return values
871
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
872
+ if output_attentions:
873
+ all_self_attns += (layer_outputs[1],)
874
+
875
+ hidden_states = self.final_layernorm(hidden_states)
876
+
877
+ if residual > 0:
878
+ hidden_states = strip_padding_from_tensor(tensor=hidden_states, dim=1, residual=residual)
879
+
880
+ # add hidden states from the last decoder layer
881
+ if output_hidden_states:
882
+ all_hidden_states += (hidden_states,)
883
+
884
+ next_cache = None
885
+ if use_cache:
886
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
887
+
888
+ if not return_dict:
889
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
890
+ return BaseModelOutputWithPast(
891
+ last_hidden_state=hidden_states,
892
+ past_key_values=next_cache,
893
+ hidden_states=all_hidden_states,
894
+ attentions=all_self_attns,
895
+ )
896
+
897
+
898
+ class Phi3SmallForCausalLM(Phi3SmallPreTrainedModel):
899
+ _tied_weights_keys = ["lm_head.weight"]
900
+
901
+ def __init__(self, config):
902
+ super().__init__(config)
903
+ self.model = Phi3SmallModel(config)
904
+ self.vocab_size = config.vocab_size
905
+ self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
906
+ self.mup_width_multiplier = config.mup_width_multiplier
907
+
908
+ # Create the mask for the dummy tokens in the vocabulary
909
+ dummy_token_indices = config.dummy_token_indices
910
+ dummy_tokens_mask = torch.zeros(self.vocab_size).bool()
911
+ dummy_tokens_mask[dummy_token_indices] = True
912
+ # shape: (vocab_size,)
913
+ self.register_buffer("dummy_tokens_mask", dummy_tokens_mask, persistent=False)
914
+
915
+ # Initialize weights and apply final processing
916
+ self.post_init()
917
+
918
+ def get_input_embeddings(self):
919
+ return self.model.embed_tokens
920
+
921
+ def set_input_embeddings(self, value):
922
+ self.model.embed_tokens = value
923
+
924
+ def get_output_embeddings(self):
925
+ return self.lm_head
926
+
927
+ def set_output_embeddings(self, value):
928
+ self.lm_head = value
929
+
930
+ def set_decoder(self, decoder):
931
+ self.model = decoder
932
+
933
+ def get_decoder(self):
934
+ return self.model
935
+
936
+ def forward(
937
+ self,
938
+ input_ids: torch.LongTensor = None,
939
+ attention_mask: Optional[torch.Tensor] = None,
940
+ position_ids: Optional[torch.LongTensor] = None,
941
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
942
+ inputs_embeds: Optional[torch.FloatTensor] = None,
943
+ labels: Optional[torch.LongTensor] = None,
944
+ use_cache: Optional[bool] = None,
945
+ output_attentions: Optional[bool] = None,
946
+ output_hidden_states: Optional[bool] = None,
947
+ return_dict: Optional[bool] = None,
948
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
949
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
950
+ output_hidden_states = (
951
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
952
+ )
953
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
954
+
955
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
956
+ outputs = self.model(
957
+ input_ids=input_ids,
958
+ attention_mask=attention_mask,
959
+ position_ids=position_ids,
960
+ past_key_values=past_key_values,
961
+ inputs_embeds=inputs_embeds,
962
+ use_cache=use_cache,
963
+ output_attentions=output_attentions,
964
+ output_hidden_states=output_hidden_states,
965
+ return_dict=return_dict,
966
+ )
967
+
968
+ hidden_states = outputs[0]
969
+ logits = self.lm_head(hidden_states)
970
+ logits = logits.float()
971
+ if self.mup_width_multiplier:
972
+ logits = logits / self.mup_width_multiplier
973
+ logits = logits.masked_fill(self.dummy_tokens_mask, min_value_of_dtype(logits.dtype))
974
+
975
+ loss = None
976
+ if labels is not None:
977
+ # Shift so that tokens < n predict n
978
+ shift_logits = logits[..., :-1, :].contiguous()
979
+ shift_labels = labels[..., 1:].contiguous()
980
+ # Flatten the tokens
981
+ loss_fct = nn.CrossEntropyLoss()
982
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
983
+ shift_labels = shift_labels.view(-1)
984
+ # Enable model parallelism
985
+ shift_labels = shift_labels.to(shift_logits.device)
986
+ loss = loss_fct(shift_logits, shift_labels)
987
+
988
+ if not return_dict:
989
+ output = (logits,) + outputs[1:]
990
+ return (loss,) + output if loss is not None else output
991
+
992
+ return CausalLMOutputWithPast(
993
+ loss=loss,
994
+ logits=logits,
995
+ past_key_values=outputs.past_key_values,
996
+ hidden_states=outputs.hidden_states,
997
+ attentions=outputs.attentions,
998
+ )
999
+
1000
+ def prepare_inputs_for_generation(
1001
+ self,
1002
+ input_ids: torch.LongTensor,
1003
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1004
+ attention_mask: Optional[torch.FloatTensor] = None,
1005
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1006
+ **kwargs
1007
+ ) -> Dict[str, Any]:
1008
+ # only last token for inputs_ids if past is defined in kwargs
1009
+ if past_key_values:
1010
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1011
+
1012
+ position_ids = kwargs.get("position_ids", None)
1013
+
1014
+ if attention_mask is not None and position_ids is None:
1015
+ # create position_ids on the fly for batch generation
1016
+ position_ids = attention_mask.long().cumsum(-1) - 1
1017
+ position_ids.masked_fill_(attention_mask == 0, 1)
1018
+ if past_key_values:
1019
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1020
+ else:
1021
+ position_ids = None
1022
+
1023
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1024
+ if inputs_embeds is not None and past_key_values is None:
1025
+ model_inputs = {"inputs_embeds": inputs_embeds}
1026
+ else:
1027
+ model_inputs = {"input_ids": input_ids}
1028
+
1029
+ model_inputs.update(
1030
+ {
1031
+ "past_key_values": past_key_values,
1032
+ "use_cache": kwargs.get("use_cache"),
1033
+ "position_ids": position_ids,
1034
+ "attention_mask": attention_mask,
1035
+ }
1036
+ )
1037
+ return model_inputs
1038
+
1039
+
1040
+ # Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with Mistral -> Phi3Small
1041
+ class Phi3SmallForSequenceClassification(Phi3SmallPreTrainedModel):
1042
+ def __init__(self, config):
1043
+ super().__init__(config)
1044
+ self.num_labels = config.num_labels
1045
+ self.model = Phi3SmallModel(config)
1046
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1047
+
1048
+ # Initialize weights and apply final processing
1049
+ self.post_init()
1050
+
1051
+ def get_input_embeddings(self):
1052
+ return self.model.embed_tokens
1053
+
1054
+ def set_input_embeddings(self, value):
1055
+ self.model.embed_tokens = value
1056
+
1057
+
1058
+ def forward(
1059
+ self,
1060
+ input_ids: torch.LongTensor = None,
1061
+ attention_mask: Optional[torch.Tensor] = None,
1062
+ position_ids: Optional[torch.LongTensor] = None,
1063
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1064
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1065
+ labels: Optional[torch.LongTensor] = None,
1066
+ use_cache: Optional[bool] = None,
1067
+ output_attentions: Optional[bool] = None,
1068
+ output_hidden_states: Optional[bool] = None,
1069
+ return_dict: Optional[bool] = None,
1070
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1071
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1072
+
1073
+ transformer_outputs = self.model(
1074
+ input_ids,
1075
+ attention_mask=attention_mask,
1076
+ position_ids=position_ids,
1077
+ past_key_values=past_key_values,
1078
+ inputs_embeds=inputs_embeds,
1079
+ use_cache=use_cache,
1080
+ output_attentions=output_attentions,
1081
+ output_hidden_states=output_hidden_states,
1082
+ return_dict=return_dict,
1083
+ )
1084
+ hidden_states = transformer_outputs[0]
1085
+ logits = self.score(hidden_states)
1086
+
1087
+ if input_ids is not None:
1088
+ batch_size = input_ids.shape[0]
1089
+ else:
1090
+ batch_size = inputs_embeds.shape[0]
1091
+
1092
+ if self.config.pad_token_id is None and batch_size != 1:
1093
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1094
+ if self.config.pad_token_id is None:
1095
+ sequence_lengths = -1
1096
+ else:
1097
+ if input_ids is not None:
1098
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1099
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1100
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1101
+ sequence_lengths = sequence_lengths.to(logits.device)
1102
+ else:
1103
+ sequence_lengths = -1
1104
+
1105
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1106
+
1107
+ loss = None
1108
+ if labels is not None:
1109
+ labels = labels.to(logits.device)
1110
+ if self.config.problem_type is None:
1111
+ if self.num_labels == 1:
1112
+ self.config.problem_type = "regression"
1113
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1114
+ self.config.problem_type = "single_label_classification"
1115
+ else:
1116
+ self.config.problem_type = "multi_label_classification"
1117
+
1118
+ if self.config.problem_type == "regression":
1119
+ loss_fct = nn.MSELoss()
1120
+ if self.num_labels == 1:
1121
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1122
+ else:
1123
+ loss = loss_fct(pooled_logits, labels)
1124
+ elif self.config.problem_type == "single_label_classification":
1125
+ loss_fct = nn.CrossEntropyLoss()
1126
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1127
+ elif self.config.problem_type == "multi_label_classification":
1128
+ loss_fct = nn.BCEWithLogitsLoss()
1129
+ loss = loss_fct(pooled_logits, labels)
1130
+ if not return_dict:
1131
+ output = (pooled_logits,) + transformer_outputs[1:]
1132
+ return ((loss,) + output) if loss is not None else output
1133
+
1134
+ return SequenceClassifierOutputWithPast(
1135
+ loss=loss,
1136
+ logits=pooled_logits,
1137
+ past_key_values=transformer_outputs.past_key_values,
1138
+ hidden_states=transformer_outputs.hidden_states,
1139
+ attentions=transformer_outputs.attentions,
1140
+ )
positional_embedding.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Orginally Taken verbatim from xformers library
3
+ https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py
4
+
5
+ The difference is that xformers seems to assume the inputs to be
6
+ (bs, head, seq_len, dim) while we assume (bs, seq_len, head, dim)
7
+
8
+ """
9
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
10
+ #
11
+ # This source code is licensed under the BSD license found in the
12
+ # LICENSE file in the root directory of this source tree.
13
+
14
+
15
+ # CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
16
+ # NOTE: Almost the same right now, moving parts to Triton is the next step
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Dict, Union
20
+
21
+ import torch
22
+ import dataclasses
23
+ from transformers.utils import logging
24
+
25
+ from transformers import PretrainedConfig
26
+
27
+ is_dacite_available = False
28
+ try:
29
+ import dacite
30
+ is_dacite_available = True
31
+ except ImportError:
32
+ pass
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ @dataclasses.dataclass
37
+ class LongRopeConfig(object):
38
+ short_factor: List[float]
39
+ long_factor: List[float]
40
+ original_max_position_embeddings: int
41
+ type: str = "longrope"
42
+ short_mscale: float = -1
43
+ long_mscale: float = -1
44
+
45
+
46
+ def __post_init__(self):
47
+ assert self.type in ("longrope", "su"), f"Invalid type {self.type} for LongRopeConfig. Expected longrope / su"
48
+
49
+
50
+ @classmethod
51
+ def from_dict(cls, config_dict: Dict[str, Union[float, List[float], int]]) -> "LongRopeConfig":
52
+ if is_dacite_available:
53
+ # Preferred since we can also type check the input
54
+ return dacite.from_dict(data_class=cls, data=config_dict)
55
+ kwargs = {}
56
+ for field in dataclasses.fields(cls):
57
+ if field.name in config_dict:
58
+ if field.init:
59
+ kwargs[field.name] = config_dict[field.name]
60
+ else:
61
+ raise ValueError(f"Field {field.name} is not initiable")
62
+ else:
63
+ if field.default is dataclasses.MISSING:
64
+ raise ValueError(f"Field {field.name} is required")
65
+ extra_keys = set(config_dict.keys()) - set(kwargs.keys())
66
+ if len(extra_keys) > 0:
67
+ for key in extra_keys:
68
+ logger.error(f"Unrecognized key {key} in config_dict")
69
+ raise ValueError(f"Unrecognized keys in config_dict")
70
+ return cls(**kwargs)
71
+
72
+ def rotate_half(x):
73
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
74
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
75
+
76
+
77
+
78
+ @torch.jit.script
79
+ def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int):
80
+ # NOTE: This could probably be moved to Triton
81
+
82
+ if seq_dimension == 0:
83
+ cos = cos[: x.shape[0], None, None, :]
84
+ sin = sin[: x.shape[0], None, None, :]
85
+ elif seq_dimension == 1:
86
+ # Handle a possible sequence length mismatch in between q and k
87
+ cos = cos[None, : x.shape[1], None, :]
88
+ sin = sin[None, : x.shape[1], None, :]
89
+ elif seq_dimension == 2:
90
+ cos = cos[None, None, : x.shape[2], :]
91
+ sin = sin[None, None, : x.shape[2], :]
92
+
93
+ return (x * cos) + (rotate_half(x) * sin)
94
+
95
+
96
+
97
+ class RotaryEmbedding(torch.nn.Module):
98
+ """
99
+ Adapted from the xformers library
100
+
101
+ The rotary position embeddings from RoFormer_ (Su et. al).
102
+ A crucial insight from the method is that the query and keys are
103
+ transformed by rotation matrices which depend on the relative positions.
104
+ Other implementations are available in the Rotary Transformer repo_ and in
105
+ GPT-NeoX_, GPT-NeoX was an inspiration
106
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
107
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
108
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
109
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
110
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
111
+
112
+ # Arguments
113
+ :param dim_mode: head dimention
114
+ :param max_seq_len:
115
+ :param default_seq_dimension: which dim is the sequence length
116
+ :param dtype: cos/sin dtype
117
+ :param use_fused_kernel: if to use customized fused kernel.
118
+ Note: if used, q, k will be modified inplace. Ok for both forward & backward.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ dim_model: int,
124
+ *,
125
+ max_seq_len: Optional[int] = None,
126
+ dtype: Optional[torch.dtype] = None,
127
+ base=10000,
128
+ position_scale=1,
129
+ device: Optional[torch.device] = None,
130
+ longrope_config: Optional[LongRopeConfig] = None,
131
+ ):
132
+ super().__init__()
133
+ self.base = base
134
+ self.dim_model = dim_model
135
+ self.max_seq_len = max_seq_len
136
+ self.longrope_config = longrope_config
137
+
138
+ if self.is_longrope:
139
+ # Keep the maximum range vector, and slice from it as needed
140
+ self.register_buffer(
141
+ "range_vector",
142
+ torch.arange(max_seq_len, device=device, dtype=torch.float32),
143
+ persistent=False
144
+ )
145
+ self.register_buffer(
146
+ "short_factors",
147
+ torch.tensor(self.longrope_config.short_factor, dtype=torch.float32),
148
+ persistent=False
149
+ )
150
+ self.register_buffer(
151
+ "long_factors",
152
+ torch.tensor(self.longrope_config.long_factor, dtype=torch.float32),
153
+ persistent=False
154
+ )
155
+ else:
156
+ # Generate and save the inverse frequency buffer (non trainable)
157
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim_model, 2).float().to(device) / self.dim_model))
158
+ self.register_buffer("inv_freq", inv_freq)
159
+
160
+ self.position_scale = position_scale
161
+
162
+ if not self.is_longrope:
163
+ dtype = dtype or torch.get_default_dtype()
164
+ self._set_cos_sin_cache(
165
+ seq_len=max_seq_len,
166
+ device=self.inv_freq.device,
167
+ dtype=dtype,
168
+ )
169
+ @property
170
+ def is_longrope(self):
171
+ return self.longrope_config is not None
172
+
173
+ @property
174
+ def original_max_seq_len(self):
175
+ if self.longrope_config is not None:
176
+ return self.longrope_config.original_max_position_embeddings
177
+ logger.warning_once(
178
+ (
179
+ "``original_max_seq_len'' is being accessed, but longrope_config has not been set. "
180
+ "Please only do this if you are sure about the context."
181
+ )
182
+ )
183
+ return self.max_seq_len
184
+
185
+ def get_range_vector(self, seq_len: int, device: torch.device):
186
+ if self.is_longrope:
187
+ assert seq_len < self.range_vector.shape[0], f"Found seq_len {seq_len} greater than max_seq_len {self.range_vector.shape[0]}"
188
+ if self.range_vector.device != device:
189
+ self.range_vector = self.range_vector.to(device)
190
+ return self.range_vector[:seq_len]
191
+ return torch.arange(seq_len, device=device, dtype=torch.float32)
192
+
193
+
194
+ def _calc_mscale(self, scale: torch.Tensor) -> torch.Tensor:
195
+ if scale <= 1.0:
196
+ return 1.0
197
+ return math.sqrt(1 + math.log(scale) / math.log(self.original_max_seq_len))
198
+
199
+ def _set_cos_sin_cache(
200
+ self,
201
+ seq_len: int,
202
+ device: Optional[torch.device] = None,
203
+ dtype: Optional[torch.dtype] = None,
204
+ ) -> None:
205
+ dtype = dtype or torch.get_default_dtype()
206
+ self.max_seq_len_cached = seq_len
207
+ t = (torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) * self.position_scale).type_as(self.inv_freq)
208
+ device_type = device.type if device is not None else "cpu"
209
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
210
+ with torch.autocast(device_type=device_type, enabled=False):
211
+ # shape: (seq_len, dim_model // 2)
212
+ freqs = torch.outer(t, self.inv_freq)
213
+ # shape: (seq_len, dim_model)
214
+ emb = torch.cat((freqs, freqs), dim=-1)
215
+ cos = emb.cos()
216
+ sin = emb.sin()
217
+ self.register_buffer("cos_cached", cos.to(dtype), persistent=False)
218
+ self.register_buffer("sin_cached", sin.to(dtype), persistent=False)
219
+
220
+ def forward(
221
+ self, q: torch.Tensor,
222
+ k: torch.Tensor,
223
+ seq_dimension: int = 1,
224
+ seqlen_offset: int = 0,
225
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
226
+ """q, k does not include `seqlen_offset`
227
+ q: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
228
+ k: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
229
+ """
230
+ if seq_dimension < 0:
231
+ seq_dimension = k.ndim + seq_dimension
232
+ assert seq_dimension in (0, 1, 2)
233
+ seq_len = k.shape[seq_dimension] + seqlen_offset
234
+
235
+ if self.is_longrope:
236
+ if seq_len > self.original_max_seq_len:
237
+ t = self.get_range_vector(seq_len, device=q.device)
238
+ rescale_factors = self.long_factors.to(q.device)
239
+ long_mscale = self.longrope_config.long_mscale
240
+ mscale = long_mscale if long_mscale > 0 else self._calc_mscale(self.max_seq_len / self.original_max_seq_len)
241
+ else:
242
+ t = self.get_range_vector(self.original_max_seq_len, device=q.device)
243
+ rescale_factors = self.short_factors.to(q.device)
244
+ short_mscale = self.longrope_config.short_mscale
245
+ mscale = short_mscale if short_mscale > 0 else 1.0
246
+ assert rescale_factors.shape == (self.dim_model // 2, ), (
247
+ f"misaligned shape for LongRoPE rescale factors:\n"
248
+ f"\tExpected {(self.dim_model // 2, )}, got {rescale_factors.shape}."
249
+ )
250
+ inv_freq = 1.0 / (rescale_factors * (self.base ** (torch.arange(0, self.dim_model, 2).float().to(q.device) / self.dim_model)))
251
+ device_type = q.device.type if q.device is not None else "cpu"
252
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
253
+ with torch.autocast(device_type=device_type, enabled=False):
254
+ freqs = torch.outer(t, inv_freq)
255
+ emb = torch.cat((freqs, freqs), dim=-1)
256
+ cos = emb.cos() * mscale
257
+ sin = emb.sin() * mscale
258
+ cos_cached = cos.to(q.dtype)
259
+ sin_cached = sin.to(q.dtype)
260
+ else:
261
+ if seq_len > self.max_seq_len_cached:
262
+ self._set_cos_sin_cache(
263
+ seq_len=seq_len,
264
+ device=k.device,
265
+ dtype=k.dtype,
266
+ )
267
+ cos_cached = self.cos_cached
268
+ sin_cached = self.sin_cached
269
+ return (
270
+ apply_rotary_pos_emb(
271
+ q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
+ ),
273
+ apply_rotary_pos_emb(
274
+ k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
+ ),
276
+ )
277
+
278
+ @classmethod
279
+ def from_config(cls, config: PretrainedConfig) -> "RotaryEmbedding":
280
+ kwargs = dict(
281
+ dim_model=config.hidden_size // config.num_attention_heads,
282
+ max_seq_len=config.max_position_embeddings,
283
+ base=config.rope_embedding_base,
284
+ position_scale=config.rope_position_scale,
285
+ )
286
+ if config.rope_scaling is not None:
287
+ kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
288
+ return cls(**kwargs)
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>"
5
+ }
tokenization_phi3_small.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/tokenization_qwen.py
2
+ import os
3
+ from typing import Collection, List, Optional, Dict, Set, Tuple, Union
4
+
5
+ from functools import cached_property
6
+
7
+ import base64
8
+
9
+ from transformers import PreTrainedTokenizer, AddedToken, AutoConfig
10
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
11
+ import tiktoken
12
+
13
+
14
+ """
15
+ This tokenizer is almost identical to tiktoken.get_encoding("cl100k_base")
16
+ with a few additional special tokens to support the ChatML format.
17
+
18
+ TODO(bapatra): Right now, I do not save the special tokens to the vocab file.
19
+ Maybe in the future, that would be useful? Can add that support later.
20
+
21
+ """
22
+
23
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
24
+ with open(tiktoken_bpe_file, "rb") as f:
25
+ contents = f.read()
26
+ return {
27
+ base64.b64decode(token): int(rank)
28
+ for token, rank in (line.split() for line in contents.splitlines() if line)
29
+ }
30
+
31
+ # On the megatron codebase, we pad vocabularies to ensure matrix multiplication is fast.
32
+ # this in turn causes some indices to be empty. We account for these empty indices by adding
33
+ # dummy tokens to the tokenizer.
34
+
35
+ EFFECTIVE_PADDED_VOCAB_SIZE = 100352
36
+ ACTUAL_VOCAB_SIZE = 100276
37
+
38
+
39
+ DUMMY_TOKENS = {
40
+ f"<|dummy_id_{11 + offset}|>": 100276 + offset
41
+ for offset in range(1, EFFECTIVE_PADDED_VOCAB_SIZE - ACTUAL_VOCAB_SIZE)
42
+ }
43
+
44
+ SPECIAL_TOKENS = {
45
+ # tiktoken.get_encoding("cl100k_base")._special_tokens
46
+ '<|endoftext|>': 100257,
47
+ '<|fim_prefix|>': 100258,
48
+ '<|fim_middle|>': 100259,
49
+ '<|fim_suffix|>': 100260,
50
+ # Special tokens for post-training
51
+ "<|system|>": 100261,
52
+ "<|user|>": 100262,
53
+ "<|assistant|>": 100263,
54
+ # Dummy unused tokens
55
+ "<|dummy_id_0|>": 100264,
56
+ "<|dummy_id_1|>": 100265,
57
+ # Special tokens for post-training continued
58
+ "<|end|>": 100266,
59
+ # Some dummy tokens, so that tokenization is contiguous and does not cause issues
60
+ # Note that the 100256th token of tiktoken.get_encoding("cl100k_base") does not
61
+ # actually map to anything. So we use a dummy token here.
62
+ "<|dummy_id_2|>": 100256,
63
+ # Likewise, tokens from 100267 to 100275 are also unused
64
+ "<|dummy_id_3|>": 100267,
65
+ "<|dummy_id_4|>": 100268,
66
+ "<|dummy_id_5|>": 100269,
67
+ "<|dummy_id_6|>": 100270,
68
+ "<|dummy_id_7|>": 100271,
69
+ "<|dummy_id_8|>": 100272,
70
+ "<|dummy_id_9|>": 100273,
71
+ "<|dummy_id_10|>": 100274,
72
+ "<|dummy_id_11|>": 100275,
73
+ # The final end of prompt token
74
+ # (unused, but present as a part of tiktoken.get_encoding("cl100k_base")._special_tokens)
75
+ '<|endofprompt|>': 100276,
76
+ # Dummy tokens to account for padding of the tokenizer
77
+ # We pad to ensure tensor cores are used for vocab multiplication
78
+ **DUMMY_TOKENS
79
+ }
80
+
81
+ class Phi3SmallTokenizer(PreTrainedTokenizer):
82
+ vocab_files_names = {
83
+ "vocab_file": "cl100k_base.tiktoken"
84
+ }
85
+
86
+ model_input_names: List[str] = ["input_ids", "attention_mask"]
87
+ padding_side = "left"
88
+
89
+ def __init__(
90
+ self,
91
+ vocab_file: Optional[str] = None,
92
+ errors: str = "replace",
93
+ **kwargs
94
+ ) -> None:
95
+ # PreTrainedTokenizer's init calls _add_tokens, which in turn checks
96
+ # if the token is present in `self.special_tokens``. Hence instantiating it here.
97
+ # The way Qwen gets around this is by checking against SPECIAL_TOKENS
98
+ # But I think it's better to check against the objects own `special_tokens`
99
+ # in case we eventually want to allow the tokenizer to have special tokens.
100
+ self.special_tokens = SPECIAL_TOKENS
101
+
102
+ super().__init__(**kwargs)
103
+ self.errors = errors
104
+
105
+ base = tiktoken.get_encoding("cl100k_base")
106
+ if vocab_file is None:
107
+ self.mergeable_ranks: Dict[bytes, int] = base._mergeable_ranks
108
+ else:
109
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)
110
+
111
+ self.pat_str = base._pat_str
112
+
113
+ enc = tiktoken.Encoding(
114
+ name="phi3small",
115
+ pat_str=self.pat_str,
116
+ mergeable_ranks=self.mergeable_ranks,
117
+ special_tokens=self.special_tokens,
118
+ )
119
+ self.tokenizer = enc
120
+
121
+ self.decoder: Dict[int, bytes] = {
122
+ v: k for k, v in self.mergeable_ranks.items()
123
+ }
124
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
125
+
126
+ self.eod_id = self.tokenizer.eot_token
127
+ self._eos_token = self._convert_id_to_token(self.eod_id)
128
+
129
+ # Setting the bos_token to be the same as the eos_token
130
+ # Note that this is **not** the correct thing to do, and is done
131
+ # just so that some of the downstream libraries do not break.
132
+ self._bos_token = self._eos_token
133
+
134
+ # Assign the special tokens to class variables
135
+ self.system_id = self.special_tokens["<|system|>"]
136
+ self.user_id = self.special_tokens["<|user|>"]
137
+ self.assistant_id = self.special_tokens["<|assistant|>"]
138
+ self.end_id = self.special_tokens["<|end|>"]
139
+
140
+ @cached_property
141
+ def dummy_token_indices(self) -> List[int]:
142
+ # There are some additional special tokens in the cl100k_base tokenizer
143
+ # that we do not use. Hence, we also consider them to be dummy tokens.
144
+ additional_tokens = [
145
+ "<|fim_prefix|>",
146
+ "<|fim_middle|>",
147
+ "<|fim_suffix|>",
148
+ "<|endofprompt|>"
149
+ ]
150
+ dummy_token_indices = [index for token, index in self.special_tokens.items() if "dummy_id" in token]
151
+ dummy_token_indices.extend([self.special_tokens[token] for token in additional_tokens])
152
+ return sorted(dummy_token_indices)
153
+
154
+ def __getstate__(self):
155
+ state = self.__dict__.copy()
156
+ del state["tokenizer"]
157
+ return state
158
+
159
+ def __setstate__(self, state):
160
+ self.__dict__ = state
161
+ enc = tiktoken.Encoding(
162
+ name="cl100k_im",
163
+ pat_str=self.pat_str,
164
+ mergeable_ranks=self.mergeable_ranks,
165
+ special_tokens=self.special_tokens,
166
+ )
167
+ self.tokenizer = enc
168
+
169
+ def __len__(self):
170
+ return self.tokenizer.n_vocab
171
+
172
+ @classmethod
173
+ def from_pretrained(
174
+ cls,
175
+ pretrained_model_name_or_path: Union[str, os.PathLike],
176
+ *init_inputs,
177
+ **kwargs,
178
+ ):
179
+ cls_kwargs = kwargs
180
+ # First try to load from the tokenization config if it exists
181
+ tokenization_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
182
+ if tokenization_config:
183
+ cls_kwargs.update(
184
+ dict(
185
+ model_max_length=tokenization_config["model_max_length"],
186
+ chat_template=tokenization_config.get("chat_template", None)
187
+ )
188
+ )
189
+ else:
190
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
191
+ cls_kwargs["model_max_length"] = config.max_position_embeddings
192
+ return cls(**cls_kwargs)
193
+
194
+ def get_vocab(self) -> Dict[Union[str, bytes], int]:
195
+ return {**self.mergeable_ranks, **self.special_tokens}
196
+
197
+ def convert_tokens_to_ids(
198
+ self,
199
+ tokens: Union[bytes, str, List[Union[bytes, str]]]
200
+ ) -> Union[int, List[int]]:
201
+ ids = []
202
+ if isinstance(tokens, (str, bytes)):
203
+ if tokens in self.special_tokens:
204
+ return self.special_tokens[tokens]
205
+ else:
206
+ return self.mergeable_ranks.get(tokens)
207
+ ids: List[int] = []
208
+ for token in tokens:
209
+ ids.append(self.convert_tokens_to_ids(token))
210
+ return ids
211
+
212
+ def _add_tokens(
213
+ self,
214
+ new_tokens: Union[List[str], List[AddedToken]],
215
+ special_tokens: bool = False,
216
+ ) -> int:
217
+ if not special_tokens and new_tokens:
218
+ raise ValueError("Only special tokens can be added to this tokenizer")
219
+ for token in new_tokens:
220
+ surface_form = token.content if isinstance(token, AddedToken) else token
221
+ if surface_form not in self.special_tokens:
222
+ raise ValueError(
223
+ "For now, we do not support unknown special tokens\n"
224
+ "In the future, if there is a need for this, we can add special tokens to the tokenizer\n"
225
+ "starting from rank 100261 - 100263 and then 100266 - 100275.\n"
226
+ "And finally, we can re-construct the enc object back\n"
227
+ )
228
+ return 0
229
+
230
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
231
+ file_path = os.path.join(save_directory, "cl100k_base.tiktoken")
232
+ with open(file_path, "w") as f:
233
+ for token, rank in self.mergeable_ranks.items():
234
+ line = base64.b64encode(token).decode("utf-8") + " " + str(rank) + "\n"
235
+ f.write(line)
236
+ return (file_path,)
237
+
238
+ def tokenize(
239
+ self,
240
+ text: str,
241
+ allowed_special: Union[Set, str] = "all",
242
+ disallowed_special: Union[Collection, str] = (),
243
+ **kwargs
244
+ ) -> List[Union[bytes, str]]:
245
+ tokens: List[Union[bytes, str]] = []
246
+ for token_id in self.tokenizer.encode(
247
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
248
+ ):
249
+ tokens.append(self.decoder[token_id])
250
+ return tokens
251
+
252
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
253
+ """
254
+ Converts a sequence of tokens in a single string.
255
+ """
256
+ text = ""
257
+ temp = b""
258
+ for t in tokens:
259
+ if isinstance(t, str):
260
+ if temp:
261
+ text += temp.decode("utf-8", errors=self.errors)
262
+ temp = b""
263
+ text += t
264
+ elif isinstance(t, bytes):
265
+ temp += t
266
+ else:
267
+ raise TypeError("token should only be of type types or str")
268
+ if temp:
269
+ text += temp.decode("utf-8", errors=self.errors)
270
+ return text
271
+
272
+ @property
273
+ def vocab_size(self):
274
+ return self.tokenizer.n_vocab
275
+
276
+ @property
277
+ def eos_token_id(self) -> int:
278
+ return self.eod_id
279
+
280
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
281
+ """Converts an id to a token, special tokens included"""
282
+ if index in self.decoder:
283
+ return self.decoder[index]
284
+ raise ValueError("unknown ids")
285
+
286
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
287
+ """Converts a token to an id using the vocab, special tokens included"""
288
+ if token in self.special_tokens:
289
+ return self.special_tokens[token]
290
+ if token in self.mergeable_ranks:
291
+ return self.mergeable_ranks[token]
292
+ raise ValueError("unknown token")
293
+
294
+ def _tokenize(self, text: str, **kwargs):
295
+ """
296
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
297
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
298
+ Do NOT take care of added tokens.
299
+ """
300
+ raise NotImplementedError
301
+
302
+ def _decode(
303
+ self,
304
+ token_ids: Union[int, List[int]],
305
+ skip_special_tokens: bool = False,
306
+ errors: str = None,
307
+ **kwargs,
308
+ ) -> str:
309
+ if isinstance(token_ids, int):
310
+ token_ids = [token_ids]
311
+ if skip_special_tokens:
312
+ token_ids = [i for i in token_ids if i < self.eod_id]
313
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
314
+
315
+
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_phi3_small.Phi3SmallTokenizer",
6
+ "tokenization_phi3_small.Phi3SmallTokenizer"
7
+ ]
8
+ },
9
+ "bos_token": "<|endoftext|>",
10
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
11
+ "clean_up_tokenization_spaces": true,
12
+ "eos_token": "<|endoftext|>",
13
+ "model_max_length": 8192,
14
+ "pad_token": "<|endoftext|>",
15
+ "tokenizer_class": "Phi3SmallTokenizer"
16
+ }
triton_blocksparse_attention_layer.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, TypeVar
3
+ import torch.nn as nn
4
+ import torch
5
+ import triton
6
+
7
+ from functools import lru_cache
8
+
9
+
10
+ from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd
11
+
12
+
13
+ Layout = Tuple[torch.LongTensor, torch.LongTensor]
14
+
15
+
16
+ def create_sparse_attn_mask(
17
+ n_heads: int,
18
+ max_seq_len: int,
19
+ max_seq_len_k: int,
20
+ dtype: torch.dtype,
21
+ device: torch.device,
22
+ BLOCK: int,
23
+ local_blocks: int,
24
+ vert_stride: int,
25
+ homo_head: bool,
26
+ return_dense: bool
27
+ ) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]:
28
+ layout, block_sparse_pattern, _ = _get_sparse_attn_mask(
29
+ n_heads=n_heads,
30
+ q_len=max_seq_len,
31
+ N_CTX=max_seq_len_k,
32
+ dtype=dtype,
33
+ device=device,
34
+ BLOCK=BLOCK,
35
+ local_blocks=local_blocks,
36
+ vert_stride=vert_stride,
37
+ homo_head=homo_head,
38
+ return_dense=return_dense
39
+ )
40
+ return layout, block_sparse_pattern
41
+
42
+
43
+ class BlockSparseAttentionLayer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ n_heads: int,
47
+ max_seq_len: int,
48
+ sparse_block_size: int,
49
+ local_blocks: int,
50
+ vert_stride: int,
51
+ kernel_block_size: Optional[int] = None,
52
+ homo_head: bool = False,
53
+ active_head_range: Optional[Tuple[int]] = None
54
+ ) -> None:
55
+ super().__init__()
56
+
57
+ self.n_heads = n_heads
58
+ self.max_seq_len = max_seq_len
59
+ self.sparse_block_size = sparse_block_size
60
+ self.kernel_block_size = kernel_block_size or sparse_block_size
61
+ self.local_blocks = local_blocks
62
+ self.vert_stride = vert_stride
63
+ self.homo_head = homo_head
64
+ self.active_head_range = active_head_range
65
+
66
+ # Internal Parameters used by the layer
67
+ self._sparse_block_mask = None
68
+ self._sparse_layout = None
69
+ self._dtype = None
70
+ self._device = None
71
+
72
+ # TODO(bapatra): Ideally, I'd want to keep all the code for
73
+ # forward to be handled here, and not branch for training and inference.
74
+ # However, that refactor would need a lot of testing. For now, using the
75
+ # training op as is, and will refactor again later.
76
+
77
+ def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None:
78
+ self._sparse_block_mask = self._sparse_block_mask[h_start: h_end]
79
+ self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end]
80
+ self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end]
81
+
82
+ def _initialize_internals(
83
+ self,
84
+ dtype: torch.dtype,
85
+ device: torch.device
86
+ ) -> None:
87
+ self._dtype, self._device = dtype, device
88
+ self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask(
89
+ n_heads=self.n_heads,
90
+ max_seq_len=self.max_seq_len,
91
+ max_seq_len_k=self.max_seq_len,
92
+ dtype=dtype,
93
+ device=device,
94
+ BLOCK=self.sparse_block_size,
95
+ local_blocks=self.local_blocks,
96
+ vert_stride=self.vert_stride,
97
+ homo_head=self.homo_head,
98
+ return_dense=False,
99
+ )
100
+ if (not self.homo_head) and (self.active_head_range is not None):
101
+ assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads."
102
+ h_start, h_end = self.active_head_range
103
+ self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end)
104
+
105
+ assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}."
106
+ assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given"
107
+ if self.sparse_block_size // self.kernel_block_size > 1:
108
+ _mul = self.sparse_block_size // self.kernel_block_size
109
+ # need to consider if block_m and block_n are different
110
+ self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul))
111
+ num_sparse_blocks = self._sparse_block_mask.size(-1)
112
+ block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
113
+ self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask)
114
+
115
+
116
+ def forward(
117
+ self,
118
+ q: torch.Tensor,
119
+ k: torch.Tensor,
120
+ v: torch.Tensor,
121
+ sm_scale: float,
122
+ *,
123
+ # Arguments Related to Block Attention Inference
124
+ left_paddings: Optional[torch.LongTensor] = None,
125
+ seqlens: Optional[torch.LongTensor] = None,
126
+ # Arguements Related to Variable Length Inference
127
+ cu_seqlens_k: Optional[torch.LongTensor] = None,
128
+ cu_seqlens_q: Optional[torch.LongTensor] = None,
129
+ ) -> torch.Tensor:
130
+
131
+ if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None:
132
+ blocksparse_op = get_local_strided_sparse_attention_op(
133
+ n_heads=self.n_heads,
134
+ max_seq_len=self.max_seq_len,
135
+ sparse_block_size=self.sparse_block_size,
136
+ kernel_block_size=self.kernel_block_size,
137
+ local_blocks=self.local_blocks,
138
+ vert_stride=self.vert_stride,
139
+ homo_head=self.homo_head,
140
+ device=q.device,
141
+ inference=not self.training
142
+ )
143
+ return blocksparse_op(q, k, v, sm_scale)
144
+
145
+ assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context"
146
+ # First set internals if they have not been set
147
+ if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device):
148
+ self._initialize_internals(dtype=q.dtype, device=q.device)
149
+
150
+ if k.dim() == 3:
151
+ assert cu_seqlens_k is not None
152
+ return blocksparse_flash_attn_varlen_fwd(
153
+ q=q,
154
+ k=k,
155
+ v=v,
156
+ cu_seqlens_k=cu_seqlens_k,
157
+ cu_seqlens_q=cu_seqlens_q,
158
+ sm_scale=sm_scale,
159
+ sparse_layout=self._sparse_layout,
160
+ block_size=self.kernel_block_size,
161
+ max_seqlen=self.max_seq_len,
162
+ )
163
+ if k.dim() == 4:
164
+ assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference."
165
+ return blocksparse_flash_attn_padded_fwd(
166
+ q=q,
167
+ k=k,
168
+ v=v,
169
+ sm_scale=sm_scale,
170
+ sparse_layout=self._sparse_layout,
171
+ left_paddings=left_paddings,
172
+ seqlens=seqlens,
173
+ block_size=self.kernel_block_size,
174
+ max_seqlen=self.max_seq_len,
175
+ )
176
+ raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.')
triton_flash_blocksparse_attn.py ADDED
@@ -0,0 +1,1943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Eric Lin (xihlin)
3
+ """
4
+ """
5
+ ... note(bapatra)::
6
+ This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module
7
+ imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal.
8
+ In the future, would be really good to revisit this and refactor into a more readable file structure.
9
+
10
+ """
11
+ from typing import TypeVar
12
+ from functools import lru_cache
13
+ import math
14
+ import pytest
15
+ import torch
16
+ import numpy as np
17
+
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ import os
22
+
23
+ import dataclasses
24
+
25
+ Phi3SmallConfig = TypeVar('Phi3SmallConfig')
26
+
27
+ # triton 2.0.0: fail at backward on A100, for the examples, if h_dim=128.
28
+
29
+ # Done
30
+ # 1. strided of qkv
31
+ # 2. seq len not power of 2
32
+ # 3. bf16 with Triton May, 2023
33
+
34
+ # TODO:
35
+ # 1. wip: support non-contiguous backward, also help reduce memory allocation in training (q, k, v split)
36
+ # 2. block sparse with different BLOCK_M, BLOCK_N?
37
+ # 3. for Lq not divided by BLOCK_M, BLOCK_N, only apply mask to K/V on last batch, still need to apply mask on Q.
38
+ # Attempt, fail to compile
39
+ # 4. For 2nd iter of inference, BLOCK_M=1, how to make things work? K/V maynot divided by BLOCK_N.
40
+ # 5. The inner loop can also be paralled via bigger num_stage(better) or on different thread-block (via m/L and atomic update, but this no-comm/sync between blocks)
41
+
42
+
43
+ ###########################################################
44
+ ################### Kernel Parameters #####################
45
+ ###########################################################
46
+
47
+ @dataclasses.dataclass
48
+ class BlockSparseParams(object):
49
+ block_size: int
50
+ kernel_block_size: int
51
+ num_local_blocks: int
52
+ vert_stride: int
53
+ homo_head_pattern: bool = False
54
+
55
+ @classmethod
56
+ def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams":
57
+ return cls(
58
+ block_size=config.blocksparse_block_size,
59
+ kernel_block_size=config.blocksparse_triton_kernel_block_size,
60
+ num_local_blocks=config.blocksparse_num_local_blocks,
61
+ vert_stride=config.blocksparse_vert_stride,
62
+ homo_head_pattern=config.blocksparse_homo_head_pattern,
63
+ )
64
+
65
+
66
+ ###########################################################
67
+ ###########################################################
68
+
69
+ ###########################################################
70
+ ################### Utility Functions #####################
71
+ ###########################################################
72
+
73
+ # helper functions for 3D sparse pattern
74
+ # these function are not optimized and very inefficient. Avoid calling them too frequent.
75
+ # currently, it is only called within `get_local_strided_sparse_attention_op`, which is cached.
76
+ def dense_to_crow_col(x):
77
+ ''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
78
+ param:
79
+ TODO:
80
+ 1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it?
81
+ NOTE: col_indices padded -1
82
+ '''
83
+ pad = -1
84
+ dim = x.dim()
85
+ assert x.dim() in (2, 3)
86
+ if x.dim() == 2:
87
+ x = x[None]
88
+ x = [xi.to_sparse_csr() for xi in x]
89
+ crows = torch.vstack([xi.crow_indices() for xi in x])
90
+ cols = [xi.col_indices() for xi in x]
91
+ max_cols = max(len(xi) for xi in cols)
92
+ cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols]
93
+ cols = torch.vstack(cols)
94
+ if dim == 2:
95
+ crows = crows[0]
96
+ cols = cols[0]
97
+ return crows, cols
98
+
99
+
100
+ def crow_col_to_dense(crows, cols, dtype=torch.float16):
101
+ dim = crows.dim()
102
+ if dim == 1:
103
+ crows = crows[None]
104
+ cols = cols[None]
105
+ device = crows.device
106
+ crows, cols = crows.cpu(), cols.cpu() # faster in cpu
107
+ shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
108
+ x = torch.zeros(shape, dtype=dtype)
109
+ for i in range(shape[0]):
110
+ for j in range(shape[1]):
111
+ x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1
112
+ if dim == 1:
113
+ x = x[0]
114
+ return x.to(device)
115
+
116
+
117
+ def dense_to_ccol_row(x):
118
+ '''Similar, but to CSC format
119
+ '''
120
+ x = x.transpose(-2, -1)
121
+ return dense_to_crow_col(x)
122
+
123
+
124
+ def ccol_row_to_dense(ccol, rows, dtype=torch.float16):
125
+ return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
126
+
127
+
128
+ def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False):
129
+ '''
130
+ :return: a tuple of 3:
131
+ - tuple of crow_indices, col_indices representation of CSR format.
132
+ - block dense mask
133
+ - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
134
+ '''
135
+ with torch.no_grad():
136
+ N_BLOCK = triton.cdiv(N_CTX, BLOCK)
137
+ q_pos = torch.arange(N_BLOCK)[:, None]
138
+ k_pos = torch.arange(N_BLOCK)[None]
139
+ mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0
140
+ block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
141
+ N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
142
+ block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr()
143
+ if return_dense:
144
+ mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
145
+ causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
146
+ mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask
147
+ return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense
148
+ else:
149
+ return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None
150
+
151
+
152
+ def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False):
153
+ '''
154
+ :return: a tuple of 3:
155
+ - tuple of crow_indices, col_indices representation of CSR format.
156
+ - block dense mask
157
+ - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
158
+ '''
159
+ if homo_head:
160
+ with torch.no_grad():
161
+ (crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense)
162
+ crow = crow[None].expand(n_heads, crow.shape[0])
163
+ col = col[None].expand(n_heads, col.shape[0])
164
+ if return_dense:
165
+ mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape)
166
+ return (crow, col), block_mask_dense, mask_dense
167
+
168
+ with torch.no_grad():
169
+ N_BLOCK = triton.cdiv(N_CTX, BLOCK)
170
+ q_pos = torch.arange(N_BLOCK)[None, :, None]
171
+ k_pos = torch.arange(N_BLOCK)[None, None]
172
+ head_sliding_step = max(1, int(vert_stride / n_heads)) # if vert_stride <= n_heads, rotating the heads
173
+ mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)]
174
+ mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
175
+ block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
176
+ N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
177
+ block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:]
178
+ if return_dense:
179
+ mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
180
+ causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
181
+ mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None]
182
+ return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense
183
+ else:
184
+ return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None
185
+
186
+
187
+ def get_sparse_attn_mask(q, N_CTX, *args, **kwargs):
188
+ return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs)
189
+
190
+ ###########################################################
191
+ ###########################################################
192
+
193
+ ###########################################################
194
+ ###################### Training Kernels ###################
195
+ ###########################################################
196
+
197
+ # TODO: only apply loading/saving mask on the last iteration for EVEN_N_BLOCK, useful for 1st iteration of inference.
198
+ # Experiment failed inside loop.
199
+ # Another idea: only on saving? load even out of boundary(will it causes illegal access error)?
200
+ @triton.jit
201
+ def _fwd_kernel(
202
+ Q, K, V, sm_scale,
203
+ layout_crow_ptr,
204
+ layout_col_ptr,
205
+ layout_crow_stride_h, layout_crow_stride_m,
206
+ layout_col_stride_h, layout_col_stride_m,
207
+ TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug. TMP, L, M are assumed to have contiguous layouts
208
+ Out,
209
+ stride_qz, stride_qh, stride_qm, stride_qd,
210
+ stride_kz, stride_kh, stride_kn, stride_kd,
211
+ stride_vz, stride_vh, stride_vn, stride_vd,
212
+ stride_oz, stride_oh, stride_om, stride_od,
213
+ Z, H, N_CTX,
214
+ PAST_LEN,
215
+ Q_ROUNDED_LEN,
216
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
217
+ BLOCK_N: tl.constexpr,
218
+ EVEN_M_BLOCK: tl.constexpr,
219
+ EVEN_N_BLOCK: tl.constexpr,
220
+ INFERENCE: tl.constexpr,
221
+ NUM_DBLOCKS: tl.constexpr,
222
+ ):
223
+ Q_LEN = N_CTX - PAST_LEN
224
+ start_m = tl.program_id(0)
225
+ off_hz = tl.program_id(1)
226
+ off_h = off_hz % H
227
+ off_z = off_hz // H
228
+ Q += off_z * stride_qz + off_h * stride_qh
229
+ K += off_z * stride_kz + off_h * stride_kh
230
+ V += off_z * stride_vz + off_h * stride_vh
231
+ # initialize offsets
232
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
233
+ offs_n = tl.arange(0, BLOCK_N)
234
+ offs_d = tl.arange(0, BLOCK_DMODEL)
235
+ off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
236
+ # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
237
+ off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
238
+ off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
239
+ # Initialize pointers to Q, K, V
240
+ q_ptrs = Q + off_q
241
+ k_ptrs = K + off_k
242
+ v_ptrs = V + off_v
243
+ # initialize pointer to m and l
244
+ t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m
245
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
246
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
247
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
248
+ if NUM_DBLOCKS >= 2:
249
+ acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
250
+
251
+ # load q: it will stay in SRAM throughout
252
+ if EVEN_M_BLOCK:
253
+ q = tl.load(q_ptrs)
254
+ if NUM_DBLOCKS >= 2:
255
+ q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
256
+ else:
257
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)
258
+ if NUM_DBLOCKS >= 2:
259
+ q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN)
260
+
261
+ layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m
262
+ start_l = tl.load(layout_ptr).to(tl.int32)
263
+ end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32)
264
+
265
+ # loop over k, v and update accumulator
266
+ for col_idx_idx in range(start_l, end_l):
267
+ col_idx = tl.load(layout_col_ptr + off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32)
268
+ start_n = col_idx * BLOCK_N
269
+ # -- compute qk ----
270
+ if EVEN_N_BLOCK:
271
+ k = tl.load(k_ptrs + start_n * stride_kn)
272
+ else:
273
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX)
274
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
275
+ qk += tl.dot(q, k)
276
+
277
+ if NUM_DBLOCKS >= 2:
278
+ if EVEN_N_BLOCK:
279
+ k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd)
280
+ else:
281
+ k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX)
282
+ qk += tl.dot(q2, k)
283
+
284
+ qk *= sm_scale
285
+ qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf'))
286
+ # -- compute m_ij, p, l_ij
287
+ m_ij = tl.max(qk, 1)
288
+ p = tl.exp(qk - m_ij[:, None])
289
+ l_ij = tl.sum(p, 1)
290
+ # -- update m_i and l_i
291
+ m_i_new = tl.maximum(m_i, m_ij)
292
+ alpha = tl.exp(m_i - m_i_new)
293
+ beta = tl.exp(m_ij - m_i_new)
294
+ l_i_new = alpha * l_i + beta * l_ij
295
+ # -- update output accumulator --
296
+ # scale p
297
+ p_scale = beta / l_i_new
298
+ p = p * p_scale[:, None]
299
+ # scale acc
300
+ acc_scale = l_i / l_i_new * alpha
301
+ # tl.store(t_ptrs, acc_scale)
302
+ # acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
303
+ acc = acc * acc_scale[:, None]
304
+ if NUM_DBLOCKS >= 2:
305
+ acc2 = acc2 * acc_scale[:, None]
306
+ p = p.to(Q.dtype.element_ty)
307
+ # update acc
308
+ if EVEN_N_BLOCK:
309
+ v = tl.load(v_ptrs + start_n * stride_vn)
310
+ else:
311
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX)
312
+ acc += tl.dot(p, v)
313
+
314
+ if NUM_DBLOCKS >= 2:
315
+ if EVEN_N_BLOCK:
316
+ v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd)
317
+ else:
318
+ v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX)
319
+ acc2 += tl.dot(p, v)
320
+
321
+ # update m_i and l_i
322
+ l_i = l_i_new
323
+ m_i = m_i_new
324
+
325
+ # rematerialize offsets to save registers
326
+ # start_m = tl.program_id(0)
327
+ # offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
328
+ # write back l and m
329
+ if not INFERENCE:
330
+ l_ptrs = L + off_hz * N_CTX + offs_m
331
+ m_ptrs = M + off_hz * N_CTX + offs_m
332
+ if EVEN_M_BLOCK:
333
+ tl.store(l_ptrs, l_i)
334
+ tl.store(m_ptrs, m_i)
335
+ else:
336
+ tl.store(l_ptrs, l_i, mask=offs_m < Q_LEN)
337
+ tl.store(m_ptrs, m_i, mask=offs_m < Q_LEN)
338
+ # initialize pointers to output
339
+ # offs_n = tl.arange(0, BLOCK_DMODEL)
340
+ off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
341
+ out_ptrs = Out + off_o
342
+ tl.store(out_ptrs, acc, mask=offs_m[:, None] < Q_LEN)
343
+ if NUM_DBLOCKS >= 2:
344
+ tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2, mask=offs_m[:, None] < Q_LEN)
345
+
346
+
347
+ ## backward
348
+ @triton.heuristics(
349
+ {
350
+ 'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
351
+ }
352
+ )
353
+ @triton.jit
354
+ def _bwd_preprocess(
355
+ Out, DO, L, # assume contiguous for Out, DO, L, NewDO, Delta layout.
356
+ NewDO, Delta,
357
+ N_CTX,
358
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
359
+ EVEN_M_BLOCK: tl.constexpr,
360
+ ):
361
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
362
+ off_d = tl.arange(0, D_HEAD)
363
+ # load
364
+ if EVEN_M_BLOCK:
365
+ o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
366
+ do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
367
+ else:
368
+ o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
369
+ do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
370
+ denom = tl.load(L + off_m).to(tl.float32)
371
+ # compute
372
+ do = do / denom[:, None]
373
+ delta = tl.sum(o * do, axis=1)
374
+ # write-back
375
+ if EVEN_M_BLOCK:
376
+ tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do)
377
+ else:
378
+ tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do, mask=off_m[:, None] < N_CTX)
379
+ tl.store(Delta + off_m, delta)
380
+
381
+
382
+ # Does not suuport unequal seqlen(q) and seqlen(k)
383
+ @triton.heuristics(
384
+ {
385
+ 'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
386
+ 'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0,
387
+ }
388
+ )
389
+ @triton.jit
390
+ def _bwd_kernel(
391
+ Q, K, V, sm_scale,
392
+ layout_ccol_ptr,
393
+ layout_row_ptr,
394
+ layout_ccol_stride_h, layout_ccol_stride_m,
395
+ layout_row_stride_h, layout_row_stride_m,
396
+ Out, DO, # assume contigous: Out, Do, DQ, DK, DV, L, M, D, seq(q) == seq(k), with stride_oz, stride_oh, stride_om, stride_od,
397
+ DQ, DK, DV,
398
+ L, M,
399
+ D,
400
+ stride_qz, stride_qh, stride_qm, stride_qd,
401
+ stride_kz, stride_kh, stride_kn, stride_kd,
402
+ stride_vz, stride_vh, stride_vn, stride_vd,
403
+ stride_oz, stride_oh, stride_om, stride_od,
404
+ # stride_dz, stride_dh, stride_dm, stride_dd,
405
+ Z, H, N_CTX,
406
+ num_block,
407
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
408
+ BLOCK_N: tl.constexpr,
409
+ EVEN_M_BLOCK: tl.constexpr,
410
+ EVEN_N_BLOCK: tl.constexpr,
411
+ NUM_DBLOCKS: tl.constexpr,
412
+ ):
413
+ start_n = tl.program_id(0)
414
+ off_hz = tl.program_id(1)
415
+ off_z = off_hz // H
416
+ off_h = off_hz % H
417
+ # offset pointers for batch/head
418
+ Q += off_z * stride_qz + off_h * stride_qh
419
+ K += off_z * stride_kz + off_h * stride_kh
420
+ V += off_z * stride_vz + off_h * stride_vh
421
+ DO += off_z * stride_oz + off_h * stride_oh
422
+ DQ += off_z * stride_oz + off_h * stride_oh
423
+ DK += off_z * stride_oz + off_h * stride_oh
424
+ DV += off_z * stride_oz + off_h * stride_oh
425
+ # Look like this loop can be parallelled
426
+ # for start_n in range(0, num_block):
427
+
428
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
429
+ offs_m = tl.arange(0, BLOCK_M)
430
+ offs_d = tl.arange(0, BLOCK_DMODEL)
431
+ # initialize pointers to value-like data
432
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
433
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)
434
+
435
+ # pointer to row-wise quantities in value-like data
436
+ D_ptrs = D + off_hz * N_CTX
437
+ m_ptrs = M + off_hz * N_CTX
438
+ # initialize dv amd dk
439
+ dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
440
+ dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
441
+ # k and v stay in SRAM throughout
442
+ if EVEN_N_BLOCK:
443
+ k = tl.load(k_ptrs)
444
+ v = tl.load(v_ptrs)
445
+ else:
446
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX)
447
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX)
448
+
449
+ if NUM_DBLOCKS >= 2:
450
+ dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
451
+ dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
452
+ if EVEN_N_BLOCK:
453
+ k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd)
454
+ v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd)
455
+ else:
456
+ k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX)
457
+ v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX)
458
+
459
+ # loop over rows
460
+
461
+ layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m
462
+ start_l = tl.load(layout_ptr).to(tl.int32)
463
+ end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32)
464
+
465
+ for row_idx_idx in range(start_l, end_l):
466
+ row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32)
467
+ start_m = row_idx * BLOCK_M
468
+
469
+ # offs_qm = start_m + tl.arange(0, BLOCK_M)
470
+ offs_m_curr = start_m + offs_m
471
+ q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd)
472
+ do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
473
+ dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
474
+
475
+ # load q, k, v, do on-chip
476
+ if EVEN_M_BLOCK:
477
+ q = tl.load(q_ptrs)
478
+ else:
479
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX)
480
+ # re-compute p = softmax(qk, dim=-1).T
481
+ # NOTE: `do` is pre-divided by `l`; no normalization here
482
+ qk = tl.dot(q, tl.trans(k))
483
+
484
+ if NUM_DBLOCKS >= 2:
485
+ if EVEN_M_BLOCK:
486
+ q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
487
+ else:
488
+ q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX)
489
+ qk += tl.dot(q2, tl.trans(k2))
490
+
491
+ qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf'))
492
+
493
+ if EVEN_M_BLOCK:
494
+ m = tl.load(m_ptrs + offs_m_curr)
495
+ else:
496
+ m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
497
+ p = tl.exp(qk * sm_scale - m[:, None])
498
+
499
+ # compute dv
500
+ if EVEN_M_BLOCK:
501
+ do = tl.load(do_ptrs)
502
+ else:
503
+ do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX)
504
+
505
+ if NUM_DBLOCKS >= 2:
506
+ if EVEN_M_BLOCK:
507
+ do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od)
508
+ else:
509
+ do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX)
510
+
511
+ dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
512
+
513
+ if NUM_DBLOCKS >= 2:
514
+ dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2)
515
+
516
+ # compute dp = dot(v, do)
517
+ if EVEN_M_BLOCK:
518
+ Di = tl.load(D_ptrs + offs_m_curr)
519
+ else:
520
+ Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
521
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
522
+ dp += tl.dot(do, tl.trans(v))
523
+
524
+ if NUM_DBLOCKS >= 2:
525
+ dp += tl.dot(do2, tl.trans(v2))
526
+
527
+ # compute ds = p * (dp - delta[:, None])
528
+ ds = p * dp * sm_scale
529
+ # compute dk = dot(ds.T, q)
530
+ dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
531
+ if NUM_DBLOCKS >= 2:
532
+ dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2)
533
+
534
+ # # compute dq
535
+ dq = tl.dot(ds.to(Q.dtype.element_ty), k)
536
+ if EVEN_M_BLOCK:
537
+ tl.atomic_add(dq_ptrs, dq)
538
+ else:
539
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX)
540
+
541
+ if NUM_DBLOCKS >= 2:
542
+ dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2)
543
+ dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od
544
+ if EVEN_M_BLOCK:
545
+ tl.atomic_add(dq_ptrs2, dq2)
546
+ else:
547
+ tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX)
548
+
549
+ # write-back
550
+ dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
551
+ dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
552
+ if EVEN_N_BLOCK:
553
+ tl.store(dv_ptrs, dv)
554
+ tl.store(dk_ptrs, dk)
555
+ else:
556
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX)
557
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX)
558
+
559
+ if NUM_DBLOCKS >= 2:
560
+ dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od
561
+ dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od
562
+ if EVEN_N_BLOCK:
563
+ tl.store(dv_ptrs2, dv2)
564
+ tl.store(dk_ptrs2, dk2)
565
+ else:
566
+ tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX)
567
+ tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX)
568
+
569
+
570
+
571
+ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None):
572
+ '''
573
+ :param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v.
574
+ :param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor.
575
+ Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all..
576
+ '''
577
+ assert q.shape[-1] == k.shape[-1] == v.shape[-1]
578
+ assert k.shape[2] == v.shape[2]
579
+ o = out if out is not None else torch.empty_like(q).contiguous()
580
+ grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
581
+
582
+ q_rounded_len = grid[0] * BLOCK_M
583
+ tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
584
+
585
+ if inference is None:
586
+ inference = (not q.requires_grad) and (not k.requires_grad) and (not v.requires_grad)
587
+
588
+ if inference:
589
+ L, m = tmp, tmp # no need to use create new tensor
590
+ else:
591
+ L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
592
+ m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
593
+
594
+ if layout_col_indices.dim() == 1:
595
+ layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1)
596
+ layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1)
597
+
598
+ assert q.shape[-1] in [64, 128]
599
+ BLOCK_DMODEL = 64
600
+
601
+ if num_warps is None:
602
+ MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL)
603
+ num_warps = max(1, 2 ** int(math.log2(MIN_D / 16)))
604
+ # print(f'> {BLOCK_M=}, {BLOCK_N=}, {BLOCK_DMODEL=}, {num_warps=}, {num_stages=}')
605
+ else:
606
+ assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.'''
607
+
608
+ ## For debugging:
609
+ # print(f'>> {q.shape=}, {k.shape=}, {BLOCK_M=}, {BLOCK_N=}, {num_warps=}, {BLOCK_DMODEL=}, {q.stride()=}, {k.stride()=}')
610
+ # print(f'>> {layout_crow_indices=}\n{layout_col_indices=}\n {layout_crow_indices.stride()=}, {layout_crow_indices.stride()=}')
611
+ # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
612
+ # {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
613
+
614
+ _fwd_kernel[grid](
615
+ q, k, v, sm_scale,
616
+ layout_crow_indices,
617
+ layout_col_indices,
618
+ layout_crow_indices.stride(0), layout_crow_indices.stride(1),
619
+ layout_col_indices.stride(0), layout_col_indices.stride(1),
620
+ tmp, L, m,
621
+ o,
622
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
623
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
624
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
625
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
626
+ q.shape[0], q.shape[1], k.shape[2],
627
+ k.shape[2] - q.shape[2],
628
+ q_rounded_len,
629
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
630
+ BLOCK_DMODEL=BLOCK_DMODEL,
631
+ EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
632
+ EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
633
+ INFERENCE=inference,
634
+ NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
635
+ num_warps=num_warps,
636
+ num_stages=num_stages,
637
+ )
638
+ if inference:
639
+ L, m = None, None
640
+
641
+ ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices, layout_col_indices)
642
+ ctx.BLOCK_M = BLOCK_M
643
+ ctx.BLOCK_N = BLOCK_N
644
+ ctx.BLOCK_DMODEL = BLOCK_DMODEL
645
+ # ctx.BLOCK = BLOCK
646
+ ctx.grid = grid
647
+ ctx.sm_scale = sm_scale
648
+ ctx.num_warps = num_warps
649
+ ctx.num_stages = num_stages
650
+ return o
651
+
652
+
653
+ def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None):
654
+ # q, k, v, o, l, m = ctx.saved_tensors
655
+ q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
656
+
657
+ ## this following too slow to do online, so get it from inputs, which is cached.
658
+ # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
659
+ # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
660
+
661
+ if not do.is_contiguous():
662
+ do = do.contiguous()
663
+ ## for debugging
664
+ # print(f'----> do is not contiguous: {do.stride()=}')
665
+ # raise ValueError(f'>>>> output grad is not contiguous: {do.stride()=}')
666
+
667
+ if not o.is_contiguous():
668
+ # TODO: currently only work with contiguous q/k/v.
669
+ raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.')
670
+
671
+
672
+ if layout_ccol_indices.dim() == 1:
673
+ layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1)
674
+ layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1)
675
+
676
+ # do = do.contiguous()
677
+ dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32)
678
+ dk = dk if dk is not None else torch.empty_like(k)
679
+ dv =dv if dv is not None else torch.empty_like(v)
680
+ do_scaled = torch.empty_like(do)
681
+ delta = torch.empty_like(l)
682
+
683
+ assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride()
684
+
685
+ _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
686
+ o, do, l,
687
+ do_scaled, delta,
688
+ k.shape[2],
689
+ BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1],
690
+ )
691
+
692
+ grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1])
693
+
694
+ _bwd_kernel[grid](
695
+ q, k, v, ctx.sm_scale,
696
+ layout_ccol_indices,
697
+ layout_row_indices,
698
+ layout_ccol_indices.stride(0), layout_ccol_indices.stride(1),
699
+ layout_row_indices.stride(0), layout_row_indices.stride(1),
700
+ o, do_scaled,
701
+ dq, dk, dv,
702
+ l, m,
703
+ delta,
704
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
705
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
706
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
707
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
708
+ q.shape[0], q.shape[1], q.shape[2],
709
+ ctx.grid[0],
710
+ BLOCK_M=ctx.BLOCK_M,
711
+ BLOCK_N=ctx.BLOCK_N,
712
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL,
713
+ NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL,
714
+ num_warps=ctx.num_warps,
715
+ num_stages=1,
716
+ )
717
+ return dq, dk, dv, None, None, None
718
+
719
+
720
+ class _sparse_attention(torch.autograd.Function):
721
+
722
+ @staticmethod
723
+ def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
724
+ BLOCK = 128
725
+ # shape constraints
726
+ return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK)
727
+
728
+ @staticmethod
729
+ def backward(ctx, do):
730
+ # q, k, v, o, l, m = ctx.saved_tensors
731
+ q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
732
+ # TODO: the following is very inefficient.
733
+ # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
734
+ layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
735
+ return _backward(ctx, do, layout_ccol_indices, layout_row_indices)
736
+
737
+
738
+
739
+ # suppressed
740
+ class _sparse_attention_inference(_sparse_attention):
741
+ # TODO: does not work now, as BLOCK_M cannot be <1, as shape for tl.dot cannot be smaller than 16.
742
+ @staticmethod
743
+ def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
744
+ BLOCK = 128
745
+ return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK)
746
+
747
+
748
+
749
+ def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs):
750
+ class _sparse_attention_config(_sparse_attention):
751
+ @staticmethod
752
+ def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
753
+ # shape constraints
754
+ return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
755
+ **kwargs
756
+ )
757
+ return _sparse_attention_config.apply
758
+
759
+
760
+ @lru_cache(maxsize=8)
761
+ def get_local_strided_sparse_attention_op(
762
+ n_heads: int,
763
+ max_seq_len:int,
764
+ sparse_block_size: int=128,
765
+ local_blocks: int=4,
766
+ vert_stride: int=4,
767
+ homo_head: bool=False,
768
+ dtype=torch.bfloat16,
769
+ device='cuda',
770
+ active_head_range=None,
771
+ verbose=True,
772
+ **kwargs):
773
+ '''
774
+ :param n_heads: total number of attention heads (regardless of tensor/model parallel)
775
+ :param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences.
776
+ :param sparse_block_size: sparse block size. Default to 128
777
+ :param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens.
778
+ :param vert_stride: Default to 4. Meaning
779
+ :param homo_head: if all head shared the same pattern.
780
+ :param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads.
781
+ Mainly for tensor/model parallelization where heads are splitted to different GPUs.
782
+ '''
783
+
784
+ if verbose:
785
+ print((f'> new block_sparse_attn op constructed with config: '
786
+ f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, '
787
+ f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}'))
788
+ # assert math.log2(max_seq_len) % 2 == 0, f"max_seq_len should be power of 2 to be more efficient"
789
+ _, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device,
790
+ BLOCK=sparse_block_size, local_blocks=local_blocks,
791
+ vert_stride=vert_stride, homo_head=homo_head,
792
+ return_dense=False)
793
+ if (not homo_head) and (active_head_range is not None):
794
+ assert isinstance(active_head_range, tuple)
795
+ assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.'
796
+ h_start, h_end = active_head_range
797
+ block_sparse_pattern = block_sparse_pattern[h_start:h_end]
798
+ # print(block_sparse_pattern)
799
+ return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs)
800
+
801
+
802
+ def get_sparse_attn_op(
803
+ sparse_pattern: torch.tensor,
804
+ sparse_block_size: int=128,
805
+ kernel_block_size=128,
806
+ qkv_format='q,k,v',
807
+ **kwargs):
808
+ '''
809
+ Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime,
810
+ which is very inefficient (use python loops on CPU. PyTorch 1.13 supports CSR->CSC, may help.)
811
+
812
+ :param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`.
813
+ This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention
814
+ :param sparse_block_size: sparse block size. Default to 128
815
+ :param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size`
816
+ :param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported.
817
+
818
+ :param kwargs: keyward arguments passed to `_forward`
819
+ '''
820
+ # assert qkv_format in ('q,k,v', 'q, kv', 'qkv') # to save from running `concat` at forward/backward
821
+
822
+ assert qkv_format == 'q,k,v'
823
+
824
+ if kernel_block_size is None:
825
+ kernel_block_size = sparse_block_size
826
+ else:
827
+ assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}."
828
+ assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given"
829
+
830
+
831
+ # print(f'>> {sparse_pattern.shape=}')
832
+ # print(f'{sparse_pattern=}')
833
+ if sparse_block_size // kernel_block_size > 1:
834
+ _mul = sparse_block_size // kernel_block_size
835
+ # need to consider if block_m and block_n are different
836
+ sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul))
837
+ num_sparse_blocks = sparse_pattern.size(-1)
838
+ block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
839
+ sparse_pattern *= block_causal_mask.type_as(sparse_pattern)
840
+ # print(f'>> after: {sparse_pattern.shape=}')
841
+ # print(f'{sparse_pattern=}')
842
+
843
+ BLOCK_N = kernel_block_size
844
+ NUM_BLOCK = sparse_pattern.size(-1)
845
+ MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK
846
+
847
+ grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern)
848
+ # sparse csc layout for backward
849
+ grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern)
850
+
851
+
852
+ # cache GPU backward layout. limit the size to avoid OOM as time goes.
853
+ # For inference, one only needs to cache one block as sequence length always increases
854
+ # Therefore, this cache needs to be reconstructed per every `block_size`-steps.
855
+ # For training/finetune, set to 8 to increase cache hit.
856
+ # Given an input, the block_len will be the same for all layers, so cache is very helpful.
857
+
858
+ max_cache_size = 1 if kwargs.get('inference', False) else 8
859
+
860
+ @lru_cache(maxsize=max_cache_size)
861
+ def get_backward_layout_by_block_len(block_len):
862
+ assert block_len <= NUM_BLOCK
863
+ if block_len == NUM_BLOCK:
864
+ return (grand_layout_ccol_indices, grand_layout_row_indices)
865
+ return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len])
866
+
867
+ # for debugging
868
+ # if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
869
+ # print(f'> {sparse_pattern.cpu().tolist()=}')
870
+ # print('----')
871
+ # print(f'> {grand_layout_crow_indices.cpu().tolist()=}\n{grand_layout_col_indices.cpu().tolist()=}')
872
+
873
+
874
+ # q, k, v separated
875
+ class _q_k_v_sparse_attention(torch.autograd.Function):
876
+ @staticmethod
877
+ def forward(ctx, q, k, v, sm_scale):
878
+ # assert q.shape[2] == 1 or q.shape[2] == k.shape[2]
879
+ # shape constraints
880
+ MIN_BLOCK_SIZE = 16
881
+ assert BLOCK_N >= MIN_BLOCK_SIZE
882
+ BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N # BLOCK_M has to be power of 2
883
+
884
+ # this following code only works for causal attention
885
+ K_BLOCKS = triton.cdiv(k.shape[2], kernel_block_size)
886
+ # Q_START_BLOCKS = K_BLOCKS - 1 if q.shape[2] == 1 else 0
887
+ Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N)
888
+ # print(Q_START_BLOCKS, K_BLOCKS)
889
+
890
+ layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1]
891
+ layout_col_indices = grand_layout_col_indices
892
+ # print(BLOCK_M, BLOCK_N, Q_START_BLOCKS, K_BLOCKS+1, layout_crow_indices, layout_col_indices)
893
+
894
+ return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
895
+ **kwargs
896
+ )
897
+ @staticmethod
898
+ def backward(ctx, do):
899
+ q, k = ctx.saved_tensors[:2]
900
+ assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.'
901
+ # assume q, k have same length
902
+ block_len = triton.cdiv(do.shape[2], kernel_block_size)
903
+ backward_layout = get_backward_layout_by_block_len(block_len)
904
+ return _backward(ctx, do, *backward_layout)[:4]
905
+
906
+
907
+ def _q_k_v_sparse_attention_fn(*args):
908
+ return _q_k_v_sparse_attention.apply(*args)
909
+
910
+ _q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern
911
+ _q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices
912
+ _q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices
913
+ _q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices
914
+ _q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices
915
+
916
+ return _q_k_v_sparse_attention_fn
917
+
918
+ ###########################################################
919
+ ###########################################################
920
+
921
+ ###########################################################
922
+ ################ Inference Kernels ########################
923
+ ###########################################################
924
+
925
+ def blocksparse_flash_attn_padded_fwd(
926
+ q, k, v, # (batch, tokens, n_heads, head_size)
927
+ sm_scale,
928
+ sparse_layout,
929
+ *,
930
+ left_paddings = None,
931
+ seqlens = None,
932
+ block_size = 64,
933
+ max_seqlen = None
934
+ ):
935
+ '''
936
+ q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size)
937
+ left_paddings: (batch, ), number of left paddings for each sample.
938
+ seqlens: can be used to specify right padding. No need to specify if left_paddings is used.
939
+ '''
940
+ batches, q_len, n_heads, head_size = q.shape
941
+ _, k_len, n_kv_heads, _ = k.shape
942
+
943
+
944
+ assert q.dim() == k.dim() == v.dim() == 4
945
+ assert q.size(2) % k.size(2) == 0
946
+ assert q.size(0) == k.size(0) and q.size(3) == k.size(3)
947
+ assert k.shape == v.shape # TODO: allow diff head_size for k, v
948
+ assert q_len == 1 or q_len == k_len, \
949
+ f'q length can only 1 for decoding for same as k length for prefilling.'
950
+
951
+ q_k_ratio = q.size(2) // k.size(2)
952
+
953
+ if max_seqlen:
954
+ assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.'
955
+
956
+ # paddings always has zero output, a little slower than using empty
957
+ out = q.new_zeros(q.shape)
958
+
959
+ layout_crow_indices, layout_col_indices = sparse_layout
960
+ block_d = triton.next_power_of_2(head_size)
961
+
962
+ if left_paddings is not None:
963
+ assert left_paddings.shape == (batches,)
964
+ k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous()
965
+ else:
966
+ k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device)
967
+
968
+ if seqlens is not None:
969
+ k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts)
970
+ assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.'
971
+ else:
972
+ k_batch_ends = torch.zeros_like(k_batch_starts) + k_len
973
+
974
+ if q_len == 1:
975
+ q_batch_starts = torch.zeros_like(k_batch_starts)
976
+ q_batch_ends = q_batch_starts + 1
977
+ else:
978
+ q_batch_starts = k_batch_starts
979
+ q_batch_ends = k_batch_ends
980
+
981
+ # switch to use cpu to avoid too many kernel lauch when iterate over
982
+ q_lens = (q_batch_ends - q_batch_starts).cpu()
983
+ n_blocks = (q_lens + block_size - 1) // block_size
984
+
985
+ q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
986
+ dtype=q_batch_starts.dtype,
987
+ device=q_batch_starts.device)
988
+ q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
989
+ dtype=q_batch_starts.dtype,
990
+ device=q_batch_starts.device)
991
+
992
+ grid = (len(q_start_sids), n_heads)
993
+
994
+ _fwd_kernel_batch_inference[grid](
995
+ q, k, v, out,
996
+ sm_scale,
997
+ q_batch_starts,
998
+ q_batch_ends,
999
+ k_batch_starts,
1000
+ k_batch_ends,
1001
+ q_batch_ids,
1002
+ q_start_sids,
1003
+
1004
+ *q.stride(),
1005
+ *k.stride(),
1006
+ *v.stride(),
1007
+ *out.stride(),
1008
+
1009
+ layout_crow_indices,
1010
+ layout_col_indices,
1011
+ *layout_crow_indices.stride(),
1012
+ *layout_col_indices.stride(),
1013
+
1014
+ q_k_ratio,
1015
+ HAS_BATCH_DIM = True,
1016
+ D_HEAD = head_size,
1017
+ BLOCK_M = block_size,
1018
+ BLOCK_N = block_size,
1019
+ BLOCK_D = block_d,
1020
+ BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1021
+ EVEN_D = block_d == head_size,
1022
+ num_warps = 1 if q_len == 1 else 4,
1023
+ num_stages = 3
1024
+ )
1025
+
1026
+ return out
1027
+
1028
+
1029
+ def blocksparse_flash_attn_varlen_fwd(
1030
+ q, k, v, # (#tokens, n_heads, head_size)
1031
+ cu_seqlens_k,
1032
+ cu_seqlens_q,
1033
+ sm_scale,
1034
+ sparse_layout,
1035
+ *,
1036
+ block_size=64,
1037
+ max_seqlen = None
1038
+ ):
1039
+ # split q to blocks
1040
+ _, n_heads, head_size = q.shape
1041
+ batch_size = cu_seqlens_k.size(0) - 1
1042
+
1043
+
1044
+ # print(f'> {q.shape=}, {k.shape=}')
1045
+ assert q.dim() == k.dim() == v.dim() == 3
1046
+ assert q.size(1) % k.size(1) == 0
1047
+ assert q.size(2) == k.size(2)
1048
+ assert k.shape == v.shape # TODO: allow diff head_size for k, v
1049
+ assert cu_seqlens_k.dim() == 1
1050
+
1051
+ q_k_ratio = q.size(1) // k.size(1)
1052
+
1053
+ if cu_seqlens_q is None:
1054
+ if q.size(0) == batch_size: # decoding only
1055
+ cu_seqlens_q = torch.arange(0, batch_size + 1,
1056
+ dtype=cu_seqlens_k.dtype,
1057
+ device=cu_seqlens_k.device)
1058
+ elif q.size(0) == k.size(0):
1059
+ cu_seqlens_q = cu_seqlens_k
1060
+ else:
1061
+ raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.')
1062
+ else:
1063
+ assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
1064
+
1065
+ # switch to use cpu to avoid too many kernel lauch when iterate over
1066
+ q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
1067
+ k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
1068
+
1069
+ assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \
1070
+ 'length of q should either be 1 (decoding) or same as k (prefilling).'
1071
+
1072
+ if max_seqlen:
1073
+ assert k_lens.max() <= max_seqlen
1074
+
1075
+ n_blocks = (q_lens + block_size - 1) // block_size
1076
+
1077
+ q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
1078
+ dtype=cu_seqlens_q.dtype,
1079
+ device=cu_seqlens_q.device)
1080
+ q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
1081
+ dtype=cu_seqlens_q.dtype,
1082
+ device=cu_seqlens_q.device)
1083
+
1084
+
1085
+ out = q.new_empty(q.shape)
1086
+ cu_seqlens_q = cu_seqlens_q.contiguous()
1087
+ cu_seqlens_k = cu_seqlens_k.contiguous()
1088
+
1089
+ layout_crow_indices, layout_col_indices = sparse_layout
1090
+ block_d = triton.next_power_of_2(head_size)
1091
+
1092
+ decoding_only = (q_lens == 1).all()
1093
+
1094
+ grid = (len(q_start_sids), n_heads)
1095
+
1096
+ _fwd_kernel_batch_inference[grid](
1097
+ q, k, v, out,
1098
+ sm_scale,
1099
+ cu_seqlens_q[:-1],
1100
+ cu_seqlens_q[1:],
1101
+ cu_seqlens_k[:-1],
1102
+ cu_seqlens_k[1:],
1103
+ q_batch_ids,
1104
+ q_start_sids,
1105
+
1106
+ 0, *q.stride(),
1107
+ 0, *k.stride(),
1108
+ 0, *v.stride(),
1109
+ 0, *out.stride(),
1110
+
1111
+ layout_crow_indices,
1112
+ layout_col_indices,
1113
+ *layout_crow_indices.stride(),
1114
+ *layout_col_indices.stride(),
1115
+
1116
+ q_k_ratio,
1117
+ HAS_BATCH_DIM = False,
1118
+ D_HEAD = head_size,
1119
+ BLOCK_M = block_size,
1120
+ BLOCK_N = block_size,
1121
+ BLOCK_D = block_d,
1122
+ BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
1123
+ EVEN_D = block_d == head_size,
1124
+ num_warps = 1 if decoding_only else 4,
1125
+ num_stages = 3
1126
+ )
1127
+
1128
+ return out
1129
+
1130
+
1131
+ @triton.jit
1132
+ def _fwd_kernel_inner(
1133
+ acc, l_i, m_i,
1134
+ q, Q,
1135
+ k_block_col_idx,
1136
+ layout_col_ptr,
1137
+ layout_col_stride_h, layout_col_stride_m,
1138
+ k_ptrs,
1139
+ v_ptrs,
1140
+ off_h, offs_m, offs_n, offs_d,
1141
+ stride_kt, stride_vt,
1142
+ sm_scale,
1143
+ k_seqlen,
1144
+ past_len,
1145
+ LAST_K_BLOCK: tl.constexpr,
1146
+ BLOCK_M_LOADING: tl.constexpr,
1147
+ BLOCK_N: tl.constexpr,
1148
+ D_HEAD: tl.constexpr,
1149
+ EVEN_D: tl.constexpr,
1150
+ M_LT_N: tl.constexpr
1151
+ ):
1152
+ k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32)
1153
+ start_n = k_block_id * BLOCK_N
1154
+ # -- compute qk ----
1155
+ if LAST_K_BLOCK:
1156
+ if EVEN_D:
1157
+ k = tl.load(k_ptrs + start_n * stride_kt,
1158
+ mask=offs_n[None, :] + start_n < k_seqlen)
1159
+ else:
1160
+ # mask = mask & (offs_d[:, ])
1161
+ k = tl.load(k_ptrs + start_n * stride_kt,
1162
+ mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD))
1163
+ else:
1164
+ if EVEN_D:
1165
+ k = tl.load(k_ptrs + start_n * stride_kt)
1166
+ else:
1167
+ k = tl.load(k_ptrs + start_n * stride_kt,
1168
+ mask=offs_d[:, None] < D_HEAD)
1169
+
1170
+
1171
+ qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
1172
+ qk += tl.dot(q, k)
1173
+
1174
+ qk *= sm_scale
1175
+
1176
+ # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
1177
+ if LAST_K_BLOCK | M_LT_N:
1178
+ qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf'))
1179
+
1180
+ # -- compute m_ij, p, l_ij
1181
+ m_ij = tl.max(qk, 1)
1182
+ p = tl.exp(qk - m_ij[:, None])
1183
+
1184
+ l_ij = tl.sum(p, 1)
1185
+ # -- update m_i and l_i
1186
+ m_i_new = tl.maximum(m_i, m_ij)
1187
+ alpha = tl.exp(m_i - m_i_new)
1188
+ beta = tl.exp(m_ij - m_i_new)
1189
+ l_i_new = alpha * l_i + beta * l_ij
1190
+ # -- update output accumulator --
1191
+ # scale p
1192
+ p_scale = beta / l_i_new
1193
+ p = p * p_scale[:, None]
1194
+ # scale acc
1195
+ acc_scale = l_i / l_i_new * alpha
1196
+ acc = acc * acc_scale[:, None]
1197
+
1198
+ p = p.to(Q.dtype.element_ty)
1199
+ # update acc
1200
+ if LAST_K_BLOCK:
1201
+ if EVEN_D:
1202
+ v = tl.load(v_ptrs + start_n * stride_vt,
1203
+ mask=offs_n[:, None] + start_n < k_seqlen)
1204
+ else:
1205
+ v = tl.load(v_ptrs + start_n * stride_vt,
1206
+ mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD))
1207
+ else:
1208
+ if EVEN_D:
1209
+ v = tl.load(v_ptrs + start_n * stride_vt)
1210
+ else:
1211
+ v = tl.load(v_ptrs + start_n * stride_vt,
1212
+ mask=offs_d[None, :] < D_HEAD)
1213
+
1214
+ acc += tl.dot(p, v)
1215
+ # update m_i and l_i
1216
+ l_i = l_i_new
1217
+ m_i = m_i_new
1218
+ return acc, l_i, m_i
1219
+
1220
+
1221
+ @triton.heuristics(
1222
+ {
1223
+ 'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'],
1224
+ }
1225
+ )
1226
+ @triton.jit
1227
+ def _fwd_kernel_batch_inference(
1228
+ Q, K, V, Out,
1229
+
1230
+ sm_scale,
1231
+ q_batch_starts,
1232
+ q_batch_ends,
1233
+ k_batch_starts,
1234
+ k_batch_ends,
1235
+ q_batch_ids,
1236
+ q_start_sids,
1237
+
1238
+ stride_qb, stride_qt, stride_qh, stride_qd,
1239
+ stride_kb, stride_kt, stride_kh, stride_kd,
1240
+ stride_vb, stride_vt, stride_vh, stride_vd,
1241
+ stride_ob, stride_ot, stride_oh, stride_od,
1242
+
1243
+ layout_crow_ptr,
1244
+ layout_col_ptr,
1245
+ layout_crow_stride_h, layout_crow_stride_m,
1246
+ layout_col_stride_h, layout_col_stride_m,
1247
+
1248
+ q_k_ratio,
1249
+
1250
+ HAS_BATCH_DIM: tl.constexpr,
1251
+ D_HEAD: tl.constexpr,
1252
+ BLOCK_M: tl.constexpr,
1253
+ BLOCK_N: tl.constexpr,
1254
+ BLOCK_D: tl.constexpr,
1255
+ BLOCK_M_LOADING: tl.constexpr,
1256
+ EVEN_D: tl.constexpr,
1257
+ M_LT_N: tl.constexpr
1258
+ ):
1259
+ '''
1260
+ NOTATION:
1261
+ pid: position id
1262
+ sid: storage id
1263
+ sbid: storage block id
1264
+ pbid: position block id
1265
+ offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
1266
+
1267
+ q and blocks in KV needs to be contiguous
1268
+
1269
+ Arguments:
1270
+ kv_seq_lens: for compute past_len
1271
+ kv_storage_offsets: similar to block_tables in vllm, except it is dynamic.
1272
+ TODO: fix this
1273
+
1274
+ TODO:
1275
+ Optimize grouped-attn
1276
+
1277
+ CUDA graph support issue
1278
+ 1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...)
1279
+ since we mix prompt and decoing phase here, it can be more complex.
1280
+ need to set up diff cuda-graph for diff (off_zm, off_z)
1281
+
1282
+ # indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding
1283
+ therefore, cu_seqlens_q, kv_seq_lens
1284
+
1285
+ '''
1286
+ off_zm = tl.program_id(0)
1287
+ off_h = tl.program_id(1)
1288
+
1289
+ off_h_for_kv = off_h // q_k_ratio
1290
+ off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
1291
+ q_start_sid = tl.load(q_start_sids + off_zm)
1292
+ start_m = q_start_sid // BLOCK_M
1293
+
1294
+ if HAS_BATCH_DIM:
1295
+ Q += off_z * stride_qb
1296
+ K += off_z * stride_kb
1297
+ V += off_z * stride_vb
1298
+ Out += off_z * stride_ob
1299
+
1300
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
1301
+ offs_n = tl.arange(0, BLOCK_N)
1302
+ offs_d = tl.arange(0, BLOCK_D)
1303
+
1304
+ q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
1305
+ q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
1306
+
1307
+ k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
1308
+ k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
1309
+
1310
+ past_len = k_seqlen - q_seqlen
1311
+
1312
+ Q += q_cu_start * stride_qt + off_h * stride_qh
1313
+ K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
1314
+ V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
1315
+ Out += q_cu_start * stride_ot + off_h * stride_oh
1316
+
1317
+ q_pbid = (past_len + q_start_sid) // BLOCK_M
1318
+
1319
+ if EVEN_D:
1320
+ q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
1321
+ mask=offs_m[:, None] < q_seqlen)
1322
+ else:
1323
+ q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
1324
+ mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
1325
+ other=0)
1326
+
1327
+ sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m
1328
+
1329
+ # TODO: load at once, supported in new Triton
1330
+ k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
1331
+ k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
1332
+
1333
+ m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf')
1334
+ l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
1335
+ acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
1336
+
1337
+ k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
1338
+ v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
1339
+
1340
+ for k_block_col_idx in range(k_block_start, k_block_end - 1):
1341
+ acc, l_i, m_i = _fwd_kernel_inner(
1342
+ acc, l_i, m_i,
1343
+ q, Q,
1344
+ k_block_col_idx,
1345
+ layout_col_ptr,
1346
+ layout_col_stride_h, layout_col_stride_m,
1347
+ k_ptrs,
1348
+ v_ptrs,
1349
+ off_h, offs_m, offs_n, offs_d,
1350
+ stride_kt, stride_vt,
1351
+ sm_scale,
1352
+ k_seqlen,
1353
+ past_len,
1354
+ False,
1355
+ BLOCK_M_LOADING,
1356
+ BLOCK_N,
1357
+ D_HEAD,
1358
+ EVEN_D,
1359
+ M_LT_N
1360
+ )
1361
+
1362
+ acc, l_i, m_i = _fwd_kernel_inner(
1363
+ acc, l_i, m_i,
1364
+ q, Q,
1365
+ k_block_end - 1,
1366
+ layout_col_ptr,
1367
+ layout_col_stride_h, layout_col_stride_m,
1368
+ k_ptrs,
1369
+ v_ptrs,
1370
+ off_h, offs_m, offs_n, offs_d,
1371
+ stride_kt, stride_vt,
1372
+ sm_scale,
1373
+ k_seqlen,
1374
+ past_len,
1375
+ True,
1376
+ BLOCK_M_LOADING,
1377
+ BLOCK_N,
1378
+ D_HEAD,
1379
+ EVEN_D,
1380
+ M_LT_N
1381
+ )
1382
+
1383
+ # write output
1384
+ if EVEN_D:
1385
+ tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
1386
+ mask=offs_m[:, None] < q_seqlen)
1387
+ else:
1388
+ tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
1389
+ mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD))
1390
+
1391
+
1392
+ ###########################################################
1393
+ ###########################################################
1394
+
1395
+ ###########################################################
1396
+ ################## Testing Utilities ######################
1397
+ ###########################################################
1398
+
1399
+
1400
+ def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None):
1401
+ '''
1402
+ q, k, v: shape=(batch, n_heads, seq, dim)
1403
+ '''
1404
+ # for verification
1405
+ if sm_scale is None:
1406
+ sm_scale = math.sqrt(float(q.size(-1)))
1407
+
1408
+ if block_attn_mask is not None:
1409
+ assert attn_mask is None
1410
+ outs = []
1411
+ for s in range(0, q.size(2), block_size):
1412
+ e = min(s + block_size, q.size(2))
1413
+ q_block = q[:, :, s:e]
1414
+ attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale
1415
+ mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)]
1416
+ mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device))
1417
+ mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0)
1418
+ attn = attn.masked_fill((1 - mask).bool(), float('-inf'))
1419
+ attn = attn.softmax(-1)
1420
+ out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e])
1421
+ outs.append(out)
1422
+ torch_output = torch.cat(outs, dim=2)
1423
+ else:
1424
+ attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale
1425
+ # import ipdb; ipdb.set_trace()
1426
+ if attn_mask is not None:
1427
+ attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf'))
1428
+ # print(f'> torch attn: {attn.exp().sum(-1)=}')
1429
+
1430
+ attn = attn.softmax(-1)
1431
+ if do is not None:
1432
+ dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do)
1433
+ print(f'> torch_attn computed dv: {dv=}')
1434
+ torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v)
1435
+ return torch_output
1436
+
1437
+ ###########################################################
1438
+ ###########################################################
1439
+
1440
+ ###########################################################
1441
+ #################### Unit Tests ###########################
1442
+ ###########################################################
1443
+
1444
+
1445
+ @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)])
1446
+ def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True,
1447
+ sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None):
1448
+ Q_LEN = Q_LEN or N_CTX
1449
+ torch.manual_seed(20)
1450
+ q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
1451
+ k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
1452
+ v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
1453
+
1454
+ if sm_scale is None:
1455
+ sm_scale = 1. / math.sqrt(D_HEAD)
1456
+
1457
+ # for debugging
1458
+ # print(f'>> {q.shape=}, {k.shape=}, {v.shape=}, {homo_head=}, {kernel_block_size=}, {sparse_block_size=}, {local_blocks=}, {vert_stride=}')
1459
+ sm_scale = 0.0078125
1460
+ if backward:
1461
+ q.requires_grad_(), k.requires_grad_(), v.requires_grad_()
1462
+
1463
+ # qkv = torch.empty((Z, N_CTX, 3*H*D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)
1464
+ # q = qkv[..., :H*D_HEAD]
1465
+ # k = qkv[..., H*D_HEAD:2*H*D_HEAD]
1466
+ # v = qkv[..., 2*H*D_HEAD:]
1467
+ # q = q.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
1468
+ # k = k.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
1469
+ # v = v.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
1470
+
1471
+ # if Q_LEN and Q_LEN < N_CTX:
1472
+ # q = q[:, :, -Q_LEN:] # .contiguous()
1473
+
1474
+ # q = q.requires_grad_()
1475
+ # k = k.requires_grad_()
1476
+ # v = v.requires_grad_()
1477
+
1478
+ dout = torch.randn_like(q).contiguous()
1479
+
1480
+ # dout = torch.eye(N_CTX)[:, :D_HEAD][None, None].expand_as(q).type_as(q).contiguous()
1481
+ # print(dout)
1482
+
1483
+ mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size,
1484
+ local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True)
1485
+
1486
+ if sparse_attention_fn is None:
1487
+ sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX,
1488
+ sparse_block_size=sparse_block_size,
1489
+ local_blocks=local_blocks,
1490
+ vert_stride=vert_stride,
1491
+ homo_head=homo_head,
1492
+ device=q.device,
1493
+ dtype=q.dtype,
1494
+ kernel_block_size=kernel_block_size)
1495
+ # reference implementation
1496
+ ref_out = torch_attention(q, k, v, mask_dense, sm_scale)
1497
+
1498
+ # lengths = torch.full((Z,), fill_value=N_CTX, device='cuda')
1499
+ # cu_seqlens = torch.zeros((Z + 1,), device='cuda', dtype=torch.int32)
1500
+ # cu_seqlens[1:] = lengths.cumsum(0)
1501
+ # # qkv = torch.randn((Z * N_CTX, 3, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1502
+
1503
+ # qkv_list = list(map(lambda x: x.permute(0, 2, 1, 3).contiguous().view(Z * N_CTX, 1, H, D_HEAD), [q, k, v]))
1504
+ # qkv = torch.cat(qkv_list, dim=1)
1505
+ # ref_out0 = flash_attn_func(qkv, cu_seqlens, dropout_p=0, max_s=N_CTX, softmax_scale=sm_scale, causal=True)
1506
+ # ref_out = ref_out0.view(Z, N_CTX, H, D_HEAD).permute(0, 2, 1, 3).contiguous()
1507
+
1508
+
1509
+ if backward:
1510
+ ref_out.backward(dout)
1511
+ ref_dv, v.grad = v.grad.clone(), None
1512
+ ref_dk, k.grad = k.grad.clone(), None
1513
+ ref_dq, q.grad = q.grad.clone(), None
1514
+
1515
+ tri_out = sparse_attention_fn(q, k, v, sm_scale)
1516
+
1517
+ decimal = 1 if dtype == torch.bfloat16 else 2
1518
+ assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}'
1519
+
1520
+ if backward:
1521
+ tri_out.backward(dout)
1522
+ tri_dv, v.grad = v.grad.clone(), None
1523
+ tri_dk, k.grad = k.grad.clone(), None
1524
+ tri_dq, q.grad = q.grad.clone(), None
1525
+
1526
+ if backward:
1527
+ assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
1528
+ assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
1529
+ assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
1530
+
1531
+ print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}')
1532
+
1533
+ ###########################################################
1534
+
1535
+ if __name__ == '__main__':
1536
+
1537
+ GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip()
1538
+ # print(GPU_TYPE)
1539
+ support_backward = True # 'A100' in GPU_TYPE. Wasn't supportted in consumer A1000.
1540
+
1541
+ ###############
1542
+ # benchmarking
1543
+
1544
+ HAS_DENSE_TRITON_FLASH = False
1545
+ # try:
1546
+ # from triton.ops.flash_attention import attention as triton_attention
1547
+ # HAS_DENSE_TRITON_FLASH = True
1548
+ # except:
1549
+ # HAS_DENSE_TRITON_FLASH = False
1550
+ # print('> cannot import Trition flash attn')
1551
+
1552
+ try:
1553
+ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func
1554
+ HAS_FLASH = True
1555
+ except BaseException:
1556
+ HAS_FLASH = False
1557
+ print('> cannot import flash_attn')
1558
+
1559
+
1560
+ # BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
1561
+ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128 # 6.7B model, with 4k len
1562
+ # BATCH, N_HEADS, N_CTX, D_HEAD = 4, 16, 4096, 128 # 204m model
1563
+
1564
+ BLOCK_SIZE = 64
1565
+ LOCAl_BLOCKS = 8 # 4
1566
+ VERT_STRIDE = 1 # 16 # 8
1567
+ HOMO_HEAD = False
1568
+ sparse_type = 'home' if HOMO_HEAD else 'hetero'
1569
+ dtype = torch.bfloat16
1570
+
1571
+
1572
+ modes = ['fwd', 'bwd'] if support_backward else ['fwd']
1573
+
1574
+ configs = [triton.testing.Benchmark(
1575
+ x_names=['SEQ_LEN'],
1576
+ x_vals=[2**i for i in range(8, 16)],
1577
+ line_arg='provider',
1578
+ line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'],
1579
+ line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else []) + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'],
1580
+ styles=[('red', '-'), ('blue', '-'), ('green', '-')],
1581
+ ylabel='ms',
1582
+ plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}',
1583
+ args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode}
1584
+ ) for mode in modes]
1585
+
1586
+
1587
+ @triton.testing.perf_report(configs)
1588
+ def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None):
1589
+ assert mode in ['fwd', 'bwd']
1590
+ warmup = 25
1591
+ rep = 100
1592
+ N_CTX = SEQ_LEN
1593
+ if provider == 'triton':
1594
+ q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1595
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1596
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1597
+ sm_scale = 1.3
1598
+ fn = lambda: triton_attention(q, k, v, sm_scale)
1599
+ if mode == 'bwd':
1600
+ o = fn()
1601
+ do = torch.randn_like(o)
1602
+ fn = lambda: o.backward(do, retain_graph=True)
1603
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1604
+ return ms
1605
+ if provider == 'triton_sparse':
1606
+ q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1607
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1608
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1609
+ sm_scale = 1.3
1610
+ # q_pos = torch.arange(N_CTX // BLOCK, device='cuda')[:, None]
1611
+ # k_pos = torch.arange(N_CTX // BLOCK, device='cuda')[None]
1612
+ # local_blocks = 4 # num_block per attn, block_size is tied to BLOCK
1613
+ # vert_stride =N_CTX + 1 # 4
1614
+ # mask_vert_strided = torch.arange(N_CTX // BLOCK, device='cuda') % vert_stride == vert_stride - 1
1615
+ # mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).type_as(q)
1616
+ # mask = mask_dense.to_sparse_csr()
1617
+ # mask_csr, _ = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK, local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=HOMO_HEAD)
1618
+
1619
+ if sparse_attention_fn is None:
1620
+ # sparse_attention_fn = sparse_attention
1621
+ sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN,
1622
+ local_blocks=LOCAl_BLOCKS,
1623
+ vert_stride=VERT_STRIDE,
1624
+ homo_head=HOMO_HEAD,
1625
+ sparse_block_size=BLOCK_SIZE,
1626
+ kernel_block_size=BLOCK_SIZE,
1627
+ device=q.device)
1628
+ # sparse_attention_fn = sparse_attention_factory(128, 128, num_warps=8)
1629
+
1630
+ # fn = lambda: sparse_attention_fn(q, k, v, mask_csr[0], mask_csr[1], sm_scale)
1631
+ fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
1632
+ if mode == 'bwd':
1633
+ o = fn()
1634
+ do = torch.randn_like(o)
1635
+ fn = lambda: o.backward(do, retain_graph=True)
1636
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1637
+ return ms
1638
+ if provider == 'flash':
1639
+ lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
1640
+ cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
1641
+ cu_seqlens[1:] = lengths.cumsum(0)
1642
+ qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
1643
+ fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
1644
+ if mode == 'bwd':
1645
+ o = fn()
1646
+ do = torch.randn_like(o)
1647
+ fn = lambda: o.backward(do, retain_graph=True)
1648
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1649
+ return ms
1650
+
1651
+ # if provider == 'torch':
1652
+ # q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1653
+ # k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1654
+ # v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1655
+ # sm_scale = 1.3
1656
+ # causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(q)
1657
+ # fn = lambda: torch_attention(q, k, v, causal_mask, sm_scale)
1658
+ # ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
1659
+ # return ms
1660
+
1661
+
1662
+ BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1 # 6.7B model, with 4k len
1663
+
1664
+ BLOCK_SIZE = 64
1665
+ LOCAl_BLOCKS = 8 # 4
1666
+ VERT_STRIDE = 16 # 8
1667
+ HOMO_HEAD = False
1668
+ sparse_type = 'home' if HOMO_HEAD else 'hetero'
1669
+ dtype = torch.bfloat16
1670
+ MAX_N_CTX = 8192
1671
+
1672
+ configs = [triton.testing.Benchmark(
1673
+ x_names=['PAST_LEN'],
1674
+ x_vals=[2**i - 1 for i in range(8, 14)],
1675
+ line_arg='provider',
1676
+ line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'],
1677
+ line_names=['Torch'] + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'],
1678
+ styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')],
1679
+ ylabel='ms',
1680
+ plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}',
1681
+ args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode}
1682
+ ) for mode in ['fwd']]
1683
+ @triton.testing.perf_report(configs)
1684
+ def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'):
1685
+ assert mode in ['fwd']
1686
+ warmup = 25
1687
+ rep = 100
1688
+ N_CTX = PAST_LEN + Q_LEN
1689
+ if provider == 'torch':
1690
+ q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1691
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1692
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1693
+ sm_scale = 1.3
1694
+ mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE,
1695
+ local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True)
1696
+
1697
+ fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048)
1698
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1699
+ return ms
1700
+ if provider == 'triton_sparse':
1701
+ q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1702
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1703
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1704
+ sm_scale = 1.3
1705
+ sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
1706
+ local_blocks=LOCAl_BLOCKS,
1707
+ vert_stride=VERT_STRIDE,
1708
+ homo_head=HOMO_HEAD,
1709
+ sparse_block_size=BLOCK_SIZE,
1710
+ kernel_block_size=BLOCK_SIZE,
1711
+ device=q.device,
1712
+ inference=True)
1713
+
1714
+ fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
1715
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1716
+ return ms
1717
+ if provider == 'triton_dense':
1718
+ q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1719
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1720
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1721
+ sm_scale = 1.3
1722
+ sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
1723
+ local_blocks=1,
1724
+ vert_stride=1,
1725
+ homo_head=True,
1726
+ sparse_block_size=BLOCK_SIZE,
1727
+ kernel_block_size=BLOCK_SIZE,
1728
+ device=q.device,
1729
+ inference=True)
1730
+
1731
+ fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
1732
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1733
+ return ms
1734
+ if provider == 'flash':
1735
+ assert Q_LEN == 1
1736
+ lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
1737
+ cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
1738
+ cu_seqlens[1:] = lengths.cumsum(0)
1739
+ cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32)
1740
+
1741
+ # (total_q, nheads, headdim),
1742
+ q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1743
+ k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1744
+ v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1745
+
1746
+ fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False)
1747
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1748
+ return ms
1749
+
1750
+
1751
+ test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
1752
+ # bench_flash_attention.run(save_path='.', print_data=True)
1753
+
1754
+ bench_flash_attention_inference.run(save_path='.', print_data=True)
1755
+ exit()
1756
+ # head_dim=64
1757
+ test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64,
1758
+ dtype=torch.bfloat16, homo_head=False, backward=support_backward)
1759
+ # uneven length, bf16
1760
+ test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128,
1761
+ kernel_block_size=64, local_blocks=8, vert_stride=8)
1762
+ test_op(3, 2, 2047, 128, homo_head=False, backward=False)
1763
+
1764
+ # diff kernel/sparse block size
1765
+ test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64)
1766
+ # inference
1767
+ # test_op(1, 4, 512 + 256, 128, Q_LEN=1, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
1768
+
1769
+ # dense flash attn
1770
+ test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False,
1771
+ backward=support_backward, local_blocks=1, vert_stride=1)
1772
+
1773
+ # fp16
1774
+ test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
1775
+
1776
+ # longer sequence
1777
+ test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward)
1778
+ test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
1779
+
1780
+ # homo head
1781
+ test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False)
1782
+ test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward)
1783
+
1784
+ # sparse_attention_fn = sparse_attention_factory(16, 128, num_warps=1, INFERENCE=True)
1785
+ # test_op(8, 1, 2047, 128, 1, backward=False, sparse_attention_fn=None)
1786
+ # test_op_inference(3, 2, 2048, 128, 2048)
1787
+ # test_op_inference(3, 2, 2047, 64, 2047)
1788
+ # test_op_inference(3, 2, 256, 64, 128)
1789
+ # test_op_inference(3, 2, 2048, 64, 1)
1790
+
1791
+ bench_flash_attention.run(save_path='.', print_data=True)
1792
+ # bench_flash_attention_inference.run(save_path='.', print_data=True)
1793
+
1794
+ # ========================
1795
+ # Some Benchmark Results #
1796
+ # ========================
1797
+
1798
+ # fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-fwd
1799
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1800
+ # 0 256.0 0.057184 0.069646 0.052567
1801
+ # 1 512.0 0.131688 0.187658 0.110212
1802
+ # 2 1024.0 0.391844 0.524990 0.247875
1803
+ # 3 2048.0 1.305190 1.456685 0.596506
1804
+ # 4 4096.0 4.623019 4.968653 1.600277
1805
+ # 5 8192.0 17.513062 18.332262 4.802458
1806
+ # 6 16384.0 68.453377 70.337540 16.052908
1807
+ # 7 32768.0 270.655487 276.020233 57.938946
1808
+ # fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-bwd (num_warp=8):
1809
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1810
+ # 0 256.0 0.190120 0.150313 0.181451
1811
+ # 1 512.0 0.406348 0.391767 0.391177
1812
+ # 2 1024.0 1.029704 1.182967 0.885741
1813
+ # 3 2048.0 2.985456 3.843399 2.040469
1814
+ # 4 4096.0 9.808897 13.073701 5.069609
1815
+ # 5 8192.0 34.995201 47.863808 13.948782
1816
+ # 6 16384.0 132.740097 182.579193 42.816513
1817
+ # 7 32768.0 542.223389 714.820618 147.053574
1818
+ # fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
1819
+ # PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
1820
+ # 0 256.0 0.050949 0.032357 0.107513
1821
+ # 1 512.0 0.073624 0.050651 0.199086
1822
+ # 2 1024.0 0.107472 0.080379 0.245445
1823
+ # 3 2048.0 0.178423 0.129448 0.338259
1824
+ # 4 4096.0 0.327647 0.223106 0.517048
1825
+ # 5 8192.0 0.588423 0.411263 0.884606
1826
+ # 6 16384.0 1.098898 0.798941 1.611809
1827
+ # 7 32768.0 2.094537 1.594726 3.044160
1828
+
1829
+
1830
+ # 6.7B
1831
+ # fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-fwd:
1832
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1833
+ # 0 256.0 0.069208 0.082156 0.065097
1834
+ # 1 512.0 0.138271 0.201393 0.144467
1835
+ # 2 1024.0 0.391521 0.624614 0.322382
1836
+ # 3 2048.0 1.268443 2.406325 0.784367
1837
+ # 4 4096.0 4.455703 9.139097 2.100856
1838
+ # 5 8192.0 16.764315 35.289600 6.328320
1839
+ # 6 16384.0 65.221634 138.401794 21.069057
1840
+ # 7 32768.0 257.251343 548.085754 76.111870
1841
+ # fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-bwd:
1842
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1843
+ # 0 256.0 0.297118 0.266469 0.255255
1844
+ # 1 512.0 0.672826 0.613685 0.552954
1845
+ # 2 1024.0 1.718434 1.705066 1.251953
1846
+ # 3 2048.0 4.936755 5.403875 2.927895
1847
+ # 4 4096.0 15.911594 18.959362 7.436288
1848
+ # 5 8192.0 55.357441 70.808578 21.140224
1849
+ # 6 16384.0 208.188416 273.617920 68.018173
1850
+ # 7 32768.0 806.037476 1081.453613 218.720261
1851
+ # fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
1852
+ # PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
1853
+ # 0 256.0 0.050151 0.032337 0.107593
1854
+ # 1 512.0 0.073409 0.051737 0.200200
1855
+ # 2 1024.0 0.107533 0.082099 0.247067
1856
+ # 3 2048.0 0.177259 0.128891 0.338510
1857
+ # 4 4096.0 0.325866 0.223621 0.524842
1858
+ # 5 8192.0 0.586926 0.408913 0.885490
1859
+ # 6 16384.0 1.100834 0.793277 1.612271
1860
+ # 7 32768.0 2.098851 1.595831 3.064544
1861
+
1862
+ # fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-fwd:
1863
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1864
+ # 0 256.0 0.066673 0.082037 0.065085
1865
+ # 1 512.0 0.137379 0.201880 0.143473
1866
+ # 2 1024.0 0.390675 0.624234 0.312046
1867
+ # 3 2048.0 1.267739 2.406950 0.696045
1868
+ # 4 4096.0 4.445138 9.136333 1.665788
1869
+ # 5 8192.0 16.768614 35.265533 4.380486
1870
+ # 6 16384.0 65.235970 138.393600 12.997633
1871
+ # 7 32768.0 257.317902 550.442993 42.821121
1872
+ # fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-bwd:
1873
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1874
+ # 0 256.0 0.296461 0.266581 0.254022
1875
+ # 1 512.0 0.671427 0.613643 0.551283
1876
+ # 2 1024.0 1.719918 1.704295 1.229982
1877
+ # 3 2048.0 4.945305 5.403364 2.721906
1878
+ # 4 4096.0 15.934293 18.960999 6.259371
1879
+ # 5 8192.0 55.406593 70.832130 15.676929
1880
+ # 6 16384.0 208.750595 275.004425 44.837891
1881
+ # 7 32768.0 808.057861 1080.647705 141.856766
1882
+ # fused-attention-inference-batch4-head32-d128-sparse-local4-vert8-hetero:
1883
+ # PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
1884
+ # 0 256.0 0.050739 0.032886 0.107837
1885
+ # 1 512.0 0.073507 0.051996 0.200293
1886
+ # 2 1024.0 0.106394 0.080679 0.240610
1887
+ # 3 2048.0 0.177659 0.127660 0.287625
1888
+ # 4 4096.0 0.326326 0.226971 0.377500
1889
+ # 5 8192.0 0.586339 0.407367 0.559266
1890
+ # 6 16384.0 1.102279 0.786221 0.920976
1891
+ # 7 32768.0 2.097370 1.545090 1.644288
1892
+
1893
+
1894
+ ################
1895
+ ##### fp16 #####
1896
+ ################
1897
+
1898
+ # fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
1899
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1900
+ # 0 256.0 0.032518 0.035472 0.029939
1901
+ # 1 512.0 0.054266 0.087841 0.054320
1902
+ # 2 1024.0 0.133447 0.263090 0.102045
1903
+ # 3 2048.0 0.384615 1.023293 0.201763
1904
+ # 4 4096.0 1.300890 4.023936 0.449555
1905
+ # 5 8192.0 4.774144 15.816704 1.150854
1906
+ # 6 16384.0 18.220032 62.771198 3.356001
1907
+ # 7 32768.0 71.405571 250.273788 10.976142
1908
+ # fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
1909
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1910
+ # 0 256.0 0.083342 0.069742 0.079496
1911
+ # 1 512.0 0.159894 0.170995 0.151705
1912
+ # 2 1024.0 0.386071 0.522407 0.331443
1913
+ # 3 2048.0 1.067715 1.737333 0.715248
1914
+ # 4 4096.0 3.382731 6.219520 1.597457
1915
+ # 5 8192.0 11.857793 23.560448 3.879035
1916
+ # 6 16384.0 44.422142 91.251709 10.626843
1917
+ # 7 32768.0 175.011841 359.473145 32.340992
1918
+
1919
+
1920
+ ################
1921
+ ##### bf16 #####
1922
+ ################
1923
+
1924
+ # fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
1925
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1926
+ # 0 256.0 0.037636 0.035902 0.031512
1927
+ # 1 512.0 0.058591 0.087229 0.058125
1928
+ # 2 1024.0 0.143337 0.263919 0.108443
1929
+ # 3 2048.0 0.414458 1.025985 0.214114
1930
+ # 4 4096.0 1.390841 4.020010 0.480550
1931
+ # 5 8192.0 5.067938 15.808171 1.230874
1932
+ # 6 16384.0 19.442280 62.765057 3.597274
1933
+ # 7 32768.0 75.501572 250.443771 11.768959
1934
+ # fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
1935
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1936
+ # 0 256.0 0.084404 0.070663 0.082613
1937
+ # 1 512.0 0.161510 0.172882 0.157661
1938
+ # 2 1024.0 0.388954 0.526047 0.339855
1939
+ # 3 2048.0 1.075814 1.736057 0.732420
1940
+ # 4 4096.0 3.401622 6.221376 1.636039
1941
+ # 5 8192.0 11.915136 23.483391 3.968725
1942
+ # 6 16384.0 44.660225 91.302910 10.857130
1943
+ # 7 32768.0 175.038467 359.048187 32.778240