initial commit
Browse files- LICENSE +202 -0
- README.md +56 -3
- elm/infer_elm.py +132 -0
- elm/model.py +418 -0
- elm/positional_embeddings.py +86 -0
- elm/utils.py +25 -0
- models/.gitattributes +2 -0
- requirements.txt +2 -0
- run.py +24 -0
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,56 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SliceX AI™ ELM (Efficient Language Models)
|
2 |
+
This repository contains code to run our ELM models.
|
3 |
+
|
4 |
+
Models are located in the "models" folder. ELM models in this repository comes in three sizes (elm-1.0, elm-0.75 and elm-0.25) and supports the following use-cases.
|
5 |
+
- news_classification
|
6 |
+
- toxicity_detection
|
7 |
+
- news_content_generation
|
8 |
+
|
9 |
+
## Download ELM repo
|
10 |
+
```bash
|
11 |
+
git clone [email protected]:slicexai/elm-0.25-v0.1
|
12 |
+
sudo apt-get intall git-lfs
|
13 |
+
git lfs install
|
14 |
+
```
|
15 |
+
(Optional) Installing git-lfs without sudo,
|
16 |
+
```bash
|
17 |
+
wget https://github.com/git-lfs/git-lfs/releases/download/v3.2.0/git-lfs-linux-amd64-v3.2.0.tar.gz
|
18 |
+
tar -xzf git-lfs-linux-amd64-v3.2.0.tar.gz
|
19 |
+
PATH=$PATH:/<absolute-path>/git-lfs-3.2.0/
|
20 |
+
git lfs install
|
21 |
+
```
|
22 |
+
|
23 |
+
## Download ELM task-specific model checkpoints
|
24 |
+
```bash
|
25 |
+
cd elm-0.25-v0.1
|
26 |
+
git lfs pull -I models/elm-0.25_news_classification/ckpt.pt
|
27 |
+
git lfs pull -I models/elm-0.25_toxicity_detection/ckpt.pt
|
28 |
+
git lfs pull -I models/elm-0.25_news_content_generation/ckpt.pt
|
29 |
+
```
|
30 |
+
|
31 |
+
## Installation
|
32 |
+
```bash
|
33 |
+
pip install -r requirements.txt
|
34 |
+
```
|
35 |
+
|
36 |
+
## How to use - Run ELM on a sample task (e.g., news classification)
|
37 |
+
```bash
|
38 |
+
python run.py <elm-model-directory>
|
39 |
+
E.g. python run.py models/elm-0.25_news_classification
|
40 |
+
```
|
41 |
+
Prompts for the specific tasks can be found in the corresponding checkpoint directory. See an example below in the form of `models/elm-0.25_news_classification/example_prompts.json`.
|
42 |
+
```json
|
43 |
+
{
|
44 |
+
"inputs": ["GM May Close Plant in Europe DETROIT (Reuters) - General Motors Corp. <A HREF=\"http://www.investor.reuters.com/FullQuote.aspx?ticker=GM.N target=/stocks/quickinfo/fullquote\">GM.N</A> will likely cut some jobs in Europe and may close a plant there as part of a restructuring plan under development to try to return the region to profitability, the U.S. automaker said on Wednesday."],
|
45 |
+
"template": "[INST]Below is a news article. Please classify it under one of the following classes (World, Business, Sports, Sci/Tech). Please format your response as a JSON payload.\n\n### Article: {input}\n\n### JSON Response:[/INST]"
|
46 |
+
}
|
47 |
+
```
|
48 |
+
|
49 |
+
Running the above command returns the following response
|
50 |
+
|
51 |
+
```json
|
52 |
+
{
|
53 |
+
"prompt": "[INST]Below is a news article. Please classify it under one of the following classes (World, Business, Sports, Sci/Tech). Please format your response as a JSON payload.\n\n### Article: GM May Close Plant in Europe DETROIT (Reuters) - General Motors Corp. <A HREF=\"http://www.investor.reuters.com/FullQuote.aspx?ticker=GM.N target=/stocks/quickinfo/fullquote\">GM.N</A> will likely cut some jobs in Europe and may close a plant there as part of a restructuring plan under development to try to return the region to profitability, the U.S. automaker said on Wednesday.\n\n### JSON Response:[/INST]",
|
54 |
+
"response": "{'text_label': 'Business'}"
|
55 |
+
}
|
56 |
+
```
|
elm/infer_elm.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
from elm.model import *
|
4 |
+
from elm.utils import batchify
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
import json
|
7 |
+
|
8 |
+
|
9 |
+
def load_elm_model_and_tokenizer(local_path,
|
10 |
+
model_config_dict,
|
11 |
+
device="cuda",
|
12 |
+
load_partial=True,
|
13 |
+
get_num_layers_from_ckpt=True):
|
14 |
+
"""Load ELM model and tokenizer from local checkpoint."""
|
15 |
+
model_args = ModelArgs(**model_config_dict)
|
16 |
+
model = load_elm_model_from_ckpt(local_path, device=device, model_args=model_args, load_partial=load_partial, get_num_layers_from_ckpt=get_num_layers_from_ckpt)
|
17 |
+
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(local_path)
|
19 |
+
tokenizer.padding_side = "left"
|
20 |
+
tokenizer.truncation_side = "left"
|
21 |
+
return model, tokenizer
|
22 |
+
|
23 |
+
|
24 |
+
def generate_elm_response_given_model(prompts, model, tokenizer,
|
25 |
+
device="cuda",
|
26 |
+
max_ctx_word_len=1024,
|
27 |
+
max_ctx_token_len=0,
|
28 |
+
max_new_tokens=500,
|
29 |
+
temperature=0.8, # set to 0 for greedy decoding
|
30 |
+
top_k=200,
|
31 |
+
return_tok_cnt=False,
|
32 |
+
return_gen_only=False,
|
33 |
+
early_stop_on_eos=False):
|
34 |
+
"""Generate responses from ELM model given an input list of prompts ([str])."""
|
35 |
+
if max_ctx_token_len > 0:
|
36 |
+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_ctx_token_len).to(device)
|
37 |
+
else:
|
38 |
+
prompts = [" ".join(p.split(" ")[-max_ctx_word_len:]) for p in prompts]
|
39 |
+
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
|
40 |
+
|
41 |
+
results = []
|
42 |
+
|
43 |
+
input_tok_cnt = torch.numel(inputs.input_ids)
|
44 |
+
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
out_tok_cnt = 0
|
48 |
+
with torch.no_grad():
|
49 |
+
temperature = temperature
|
50 |
+
top_k = top_k
|
51 |
+
|
52 |
+
outputs = model.generate(inputs.input_ids, max_new_tokens, temperature=temperature, top_k=top_k,
|
53 |
+
return_gen_only=return_gen_only)
|
54 |
+
|
55 |
+
if return_tok_cnt:
|
56 |
+
out_tok_cnt += torch.numel(outputs)
|
57 |
+
|
58 |
+
if early_stop_on_eos:
|
59 |
+
mod_outputs = []
|
60 |
+
for i in range(len(outputs)):
|
61 |
+
curr_out = outputs[i]
|
62 |
+
|
63 |
+
eos_loc_id = -1
|
64 |
+
for j in range(len(outputs[i])):
|
65 |
+
tok_id = outputs[i][j]
|
66 |
+
if tok_id == tokenizer.eos_token_id:
|
67 |
+
eos_loc_id = j
|
68 |
+
break
|
69 |
+
if eos_loc_id >= 0:
|
70 |
+
curr_out = outputs[i][:eos_loc_id]
|
71 |
+
mod_outputs.append(curr_out)
|
72 |
+
outputs = mod_outputs
|
73 |
+
detokenized_output = tokenizer.batch_decode(outputs, skip_special_tokens=False)
|
74 |
+
|
75 |
+
results = detokenized_output
|
76 |
+
|
77 |
+
if return_tok_cnt:
|
78 |
+
return results, (input_tok_cnt, out_tok_cnt)
|
79 |
+
|
80 |
+
return results
|
81 |
+
|
82 |
+
def generate_elm_responses(elm_model_path,
|
83 |
+
prompts,
|
84 |
+
device=None,
|
85 |
+
elm_model_config={},
|
86 |
+
eval_batch_size=1,
|
87 |
+
verbose=True):
|
88 |
+
|
89 |
+
|
90 |
+
if not device:
|
91 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
92 |
+
print(f"Setting device to {device}")
|
93 |
+
|
94 |
+
model_config_dict = {
|
95 |
+
"hidden_size": elm_model_config.get("hidden_size", 2048),
|
96 |
+
"max_inp_len": elm_model_config.get("max_inp_len", 2048),
|
97 |
+
"num_attention_heads": elm_model_config.get("num_attention_heads", 32),
|
98 |
+
"num_layers": elm_model_config.get("num_layers", 48),
|
99 |
+
"bits": elm_model_config.get("bits", 256),
|
100 |
+
"vocab_size": elm_model_config.get("vocab_size", 50304),
|
101 |
+
"dropout": elm_model_config.get("dropout", 0.1),
|
102 |
+
"use_rotary_embeddings": elm_model_config.get("use_rotary_embeddings", True)
|
103 |
+
}
|
104 |
+
|
105 |
+
model, tokenizer = load_elm_model_and_tokenizer(local_path=elm_model_path, model_config_dict=model_config_dict, device=device, load_partial=True)
|
106 |
+
|
107 |
+
#prompts = [prompt if "[INST]" in prompt else f"[INST]{prompt}[/INST]" for prompt in prompts]
|
108 |
+
max_new_tokens = 128
|
109 |
+
if "classification" in elm_model_path or "detection" in elm_model_path:
|
110 |
+
max_new_tokens = 12
|
111 |
+
result = []
|
112 |
+
for prompt_batch in batchify(prompts, eval_batch_size):
|
113 |
+
responses, _ = generate_elm_response_given_model(prompt_batch,
|
114 |
+
model,
|
115 |
+
tokenizer,
|
116 |
+
device=device,
|
117 |
+
max_ctx_word_len=1024,
|
118 |
+
max_ctx_token_len=512,
|
119 |
+
max_new_tokens=max_new_tokens,
|
120 |
+
return_tok_cnt=True,
|
121 |
+
return_gen_only=False,
|
122 |
+
temperature=0.0,
|
123 |
+
early_stop_on_eos=True)
|
124 |
+
|
125 |
+
for prompt, response in zip(prompt_batch, responses):
|
126 |
+
response = response.split("[/INST]")[-1].strip()
|
127 |
+
result.append(response)
|
128 |
+
if verbose:
|
129 |
+
print(json.dumps({"prompt": prompt, "response": response}, indent=4))
|
130 |
+
print("\n***\n")
|
131 |
+
return result
|
132 |
+
|
elm/model.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import inspect
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from elm.utils import *
|
17 |
+
from elm.positional_embeddings import *
|
18 |
+
|
19 |
+
|
20 |
+
def get_elm_model_map(model_name):
|
21 |
+
"""Map the model type to corresponding class."""
|
22 |
+
elm_model_map = {
|
23 |
+
"rambutan": RambutanSlice,
|
24 |
+
}
|
25 |
+
|
26 |
+
return elm_model_map.get(model_name, RambutanSlice)
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class ModelArgs:
|
31 |
+
"""ELM Model Args"""
|
32 |
+
model_name_or_path: str = "ELM"
|
33 |
+
compile_model: bool = False
|
34 |
+
elm_model_class: Optional[str] = "rambutan"
|
35 |
+
hidden_size: Optional[int] = 2048
|
36 |
+
max_inp_len: Optional[int] = 2048
|
37 |
+
attn_window_size: Optional[int] = max_inp_len
|
38 |
+
num_attention_heads: Optional[int] = 32
|
39 |
+
layernorm_eps: float = 1e-5
|
40 |
+
attention_dropout: float = 0.1
|
41 |
+
hidden_dropout: float = 0.1
|
42 |
+
num_layers: Optional[int] = 16
|
43 |
+
bits: Optional[int] = 256
|
44 |
+
vocab_size: Optional[int] = 50304
|
45 |
+
dropout: Optional[int] = 0.1
|
46 |
+
use_rotary_embeddings: Optional[bool] = True
|
47 |
+
tokenizer: Optional[str] = None
|
48 |
+
|
49 |
+
|
50 |
+
class ELM(torch.nn.Module):
|
51 |
+
"""ELM (SliceX GPT) model."""
|
52 |
+
def __init__(self,
|
53 |
+
model_args: ModelArgs):
|
54 |
+
"""Initialize an ELM model instance."""
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.model_args = model_args
|
58 |
+
|
59 |
+
elm_model_class = model_args.elm_model_class
|
60 |
+
hidden_size = model_args.hidden_size
|
61 |
+
max_inp_len = model_args.max_inp_len
|
62 |
+
num_attention_heads = model_args.num_attention_heads
|
63 |
+
layernorm_eps = model_args.layernorm_eps
|
64 |
+
attention_dropout = model_args.attention_dropout
|
65 |
+
hidden_dropout = model_args.hidden_dropout
|
66 |
+
num_layers = model_args.num_layers
|
67 |
+
bits = model_args.bits
|
68 |
+
vocab_size = model_args.vocab_size
|
69 |
+
use_rotary_embeddings = model_args.use_rotary_embeddings
|
70 |
+
|
71 |
+
layer_class = get_elm_model_map(elm_model_class)
|
72 |
+
|
73 |
+
self.slice_transformer = torch.nn.ModuleDict(dict(
|
74 |
+
temb = torch.nn.Embedding(vocab_size, hidden_size),
|
75 |
+
pemb = torch.nn.Embedding(max_inp_len, hidden_size) if not use_rotary_embeddings else None,
|
76 |
+
drop = torch.nn.Dropout(hidden_dropout),
|
77 |
+
h = torch.nn.ModuleList([ layer_class(model_args=model_args) for _ in range(num_layers) ]),
|
78 |
+
ln_f = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps),
|
79 |
+
))
|
80 |
+
|
81 |
+
self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
|
82 |
+
|
83 |
+
print("Number of model parameters: %.2fM" % (self.get_num_params(False)/1e6,))
|
84 |
+
|
85 |
+
|
86 |
+
def forward(self,
|
87 |
+
x: torch.Tensor,
|
88 |
+
attention_mask: Optional[torch.Tensor] = None,
|
89 |
+
targets: Optional[torch.Tensor] = None):
|
90 |
+
device = x.device
|
91 |
+
batch, seqlen = x.size()
|
92 |
+
|
93 |
+
|
94 |
+
tok_emb = self.slice_transformer.temb(x)
|
95 |
+
|
96 |
+
if not self.model_args.use_rotary_embeddings:
|
97 |
+
pos = torch.arange(0, seqlen, dtype=torch.long, device=device)
|
98 |
+
pos_emb = self.slice_transformer.pemb(pos)
|
99 |
+
x = self.slice_transformer.drop(tok_emb + pos_emb)
|
100 |
+
else:
|
101 |
+
x = self.slice_transformer.drop(tok_emb)
|
102 |
+
|
103 |
+
tlayer_id = 0
|
104 |
+
ignore_index_id = -100
|
105 |
+
loss = torch.zeros(1).to(device)
|
106 |
+
loss_denom = 0
|
107 |
+
|
108 |
+
for tlayer in self.slice_transformer.h:
|
109 |
+
x = tlayer(x, attention_mask=attention_mask)
|
110 |
+
|
111 |
+
tlayer_id += 1
|
112 |
+
|
113 |
+
x = self.slice_transformer.ln_f(x)
|
114 |
+
|
115 |
+
if targets is not None:
|
116 |
+
logits = self.lm_head(x)
|
117 |
+
|
118 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
119 |
+
shift_targets = targets[..., 1:].contiguous()
|
120 |
+
curr_loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)),
|
121 |
+
shift_targets.view(-1),
|
122 |
+
ignore_index=ignore_index_id)
|
123 |
+
loss += curr_loss.float()
|
124 |
+
loss_denom += 1
|
125 |
+
else:
|
126 |
+
logits = self.lm_head(x[:, [-1], :])
|
127 |
+
|
128 |
+
loss = loss / loss_denom
|
129 |
+
|
130 |
+
return logits, loss
|
131 |
+
|
132 |
+
|
133 |
+
def get_num_params(self, non_embedding=True):
|
134 |
+
"""
|
135 |
+
Return the number of parameters in the model.
|
136 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
137 |
+
This assumes parameter tying between input and final layer embeddings. Oherwise
|
138 |
+
If there is no parameter sharing , set the flag to False to include parameters for both layers.
|
139 |
+
"""
|
140 |
+
n_params = sum(p.numel() for p in self.parameters())
|
141 |
+
if non_embedding and not self.model_args.use_rotary_embeddings:
|
142 |
+
n_params -= self.slice_transformer.pemb.weight.numel()
|
143 |
+
return n_params
|
144 |
+
|
145 |
+
|
146 |
+
@torch.no_grad()
|
147 |
+
def generate(self, x, max_new_tokens, temperature=0.8, top_k=200, top_p=0.9,
|
148 |
+
return_gen_only=False):
|
149 |
+
max_inp_len = self.model_args.max_inp_len
|
150 |
+
|
151 |
+
for _ in range(max_new_tokens):
|
152 |
+
x_ctxt = x if x.size(1) <= max_inp_len else x[:, -max_inp_len:]
|
153 |
+
|
154 |
+
logits, _ = self(x_ctxt)
|
155 |
+
|
156 |
+
next_id = None
|
157 |
+
|
158 |
+
if temperature <= 0:
|
159 |
+
next_id = torch.argmax(logits, dim=-1)
|
160 |
+
else:
|
161 |
+
logits = logits[:, -1, :] / temperature
|
162 |
+
|
163 |
+
if top_k is not None:
|
164 |
+
v, k = torch.topk(logits, min(top_k, logits.size(-1)))
|
165 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
166 |
+
|
167 |
+
probs = F.softmax(logits, dim=-1)
|
168 |
+
|
169 |
+
if top_p is None:
|
170 |
+
next_id = torch.multinomial(probs, num_samples=1)
|
171 |
+
else:
|
172 |
+
next_id = sample_top_p(probs, top_p)
|
173 |
+
x = torch.cat((x, next_id), dim=1)
|
174 |
+
|
175 |
+
if return_gen_only:
|
176 |
+
return x[:,-max_new_tokens:]
|
177 |
+
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class RambutanMLP(torch.nn.Module):
|
182 |
+
"""RambutanMLP version of MLP module used in the ELM (SliceX GPT) Transformer block."""
|
183 |
+
def __init__(self, dim=768, bits=32, dropout = 0.0):
|
184 |
+
super(RambutanMLP, self).__init__()
|
185 |
+
self.dim = dim
|
186 |
+
self.bits = bits
|
187 |
+
|
188 |
+
self.dropout = torch.nn.Dropout(dropout)
|
189 |
+
|
190 |
+
self.A1_c_w = torch.nn.Linear(self.dim, self.bits, bias=True)
|
191 |
+
|
192 |
+
self.Hexperts = 4
|
193 |
+
self.Hexpertemb = torch.nn.Embedding(self.bits, self.dim)
|
194 |
+
|
195 |
+
self.expert_aggr = torch.nn.Linear(self.Hexperts, 1)
|
196 |
+
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
h_c = torch.nn.functional.softmax(self.A1_c_w(x), dim=-1)
|
200 |
+
|
201 |
+
v, i = torch.topk(h_c, self.Hexperts)
|
202 |
+
|
203 |
+
if len(x.size()) < 3:
|
204 |
+
p = v.unsqueeze(-1).expand(-1,-1,self.dim)
|
205 |
+
else:
|
206 |
+
p = v.unsqueeze(-1).expand(-1,-1,-1,self.dim)
|
207 |
+
|
208 |
+
h_emb = p * self.Hexpertemb(i)
|
209 |
+
|
210 |
+
if len(x.size()) < 3:
|
211 |
+
out = self.expert_aggr(h_emb.transpose(1,2)).reshape(h_emb.size(0), -1)
|
212 |
+
else:
|
213 |
+
out = self.expert_aggr(h_emb.transpose(-2,-1)).reshape(x.size())
|
214 |
+
|
215 |
+
out = x * out
|
216 |
+
out = self.dropout(out)
|
217 |
+
|
218 |
+
return out
|
219 |
+
|
220 |
+
|
221 |
+
class RambutanSlice(torch.nn.Module):
|
222 |
+
"""Rambutan version of ELM (SliceX GPT) Transformer block."""
|
223 |
+
def __init__(self,
|
224 |
+
model_args: ModelArgs):
|
225 |
+
super().__init__()
|
226 |
+
|
227 |
+
self.model_args = model_args
|
228 |
+
|
229 |
+
self.num_attention_heads = model_args.num_attention_heads
|
230 |
+
self.kv_channels = model_args.hidden_size // model_args.num_attention_heads
|
231 |
+
self.ln1 = torch.nn.LayerNorm(model_args.hidden_size, eps=model_args.layernorm_eps)
|
232 |
+
self.ln2 = torch.nn.LayerNorm(model_args.hidden_size, eps=model_args.layernorm_eps)
|
233 |
+
self.mlp = RambutanMLP(dim=model_args.hidden_size, bits=model_args.bits)
|
234 |
+
self.cattn = RambutanCausalSelfAttention(model_args=model_args)
|
235 |
+
|
236 |
+
|
237 |
+
def forward(self,
|
238 |
+
x: torch.Tensor,
|
239 |
+
attention_mask: torch.Tensor = None):
|
240 |
+
res = x
|
241 |
+
|
242 |
+
x = self.ln1(x)
|
243 |
+
x = self.cattn(x, attention_mask=attention_mask)
|
244 |
+
|
245 |
+
x = res + x
|
246 |
+
res = x
|
247 |
+
x = self.ln2(x)
|
248 |
+
x = self.mlp(x)
|
249 |
+
|
250 |
+
return x + res
|
251 |
+
|
252 |
+
|
253 |
+
class RambutanCausalSelfAttention(torch.nn.Module):
|
254 |
+
"""Rambutan version of self-attention module used in the ELM (SliceX GPT) transformer block."""
|
255 |
+
|
256 |
+
def __init__(self,
|
257 |
+
model_args: ModelArgs):
|
258 |
+
super().__init__()
|
259 |
+
|
260 |
+
self.model_args = model_args
|
261 |
+
|
262 |
+
n_embd = model_args.hidden_size
|
263 |
+
n_head = model_args.num_attention_heads
|
264 |
+
bias = False
|
265 |
+
dropout = model_args.attention_dropout
|
266 |
+
|
267 |
+
assert n_embd % n_head == 0
|
268 |
+
|
269 |
+
self.c_attn = torch.nn.Linear(n_embd, 3 * n_embd, bias=bias)
|
270 |
+
|
271 |
+
self.c_proj = torch.nn.Linear(n_embd, n_embd, bias=bias)
|
272 |
+
|
273 |
+
self.attn_dropout = torch.nn.Dropout(dropout)
|
274 |
+
self.resid_dropout = torch.nn.Dropout(dropout)
|
275 |
+
self.n_head = n_head
|
276 |
+
self.n_embd = n_embd
|
277 |
+
self.dropout = dropout
|
278 |
+
|
279 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
280 |
+
|
281 |
+
if not self.flash:
|
282 |
+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
283 |
+
self.rotary_embeddings = (
|
284 |
+
RotaryEmbedding(n_embd // n_head) if model_args.use_rotary_embeddings else None
|
285 |
+
)
|
286 |
+
|
287 |
+
|
288 |
+
def forward(self, x, attention_mask: torch.Tensor = None):
|
289 |
+
B, T, C = x.size()
|
290 |
+
device = x.device
|
291 |
+
|
292 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
293 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
294 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
295 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
296 |
+
|
297 |
+
if self.rotary_embeddings:
|
298 |
+
q, k = self.rotary_embeddings(q=q, k=k)
|
299 |
+
|
300 |
+
is_causal = True
|
301 |
+
attn_mask = None
|
302 |
+
|
303 |
+
if attention_mask is not None:
|
304 |
+
att_mask_input = attention_mask
|
305 |
+
att_mask_input = att_mask_input.unsqueeze(-1).expand(B, T, T)
|
306 |
+
|
307 |
+
if is_causal:
|
308 |
+
att_mask_causal = torch.tril(torch.ones(T, T)).view(1,T,T).expand(B,T,T).to(device)
|
309 |
+
attn_mask = (att_mask_causal * att_mask_input)
|
310 |
+
else:
|
311 |
+
attn_mask = att_mask_input
|
312 |
+
|
313 |
+
attn_mask = attn_mask.unsqueeze(1).expand(B, self.n_head, T, T)
|
314 |
+
attn_mask.float().to(device)
|
315 |
+
|
316 |
+
|
317 |
+
if self.flash:
|
318 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
319 |
+
else:
|
320 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
321 |
+
|
322 |
+
if is_causal and attn_mask is None:
|
323 |
+
attn_mask = torch.tril(torch.ones(T, T)).view(1,T,T).expand(B,T,T).to(device)
|
324 |
+
attn_mask = attn_mask.unsqueeze(1).expand(B, self.n_head, T, T)
|
325 |
+
|
326 |
+
if attn_mask is not None:
|
327 |
+
att = att.masked_fill(attn_mask == 0, torch.finfo(att.dtype).min)
|
328 |
+
|
329 |
+
att = F.softmax(att, dim=-1)
|
330 |
+
att = self.attn_dropout(att)
|
331 |
+
y = att @ v
|
332 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
333 |
+
|
334 |
+
y = self.resid_dropout(self.c_proj(y))
|
335 |
+
|
336 |
+
return y
|
337 |
+
|
338 |
+
|
339 |
+
def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None):
|
340 |
+
"""Initialize ELM model using default or model_config parameters."""
|
341 |
+
if model_config_dict:
|
342 |
+
model_args = ModelArgs(**model_config_dict)
|
343 |
+
|
344 |
+
dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
345 |
+
|
346 |
+
model = ELM(model_args=model_args).to(dtype=dtype)
|
347 |
+
|
348 |
+
return model
|
349 |
+
|
350 |
+
def get_h_layers_in_ckpt(ckpt_state_dict,
|
351 |
+
layer_name_template = 'slice_transformer.h.{layer_num}.'):
|
352 |
+
num_layers_in_ckpt = 0
|
353 |
+
from collections import defaultdict
|
354 |
+
layer_wise_dict = defaultdict(lambda: defaultdict(list))
|
355 |
+
|
356 |
+
layer_num_found = True
|
357 |
+
while layer_num_found:
|
358 |
+
layer_num_found = False
|
359 |
+
for layer_name in ckpt_state_dict.keys():
|
360 |
+
if layer_name_template.format(layer_num=num_layers_in_ckpt) in layer_name:
|
361 |
+
layer_wise_dict[num_layers_in_ckpt][layer_name] = ckpt_state_dict[layer_name]
|
362 |
+
layer_num_found = True
|
363 |
+
num_layers_in_ckpt += 1
|
364 |
+
return layer_wise_dict
|
365 |
+
|
366 |
+
def load_elm_model_from_ckpt(ckpt_dir, device='cuda', load_partial=False, model_args=ModelArgs(), get_num_layers_from_ckpt=True):
|
367 |
+
"""Load ELM model from local checkpoint."""
|
368 |
+
print(f"Loading ELM checkpoint from {ckpt_dir}")
|
369 |
+
ckpt_path = os.path.join(ckpt_dir, 'ckpt.pt')
|
370 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
371 |
+
|
372 |
+
if get_num_layers_from_ckpt:
|
373 |
+
layer_name_template = 'slice_transformer.h.{layer_num}.'
|
374 |
+
ckpt_layer_wise_dict = get_h_layers_in_ckpt(checkpoint['model'],
|
375 |
+
layer_name_template = layer_name_template)
|
376 |
+
model_args.num_layers = len(ckpt_layer_wise_dict)
|
377 |
+
model = init_elm_model(model_args=model_args, device=device)
|
378 |
+
ckpt_state_dict = checkpoint['model']
|
379 |
+
|
380 |
+
unwanted_prefix = '_orig_mod.'
|
381 |
+
for k,v in list(ckpt_state_dict.items()):
|
382 |
+
if k.startswith(unwanted_prefix):
|
383 |
+
ckpt_state_dict[k[len(unwanted_prefix):]] = ckpt_state_dict.pop(k)
|
384 |
+
|
385 |
+
if load_partial:
|
386 |
+
mod_state_dict = model.state_dict()
|
387 |
+
for k,v in list(ckpt_state_dict.items()):
|
388 |
+
if k in mod_state_dict:
|
389 |
+
v_size = v.size()
|
390 |
+
mod_size = mod_state_dict[k].size()
|
391 |
+
|
392 |
+
if v_size == mod_size:
|
393 |
+
mod_state_dict[k] = v
|
394 |
+
else:
|
395 |
+
if len(v_size) == 1:
|
396 |
+
mod_state_dict[k][:v_size[-1]] = v
|
397 |
+
elif len(v_size) == 2:
|
398 |
+
mod_state_dict[k][:v_size[-2], :v_size[-1]] = v
|
399 |
+
|
400 |
+
ckpt_state_dict = mod_state_dict
|
401 |
+
load_status = model.load_state_dict(ckpt_state_dict)
|
402 |
+
print(load_status)
|
403 |
+
model.to(device)
|
404 |
+
|
405 |
+
return model
|
406 |
+
|
407 |
+
|
408 |
+
def sample_top_p(probs, threshold):
|
409 |
+
"""Perform top-p sampling on probability distribution using a threshold."""
|
410 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
411 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
412 |
+
mask = probs_sum - probs_sort > threshold
|
413 |
+
probs_sort[mask] = 0.0
|
414 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
415 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
416 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
417 |
+
|
418 |
+
return next_token
|
elm/positional_embeddings.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
|
5 |
+
def rotate_half(x):
|
6 |
+
x1, x2 = x.chunk(2, dim=-1)
|
7 |
+
return torch.cat((-x2, x1), dim=-1)
|
8 |
+
|
9 |
+
|
10 |
+
@torch.jit.script
|
11 |
+
def apply_rotary_pos_emb(x, cos, sin):
|
12 |
+
# NOTE: This could probably be moved to Triton
|
13 |
+
|
14 |
+
# Handle a possible sequence length mismatch in between q and k
|
15 |
+
cos = cos[:, :, : x.shape[-2], :]
|
16 |
+
sin = sin[:, :, : x.shape[-2], :]
|
17 |
+
|
18 |
+
return (x * cos) + (rotate_half(x) * sin)
|
19 |
+
|
20 |
+
|
21 |
+
class RotaryEmbedding(torch.nn.Module):
|
22 |
+
"""
|
23 |
+
Rotary position embeddings from RoFormer (Su et. al, 2021).
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, dim_model: int, *_, **__):
|
27 |
+
super().__init__()
|
28 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
29 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
|
30 |
+
self.register_buffer("inv_freq", inv_freq)
|
31 |
+
|
32 |
+
self._seq_len_cached = None
|
33 |
+
self._cos_cached = None
|
34 |
+
self._sin_cached = None
|
35 |
+
|
36 |
+
def update_cos_sin_tables(self, x, seq_dimension=1):
|
37 |
+
seq_len = x.shape[seq_dimension]
|
38 |
+
|
39 |
+
# Reset the tables if the sequence length has changed,
|
40 |
+
# or if we're on a new device (possibly due to tracing for instance)
|
41 |
+
if (
|
42 |
+
seq_len != self._seq_len_cached
|
43 |
+
or self._cos_cached.device != x.device
|
44 |
+
or self._cos_cached.dtype != x.dtype
|
45 |
+
):
|
46 |
+
self._seq_len_cached = seq_len
|
47 |
+
t = torch.arange(
|
48 |
+
x.shape[seq_dimension], device=x.device, dtype=torch.float32
|
49 |
+
)
|
50 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
|
51 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
52 |
+
|
53 |
+
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
|
54 |
+
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
|
55 |
+
|
56 |
+
return self._cos_cached, self._sin_cached
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self, q: torch.Tensor, k: torch.Tensor
|
60 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
61 |
+
self._cos_cached, self._sin_cached = self.update_cos_sin_tables(
|
62 |
+
k, seq_dimension=-2
|
63 |
+
)
|
64 |
+
|
65 |
+
return (
|
66 |
+
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
|
67 |
+
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def __test_rope__():
|
72 |
+
dtype=torch.float16
|
73 |
+
batch=4
|
74 |
+
seqlen=2048
|
75 |
+
dim=4096
|
76 |
+
num_heads=32
|
77 |
+
dim_key_head=dim // num_heads
|
78 |
+
|
79 |
+
x=torch.randn(batch,seqlen,num_heads,dim_key_head).to(dtype=dtype).to('cuda')
|
80 |
+
|
81 |
+
rpe=RotaryEmbedding(dim_key_head).to(dtype=dtype).to('cuda')
|
82 |
+
q,k=rpe(q=x,k=x)
|
83 |
+
|
84 |
+
|
85 |
+
#__test_rope__()
|
86 |
+
|
elm/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
from prettytable import PrettyTable
|
4 |
+
|
5 |
+
def count_parameters(model):
|
6 |
+
"""Count the number of parameters in the model."""
|
7 |
+
table = PrettyTable(["Modules", "Parameters"])
|
8 |
+
total_params = 0
|
9 |
+
|
10 |
+
for name, parameter in model.named_parameters():
|
11 |
+
if not parameter.requires_grad: continue
|
12 |
+
params = parameter.numel()
|
13 |
+
table.add_row([name, params])
|
14 |
+
total_params+=params
|
15 |
+
|
16 |
+
print(table)
|
17 |
+
print(f"Total Trainable Params: {total_params}")
|
18 |
+
|
19 |
+
return total_params
|
20 |
+
|
21 |
+
|
22 |
+
def batchify(lst, n):
|
23 |
+
"""Divide a list into chunks of size n."""
|
24 |
+
return [lst[i:i + n] for i in range(0, len(lst), n)]
|
25 |
+
|
models/.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
“*.pt” filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
run.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
from elm.infer_elm import generate_elm_responses
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser(description='run prompts with elm model.')
|
7 |
+
parser.add_argument('elm_model_path', help='Path to the elm_model_path')
|
8 |
+
|
9 |
+
|
10 |
+
def get_prompt_config_file(elm_model_path):
|
11 |
+
return os.path.join(elm_model_path, "example_prompts.json")
|
12 |
+
|
13 |
+
def run(elm_model_path: str):
|
14 |
+
prompt_config_file = get_prompt_config_file(elm_model_path)
|
15 |
+
|
16 |
+
with open(prompt_config_file, "r") as f:
|
17 |
+
prompt_info = json.load(f)
|
18 |
+
prompts = [prompt_info["template"].format(input=input) for input in prompt_info["inputs"]]
|
19 |
+
print(f"Loaded prompts from: {prompt_config_file}")
|
20 |
+
generate_elm_responses(elm_model_path, prompts, verbose=True)
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
args = parser.parse_args()
|
24 |
+
run(args.elm_model_path)
|