kmfoda commited on
Commit
78082ac
·
verified ·
1 Parent(s): e0b5d91

Upload GPTOptim

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/root/optimized-gpt2-1b",
3
+ "activation_function": "gelu_new",
4
+ "all_reduce_scores": {
5
+ "0": "NON_PARTICIPATING",
6
+ "1": "NON_PARTICIPATING",
7
+ "10": "NON_PARTICIPATING",
8
+ "100": "NON_PARTICIPATING",
9
+ "101": "NON_PARTICIPATING",
10
+ "102": "NON_PARTICIPATING",
11
+ "103": "NON_PARTICIPATING",
12
+ "104": "NON_PARTICIPATING",
13
+ "105": "SUCCESS",
14
+ "106": "NON_PARTICIPATING",
15
+ "107": "NON_PARTICIPATING",
16
+ "108": "NON_PARTICIPATING",
17
+ "109": "NON_PARTICIPATING",
18
+ "11": "NON_PARTICIPATING",
19
+ "110": "NON_PARTICIPATING",
20
+ "111": "NON_PARTICIPATING",
21
+ "112": "NON_PARTICIPATING",
22
+ "113": "NON_PARTICIPATING",
23
+ "114": "NON_PARTICIPATING",
24
+ "115": "SUCCESS",
25
+ "116": "NON_PARTICIPATING",
26
+ "117": "NON_PARTICIPATING",
27
+ "118": "NON_PARTICIPATING",
28
+ "119": "NON_PARTICIPATING",
29
+ "12": "NON_PARTICIPATING",
30
+ "120": "NON_PARTICIPATING",
31
+ "121": "NON_PARTICIPATING",
32
+ "122": "NON_PARTICIPATING",
33
+ "123": "NON_PARTICIPATING",
34
+ "124": "NON_PARTICIPATING",
35
+ "125": "NON_PARTICIPATING",
36
+ "126": "NON_PARTICIPATING",
37
+ "127": "NON_PARTICIPATING",
38
+ "128": "NON_PARTICIPATING",
39
+ "129": "NON_PARTICIPATING",
40
+ "13": "NON_PARTICIPATING",
41
+ "130": "NON_PARTICIPATING",
42
+ "131": "NON_PARTICIPATING",
43
+ "132": "NON_PARTICIPATING",
44
+ "133": "NON_PARTICIPATING",
45
+ "134": "NON_PARTICIPATING",
46
+ "135": "NON_PARTICIPATING",
47
+ "136": "NON_PARTICIPATING",
48
+ "137": "NON_PARTICIPATING",
49
+ "138": "NON_PARTICIPATING",
50
+ "139": "SUCCESS",
51
+ "14": "NON_PARTICIPATING",
52
+ "140": "NON_PARTICIPATING",
53
+ "141": "NON_PARTICIPATING",
54
+ "142": "NON_PARTICIPATING",
55
+ "143": "NON_PARTICIPATING",
56
+ "144": "NON_PARTICIPATING",
57
+ "145": "NON_PARTICIPATING",
58
+ "146": "SUCCESS",
59
+ "147": "NON_PARTICIPATING",
60
+ "148": "NON_PARTICIPATING",
61
+ "149": "NON_PARTICIPATING",
62
+ "15": "SUCCESS",
63
+ "150": "NON_PARTICIPATING",
64
+ "151": "NON_PARTICIPATING",
65
+ "152": "NON_PARTICIPATING",
66
+ "153": "SUCCESS",
67
+ "154": "NON_PARTICIPATING",
68
+ "155": "SUCCESS",
69
+ "156": "NON_PARTICIPATING",
70
+ "157": "NON_PARTICIPATING",
71
+ "158": "NON_PARTICIPATING",
72
+ "159": "NON_PARTICIPATING",
73
+ "16": "SUCCESS",
74
+ "160": "NON_PARTICIPATING",
75
+ "161": "NON_PARTICIPATING",
76
+ "162": "NON_PARTICIPATING",
77
+ "163": "NON_PARTICIPATING",
78
+ "164": "NON_PARTICIPATING",
79
+ "165": "NON_PARTICIPATING",
80
+ "166": "SUCCESS",
81
+ "167": "NON_PARTICIPATING",
82
+ "168": "NON_PARTICIPATING",
83
+ "169": "SUCCESS",
84
+ "17": "NON_PARTICIPATING",
85
+ "170": "NON_PARTICIPATING",
86
+ "171": "SUCCESS",
87
+ "172": "NON_PARTICIPATING",
88
+ "173": "NON_PARTICIPATING",
89
+ "174": "NON_PARTICIPATING",
90
+ "175": "NON_PARTICIPATING",
91
+ "176": "NON_PARTICIPATING",
92
+ "177": "NON_PARTICIPATING",
93
+ "178": "NON_PARTICIPATING",
94
+ "179": "NON_PARTICIPATING",
95
+ "18": "NON_PARTICIPATING",
96
+ "180": "NON_PARTICIPATING",
97
+ "181": "NON_PARTICIPATING",
98
+ "182": "NON_PARTICIPATING",
99
+ "183": "NON_PARTICIPATING",
100
+ "184": "NON_PARTICIPATING",
101
+ "185": "NON_PARTICIPATING",
102
+ "186": "NON_PARTICIPATING",
103
+ "187": "NON_PARTICIPATING",
104
+ "188": "NON_PARTICIPATING",
105
+ "189": "NON_PARTICIPATING",
106
+ "19": "NON_PARTICIPATING",
107
+ "190": "NON_PARTICIPATING",
108
+ "191": "NON_PARTICIPATING",
109
+ "192": "NON_PARTICIPATING",
110
+ "193": "NON_PARTICIPATING",
111
+ "194": "NON_PARTICIPATING",
112
+ "195": "NON_PARTICIPATING",
113
+ "196": "NON_PARTICIPATING",
114
+ "197": "SUCCESS",
115
+ "198": "NON_PARTICIPATING",
116
+ "199": "NON_PARTICIPATING",
117
+ "2": "NON_PARTICIPATING",
118
+ "20": "NON_PARTICIPATING",
119
+ "200": "NON_PARTICIPATING",
120
+ "201": "NON_PARTICIPATING",
121
+ "202": "NON_PARTICIPATING",
122
+ "203": "SUCCESS",
123
+ "204": "NON_PARTICIPATING",
124
+ "205": "NON_PARTICIPATING",
125
+ "206": "NON_PARTICIPATING",
126
+ "207": "NON_PARTICIPATING",
127
+ "208": "NON_PARTICIPATING",
128
+ "209": "NON_PARTICIPATING",
129
+ "21": "NON_PARTICIPATING",
130
+ "210": "NON_PARTICIPATING",
131
+ "211": "NON_PARTICIPATING",
132
+ "212": "NON_PARTICIPATING",
133
+ "213": "NON_PARTICIPATING",
134
+ "214": "NON_PARTICIPATING",
135
+ "215": "NON_PARTICIPATING",
136
+ "216": "NON_PARTICIPATING",
137
+ "217": "NON_PARTICIPATING",
138
+ "218": "SUCCESS",
139
+ "219": "NON_PARTICIPATING",
140
+ "22": "SUCCESS",
141
+ "220": "NON_PARTICIPATING",
142
+ "221": "NON_PARTICIPATING",
143
+ "222": "NON_PARTICIPATING",
144
+ "223": "NON_PARTICIPATING",
145
+ "224": "NON_PARTICIPATING",
146
+ "225": "NON_PARTICIPATING",
147
+ "226": "NON_PARTICIPATING",
148
+ "227": "NON_PARTICIPATING",
149
+ "228": "NON_PARTICIPATING",
150
+ "229": "NON_PARTICIPATING",
151
+ "23": "NON_PARTICIPATING",
152
+ "230": "NON_PARTICIPATING",
153
+ "231": "NON_PARTICIPATING",
154
+ "232": "NON_PARTICIPATING",
155
+ "233": "NON_PARTICIPATING",
156
+ "234": "NON_PARTICIPATING",
157
+ "235": "NON_PARTICIPATING",
158
+ "236": "NON_PARTICIPATING",
159
+ "237": "NON_PARTICIPATING",
160
+ "238": "NON_PARTICIPATING",
161
+ "239": "NON_PARTICIPATING",
162
+ "24": "NON_PARTICIPATING",
163
+ "240": "NON_PARTICIPATING",
164
+ "241": "SUCCESS",
165
+ "242": "NON_PARTICIPATING",
166
+ "243": "NON_PARTICIPATING",
167
+ "244": "NON_PARTICIPATING",
168
+ "245": "NON_PARTICIPATING",
169
+ "246": "NON_PARTICIPATING",
170
+ "247": "NON_PARTICIPATING",
171
+ "248": "NON_PARTICIPATING",
172
+ "249": "NON_PARTICIPATING",
173
+ "25": "SUCCESS",
174
+ "250": "NON_PARTICIPATING",
175
+ "251": "NON_PARTICIPATING",
176
+ "252": "NON_PARTICIPATING",
177
+ "253": "NON_PARTICIPATING",
178
+ "254": "NON_PARTICIPATING",
179
+ "255": "NON_PARTICIPATING",
180
+ "26": "NON_PARTICIPATING",
181
+ "27": "NON_PARTICIPATING",
182
+ "28": "NON_PARTICIPATING",
183
+ "29": "NON_PARTICIPATING",
184
+ "3": "NON_PARTICIPATING",
185
+ "30": "NON_PARTICIPATING",
186
+ "31": "NON_PARTICIPATING",
187
+ "32": "NON_PARTICIPATING",
188
+ "33": "NON_PARTICIPATING",
189
+ "34": "NON_PARTICIPATING",
190
+ "35": "NON_PARTICIPATING",
191
+ "36": "NON_PARTICIPATING",
192
+ "37": "SUCCESS",
193
+ "38": "NON_PARTICIPATING",
194
+ "39": "SUCCESS",
195
+ "4": "SUCCESS",
196
+ "40": "NON_PARTICIPATING",
197
+ "41": "NON_PARTICIPATING",
198
+ "42": "NON_PARTICIPATING",
199
+ "43": "NON_PARTICIPATING",
200
+ "44": "NON_PARTICIPATING",
201
+ "45": "NON_PARTICIPATING",
202
+ "46": "NON_PARTICIPATING",
203
+ "47": "NON_PARTICIPATING",
204
+ "48": "NON_PARTICIPATING",
205
+ "49": "NON_PARTICIPATING",
206
+ "5": "NON_PARTICIPATING",
207
+ "50": "SUCCESS",
208
+ "51": "NON_PARTICIPATING",
209
+ "52": "NON_PARTICIPATING",
210
+ "53": "NON_PARTICIPATING",
211
+ "54": "NON_PARTICIPATING",
212
+ "55": "NON_PARTICIPATING",
213
+ "56": "NON_PARTICIPATING",
214
+ "57": "SUCCESS",
215
+ "58": "NON_PARTICIPATING",
216
+ "59": "NON_PARTICIPATING",
217
+ "6": "NON_PARTICIPATING",
218
+ "60": "NON_PARTICIPATING",
219
+ "61": "NON_PARTICIPATING",
220
+ "62": "NON_PARTICIPATING",
221
+ "63": "NON_PARTICIPATING",
222
+ "64": "NON_PARTICIPATING",
223
+ "65": "SUCCESS",
224
+ "66": "NON_PARTICIPATING",
225
+ "67": "NON_PARTICIPATING",
226
+ "68": "SUCCESS",
227
+ "69": "NON_PARTICIPATING",
228
+ "7": "NON_PARTICIPATING",
229
+ "70": "NON_PARTICIPATING",
230
+ "71": "NON_PARTICIPATING",
231
+ "72": "SUCCESS",
232
+ "73": "SUCCESS",
233
+ "74": "NON_PARTICIPATING",
234
+ "75": "NON_PARTICIPATING",
235
+ "76": "SUCCESS",
236
+ "77": "NON_PARTICIPATING",
237
+ "78": "NON_PARTICIPATING",
238
+ "79": "NON_PARTICIPATING",
239
+ "8": "NON_PARTICIPATING",
240
+ "80": "SUCCESS",
241
+ "81": "NON_PARTICIPATING",
242
+ "82": "NON_PARTICIPATING",
243
+ "83": "NON_PARTICIPATING",
244
+ "84": "NON_PARTICIPATING",
245
+ "85": "NON_PARTICIPATING",
246
+ "86": "NON_PARTICIPATING",
247
+ "87": "NON_PARTICIPATING",
248
+ "88": "NON_PARTICIPATING",
249
+ "89": "NON_PARTICIPATING",
250
+ "9": "NON_PARTICIPATING",
251
+ "90": "NON_PARTICIPATING",
252
+ "91": "SUCCESS",
253
+ "92": "NON_PARTICIPATING",
254
+ "93": "NON_PARTICIPATING",
255
+ "94": "NON_PARTICIPATING",
256
+ "95": "NON_PARTICIPATING",
257
+ "96": "NON_PARTICIPATING",
258
+ "97": "NON_PARTICIPATING",
259
+ "98": "NON_PARTICIPATING",
260
+ "99": "SUCCESS"
261
+ },
262
+ "architectures": [
263
+ "GPTOptim"
264
+ ],
265
+ "attn_pdrop": 0.1,
266
+ "auto_map": {
267
+ "AutoConfig": "configuration_gpt_optimized.GPTOptimConfig",
268
+ "AutoModelForCausalLM": "modeling_gpt_optimized.GPTOptim"
269
+ },
270
+ "block_size": 1024,
271
+ "bos_token_id": 50256,
272
+ "embd_pdrop": 0.1,
273
+ "eos_token_id": 50256,
274
+ "initializer_range": 0.02,
275
+ "layer_norm_epsilon": 1e-05,
276
+ "model_type": "gpt_optimized",
277
+ "n_embd": 1280,
278
+ "n_head": 32,
279
+ "n_inner": null,
280
+ "n_layer": 48,
281
+ "n_positions": 1024,
282
+ "reorder_and_upcast_attn": false,
283
+ "resid_pdrop": 0.1,
284
+ "scale_attn_by_inverse_layer_idx": false,
285
+ "scale_attn_weights": true,
286
+ "summary_activation": null,
287
+ "summary_first_dropout": 0.1,
288
+ "summary_proj_to_labels": true,
289
+ "summary_type": "cls_index",
290
+ "summary_use_proj": true,
291
+ "torch_dtype": "float32",
292
+ "transformers_version": "4.39.3",
293
+ "use_cache": true,
294
+ "vocab_size": 50257
295
+ }
configuration_gpt_optimized.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, GPT2Config
2
+ from typing import List
3
+
4
+
5
+ class GPTOptimConfig(GPT2Config):
6
+ model_type = "gpt_optimized"
7
+
8
+ def __init__(
9
+ self,
10
+ block_size: int = 1024, # max sequence length
11
+ vocab_size: int = 50257, # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
12
+ n_layer: int = 16, # number of layers
13
+ n_head: int = 16, # number of heads
14
+ n_embd: int = 1024, # embedding dimension
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.block_size = block_size
19
+ self.vocab_size = vocab_size
20
+ self.n_layer = n_layer
21
+ self.n_head = n_head
22
+ self.n_embd = n_embd
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2c240204fac1bf66e112ce3be2384a0097a2ea95b57ed2a4896c6cd01ecf5f7
3
+ size 4040701744
modeling_gpt_optimized.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import CrossEntropyLoss, functional as F
4
+ from transformers import PreTrainedModel, GPT2PreTrainedModel
5
+ from .configuration_gpt_optimized import GPTOptimConfig
6
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
7
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
8
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
9
+ from typing import Optional, Tuple, Union
10
+
11
+ _CHECKPOINT_FOR_DOC = "openai-community/gpt2"
12
+ _CONFIG_FOR_DOC = "GPT2Config"
13
+
14
+ GPT2_INPUTS_DOCSTRING = r"""
15
+ Args:
16
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
17
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
18
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
19
+ sequence tokens in the vocabulary.
20
+
21
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
22
+ `input_ids`.
23
+
24
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
25
+ [`PreTrainedTokenizer.__call__`] for details.
26
+
27
+ [What are input IDs?](../glossary#input-ids)
28
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
29
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
30
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
31
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
32
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
33
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
34
+
35
+ - 1 for tokens that are **not masked**,
36
+ - 0 for tokens that are **masked**.
37
+
38
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
39
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
40
+ `len(past_key_values) + len(input_ids)`
41
+
42
+ [What are attention masks?](../glossary#attention-mask)
43
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
44
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
45
+ 1]`:
46
+
47
+ - 0 corresponds to a *sentence A* token,
48
+ - 1 corresponds to a *sentence B* token.
49
+
50
+ [What are token type IDs?](../glossary#token-type-ids)
51
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
52
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
53
+ config.max_position_embeddings - 1]`.
54
+
55
+ [What are position IDs?](../glossary#position-ids)
56
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
57
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
58
+
59
+ - 1 indicates the head is **not masked**,
60
+ - 0 indicates the head is **masked**.
61
+
62
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
63
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
64
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
65
+ model's internal embedding lookup matrix.
66
+
67
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
68
+ `past_key_values`).
69
+ use_cache (`bool`, *optional*):
70
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
71
+ `past_key_values`).
72
+ output_attentions (`bool`, *optional*):
73
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
74
+ tensors for more detail.
75
+ output_hidden_states (`bool`, *optional*):
76
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
77
+ more detail.
78
+ return_dict (`bool`, *optional*):
79
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
80
+ """
81
+
82
+ class CausalSelfAttention(nn.Module):
83
+
84
+ def __init__(self, config):
85
+ super().__init__()
86
+ assert config.n_embd % config.n_head == 0
87
+ # key, query, value projections for all heads, but in a batch
88
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
89
+ # output projection
90
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
91
+ self.c_proj.NANOGPT_SCALE_INIT = 1
92
+ # regularization
93
+ self.n_head = config.n_head
94
+ self.n_embd = config.n_embd
95
+
96
+ def forward(self, x):
97
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
98
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
99
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
100
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
101
+ qkv = self.c_attn(x)
102
+ q, k, v = qkv.split(self.n_embd, dim=2)
103
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
104
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
105
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
106
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
107
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
108
+ # output projection
109
+ y = self.c_proj(y)
110
+ return y
111
+
112
+ class MLP(nn.Module):
113
+
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
117
+ self.gelu = nn.GELU(approximate='tanh')
118
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
119
+ self.c_proj.NANOGPT_SCALE_INIT = 1
120
+
121
+ def forward(self, x):
122
+ x = self.c_fc(x)
123
+ x = self.gelu(x)
124
+ x = self.c_proj(x)
125
+ return x
126
+
127
+ class Block(nn.Module):
128
+
129
+ def __init__(self, config):
130
+ super().__init__()
131
+ self.ln_1 = nn.LayerNorm(config.n_embd)
132
+ self.attn = CausalSelfAttention(config)
133
+ self.ln_2 = nn.LayerNorm(config.n_embd)
134
+ self.mlp = MLP(config)
135
+
136
+ def forward(self, x):
137
+ x = x + self.attn(self.ln_1(x))
138
+ x = x + self.mlp(self.ln_2(x))
139
+ return x
140
+
141
+ class GPT(nn.Module):
142
+
143
+ def __init__(self, config):
144
+ super().__init__()
145
+ self.config = config
146
+
147
+ self.transformer = nn.ModuleDict(dict(
148
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
149
+ wpe = nn.Embedding(config.block_size, config.n_embd),
150
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
151
+ ln_f = nn.LayerNorm(config.n_embd),
152
+ ))
153
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
154
+
155
+ # weight sharing scheme
156
+ self.transformer.wte.weight = self.lm_head.weight
157
+
158
+ # init params
159
+ self.apply(self._init_weights)
160
+
161
+ def _init_weights(self, module):
162
+ if isinstance(module, nn.Linear):
163
+ std = 0.02
164
+ if hasattr(module, 'NANOGPT_SCALE_INIT'):
165
+ std *= (2 * self.config.n_layer) ** -0.5
166
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
167
+ if module.bias is not None:
168
+ torch.nn.init.zeros_(module.bias)
169
+ elif isinstance(module, nn.Embedding):
170
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
171
+
172
+ class GPTOptim(GPT2PreTrainedModel):
173
+ config_class = GPTOptimConfig
174
+
175
+ def __init__(self, config):
176
+ super().__init__(config)
177
+ self.model = GPT(
178
+ config
179
+ )
180
+ self.config = config
181
+
182
+ def forward(self, input_ids, labels=None):
183
+ # input_ids is of shape (B, T)
184
+ B, T = input_ids.size()
185
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
186
+ # forward the token and posisition embeddings
187
+ pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device) # shape (T)
188
+ pos_emb = self.model.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
189
+ tok_emb = self.model.transformer.wte(input_ids) # token embeddings of shape (B, T, n_embd)
190
+ x = tok_emb + pos_emb
191
+ # forward the blocks of the transformer
192
+ for block in self.model.transformer.h:
193
+ x = block(x)
194
+ # forward the final layernorm and the classifier
195
+ x = self.model.transformer.ln_f(x)
196
+ logits = self.model.lm_head(x) # (B, T, vocab_size)
197
+ loss = None
198
+ if labels is not None:
199
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=self.config.eos_token_id)
200
+ return logits, loss