DavidGF commited on
Commit
072a327
·
verified ·
1 Parent(s): 85679c7

Upload folder using huggingface_hub

Browse files
configuration_kraken.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class KrakenConfig(PretrainedConfig):
4
+ model_type = "kraken"
5
+
6
+ def __init__(self, config_dict=None, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.config_dict = config_dict or {}
kraken_model/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KrakenForCausalLM"
4
+ ],
5
+ "config_dict": {
6
+ "class_indices": {
7
+ "LABEL_0": 0,
8
+ "LABEL_1": 1,
9
+ "LABEL_2": 2,
10
+ "LABEL_3": 3
11
+ },
12
+ "model_type": "kraken",
13
+ "models": {
14
+ "expert1": "VAGOsolutions/Llama-3-SauerkrautLM-8b-Instruct",
15
+ "expert2": "mii-community/zefiro-7b-dpo-ITA",
16
+ "expert3": "paulml/Hermes-2-Pro-French",
17
+ "expert4": "norallm/normistral-7b-warm-instruct"
18
+ },
19
+ "quantization": {
20
+ "expert1": null,
21
+ "expert2": null,
22
+ "expert3": null,
23
+ "expert4": null
24
+ },
25
+ "router": "kraken_router",
26
+ "tokenizers": {
27
+ "expert1": "VAGOsolutions/Llama-3-SauerkrautLM-8b-Instruct",
28
+ "expert2": "mii-community/zefiro-7b-dpo-ITA",
29
+ "expert3": "paulml/Hermes-2-Pro-French",
30
+ "expert4": "norallm/normistral-7b-warm-instruct"
31
+ }
32
+ },
33
+ "model_type": "kraken",
34
+ "torch_dtype": "float32",
35
+ "transformers_version": "4.41.0"
36
+ }
kraken_model/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.41.0"
4
+ }
kraken_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7904498ee4fe7684ea8b1e5e4b898d68bf5874b8b99a6b3a29d217f2f762cb0d
3
+ size 1856003896
kraken_router/added_tokens.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "<|endoftext|>": 151643,
3
+ "<|im_end|>": 151645,
4
+ "<|im_start|>": 151644
5
+ }
kraken_router/config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Qwen/Qwen1.5-0.5B",
3
+ "architectures": [
4
+ "Qwen2ForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151643,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 1024,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2",
15
+ "3": "LABEL_3"
16
+ },
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 2816,
19
+ "label2id": {
20
+ "LABEL_0": 0,
21
+ "LABEL_1": 1,
22
+ "LABEL_2": 2,
23
+ "LABEL_3": 3
24
+ },
25
+ "max_position_embeddings": 32768,
26
+ "max_window_layers": 21,
27
+ "model_type": "qwen2",
28
+ "num_attention_heads": 16,
29
+ "num_hidden_layers": 24,
30
+ "num_key_value_heads": 16,
31
+ "pad_token_id": 151643,
32
+ "problem_type": "single_label_classification",
33
+ "rms_norm_eps": 1e-06,
34
+ "rope_theta": 1000000.0,
35
+ "sliding_window": 32768,
36
+ "tie_word_embeddings": true,
37
+ "torch_dtype": "float32",
38
+ "transformers_version": "4.41.0",
39
+ "use_cache": true,
40
+ "use_sliding_window": false,
41
+ "vocab_size": 151936
42
+ }
kraken_router/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
kraken_router/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a977f3e5f94dc95b07300f21e38c86fc85906081fde22526de296c32c0e686e4
3
+ size 1856000112
kraken_router/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52c203605e5f0e424fd7f097eacdd937670347c777796d884eba1eaa2cc2a2f7
3
+ size 3712178682
kraken_router/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c062f7f375beded48b5337f5a3f3a5cb38807fa3e85dbf3e294c0ab6b627bfc2
3
+ size 14244
kraken_router/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:460f54983d240fe33e7a94fb335adfd576cb487f5d497e2b899d4e773a2ce84c
3
+ size 1064
kraken_router/special_tokens_map.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>"
5
+ ],
6
+ "eos_token": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false
12
+ },
13
+ "pad_token": "<|endoftext|>"
14
+ }
kraken_router/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
kraken_router/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "additional_special_tokens": [
30
+ "<|im_start|>",
31
+ "<|im_end|>"
32
+ ],
33
+ "bos_token": null,
34
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
35
+ "clean_up_tokenization_spaces": false,
36
+ "eos_token": "<|endoftext|>",
37
+ "errors": "replace",
38
+ "model_max_length": 32768,
39
+ "pad_token": "<|endoftext|>",
40
+ "split_special_tokens": false,
41
+ "tokenizer_class": "Qwen2Tokenizer",
42
+ "unk_token": null
43
+ }
kraken_router/trainer_state.json ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 5.283331606754377,
5
+ "eval_steps": 500,
6
+ "global_step": 51000,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.05179736869367036,
13
+ "grad_norm": 3.8771300836515366e-08,
14
+ "learning_rate": 1.9852007518018084e-05,
15
+ "loss": 0.0996,
16
+ "step": 500
17
+ },
18
+ {
19
+ "epoch": 0.10359473738734072,
20
+ "grad_norm": 9.604158321963041e-07,
21
+ "learning_rate": 1.970401503603617e-05,
22
+ "loss": 0.256,
23
+ "step": 1000
24
+ },
25
+ {
26
+ "epoch": 0.1553921060810111,
27
+ "grad_norm": 1.8179751350544393e-05,
28
+ "learning_rate": 1.9556022554054253e-05,
29
+ "loss": 0.1367,
30
+ "step": 1500
31
+ },
32
+ {
33
+ "epoch": 0.20718947477468144,
34
+ "grad_norm": 8.53914680192247e-05,
35
+ "learning_rate": 1.9408030072072343e-05,
36
+ "loss": 0.063,
37
+ "step": 2000
38
+ },
39
+ {
40
+ "epoch": 0.2589868434683518,
41
+ "grad_norm": 4.6534645662177354e-05,
42
+ "learning_rate": 1.9260037590090425e-05,
43
+ "loss": 0.0201,
44
+ "step": 2500
45
+ },
46
+ {
47
+ "epoch": 0.3107842121620222,
48
+ "grad_norm": 6.902104843220513e-08,
49
+ "learning_rate": 1.911204510810851e-05,
50
+ "loss": 0.106,
51
+ "step": 3000
52
+ },
53
+ {
54
+ "epoch": 0.36258158085569253,
55
+ "grad_norm": 2.192794745781157e-08,
56
+ "learning_rate": 1.8964052626126594e-05,
57
+ "loss": 0.0797,
58
+ "step": 3500
59
+ },
60
+ {
61
+ "epoch": 0.4143789495493629,
62
+ "grad_norm": 4.58152558859698e-13,
63
+ "learning_rate": 1.881606014414468e-05,
64
+ "loss": 0.0264,
65
+ "step": 4000
66
+ },
67
+ {
68
+ "epoch": 0.46617631824303324,
69
+ "grad_norm": 4.8503436119062826e-05,
70
+ "learning_rate": 1.8668067662162763e-05,
71
+ "loss": 0.0329,
72
+ "step": 4500
73
+ },
74
+ {
75
+ "epoch": 0.5179736869367036,
76
+ "grad_norm": 9.135746950050816e-05,
77
+ "learning_rate": 1.852007518018085e-05,
78
+ "loss": 0.023,
79
+ "step": 5000
80
+ },
81
+ {
82
+ "epoch": 0.569771055630374,
83
+ "grad_norm": 2.3920888381212535e-08,
84
+ "learning_rate": 1.8372082698198932e-05,
85
+ "loss": 0.044,
86
+ "step": 5500
87
+ },
88
+ {
89
+ "epoch": 0.6215684243240444,
90
+ "grad_norm": 1.1770172932301648e-05,
91
+ "learning_rate": 1.8224090216217018e-05,
92
+ "loss": 0.0097,
93
+ "step": 6000
94
+ },
95
+ {
96
+ "epoch": 0.6733657930177147,
97
+ "grad_norm": 5.798747224616818e-05,
98
+ "learning_rate": 1.8076097734235104e-05,
99
+ "loss": 0.0628,
100
+ "step": 6500
101
+ },
102
+ {
103
+ "epoch": 0.7251631617113851,
104
+ "grad_norm": 0.0055756960064172745,
105
+ "learning_rate": 1.7928105252253187e-05,
106
+ "loss": 0.0167,
107
+ "step": 7000
108
+ },
109
+ {
110
+ "epoch": 0.7769605304050554,
111
+ "grad_norm": 6.234566535567865e-05,
112
+ "learning_rate": 1.7780112770271273e-05,
113
+ "loss": 0.0342,
114
+ "step": 7500
115
+ },
116
+ {
117
+ "epoch": 0.8287578990987258,
118
+ "grad_norm": 2.065628723357804e-05,
119
+ "learning_rate": 1.7632120288289356e-05,
120
+ "loss": 0.0278,
121
+ "step": 8000
122
+ },
123
+ {
124
+ "epoch": 0.8805552677923961,
125
+ "grad_norm": 0.00034213648177683353,
126
+ "learning_rate": 1.7484127806307442e-05,
127
+ "loss": 0.0777,
128
+ "step": 8500
129
+ },
130
+ {
131
+ "epoch": 0.9323526364860665,
132
+ "grad_norm": 0.0024548424407839775,
133
+ "learning_rate": 1.7336135324325525e-05,
134
+ "loss": 0.0117,
135
+ "step": 9000
136
+ },
137
+ {
138
+ "epoch": 0.9841500051797368,
139
+ "grad_norm": 2.010272328334395e-05,
140
+ "learning_rate": 1.718814284234361e-05,
141
+ "loss": 0.0149,
142
+ "step": 9500
143
+ },
144
+ {
145
+ "epoch": 1.0359473738734073,
146
+ "grad_norm": 3.7440368032548577e-05,
147
+ "learning_rate": 1.7040150360361697e-05,
148
+ "loss": 0.0059,
149
+ "step": 10000
150
+ },
151
+ {
152
+ "epoch": 1.0877447425670776,
153
+ "grad_norm": 9.011640031530987e-06,
154
+ "learning_rate": 1.689215787837978e-05,
155
+ "loss": 0.0245,
156
+ "step": 10500
157
+ },
158
+ {
159
+ "epoch": 1.139542111260748,
160
+ "grad_norm": 3.9126422052504495e-05,
161
+ "learning_rate": 1.6744165396397866e-05,
162
+ "loss": 0.0001,
163
+ "step": 11000
164
+ },
165
+ {
166
+ "epoch": 1.1913394799544184,
167
+ "grad_norm": 8.866464668244589e-06,
168
+ "learning_rate": 1.659617291441595e-05,
169
+ "loss": 0.0,
170
+ "step": 11500
171
+ },
172
+ {
173
+ "epoch": 1.2431368486480887,
174
+ "grad_norm": 0.00013718219997826964,
175
+ "learning_rate": 1.6448180432434035e-05,
176
+ "loss": 0.0181,
177
+ "step": 12000
178
+ },
179
+ {
180
+ "epoch": 1.294934217341759,
181
+ "grad_norm": 3.808485416811891e-06,
182
+ "learning_rate": 1.6300187950452117e-05,
183
+ "loss": 0.0,
184
+ "step": 12500
185
+ },
186
+ {
187
+ "epoch": 1.3467315860354294,
188
+ "grad_norm": 7.217061011033366e-07,
189
+ "learning_rate": 1.6152195468470203e-05,
190
+ "loss": 0.0266,
191
+ "step": 13000
192
+ },
193
+ {
194
+ "epoch": 1.3985289547290998,
195
+ "grad_norm": 7.131046731956303e-05,
196
+ "learning_rate": 1.600420298648829e-05,
197
+ "loss": 0.0266,
198
+ "step": 13500
199
+ },
200
+ {
201
+ "epoch": 1.4503263234227701,
202
+ "grad_norm": 9.411406608705875e-06,
203
+ "learning_rate": 1.5856210504506372e-05,
204
+ "loss": 0.0079,
205
+ "step": 14000
206
+ },
207
+ {
208
+ "epoch": 1.5021236921164405,
209
+ "grad_norm": 7.976142660481855e-05,
210
+ "learning_rate": 1.570821802252446e-05,
211
+ "loss": 0.022,
212
+ "step": 14500
213
+ },
214
+ {
215
+ "epoch": 1.5539210608101108,
216
+ "grad_norm": 1.0580498610579525e-06,
217
+ "learning_rate": 1.556022554054254e-05,
218
+ "loss": 0.0093,
219
+ "step": 15000
220
+ },
221
+ {
222
+ "epoch": 1.6057184295037812,
223
+ "grad_norm": 6.298066182353068e-06,
224
+ "learning_rate": 1.5412233058560627e-05,
225
+ "loss": 0.0,
226
+ "step": 15500
227
+ },
228
+ {
229
+ "epoch": 1.6575157981974515,
230
+ "grad_norm": 2.102418066030065e-12,
231
+ "learning_rate": 1.526424057657871e-05,
232
+ "loss": 0.0024,
233
+ "step": 16000
234
+ },
235
+ {
236
+ "epoch": 1.709313166891122,
237
+ "grad_norm": 3.009020701938425e-06,
238
+ "learning_rate": 1.5116248094596794e-05,
239
+ "loss": 0.009,
240
+ "step": 16500
241
+ },
242
+ {
243
+ "epoch": 1.7611105355847922,
244
+ "grad_norm": 6.2723142946197186e-06,
245
+ "learning_rate": 1.4968255612614882e-05,
246
+ "loss": 0.023,
247
+ "step": 17000
248
+ },
249
+ {
250
+ "epoch": 1.8129079042784626,
251
+ "grad_norm": 5.638932634610683e-06,
252
+ "learning_rate": 1.4820263130632967e-05,
253
+ "loss": 0.0099,
254
+ "step": 17500
255
+ },
256
+ {
257
+ "epoch": 1.8647052729721332,
258
+ "grad_norm": 3.8804391806479543e-05,
259
+ "learning_rate": 1.4672270648651051e-05,
260
+ "loss": 0.0185,
261
+ "step": 18000
262
+ },
263
+ {
264
+ "epoch": 1.9165026416658035,
265
+ "grad_norm": 3.445857828410226e-06,
266
+ "learning_rate": 1.4524278166669134e-05,
267
+ "loss": 0.0,
268
+ "step": 18500
269
+ },
270
+ {
271
+ "epoch": 1.9683000103594739,
272
+ "grad_norm": 0.0029530602041631937,
273
+ "learning_rate": 1.4376285684687218e-05,
274
+ "loss": 0.0243,
275
+ "step": 19000
276
+ },
277
+ {
278
+ "epoch": 2.0200973790531442,
279
+ "grad_norm": 7.133132271519571e-07,
280
+ "learning_rate": 1.4228293202705303e-05,
281
+ "loss": 0.0189,
282
+ "step": 19500
283
+ },
284
+ {
285
+ "epoch": 2.0718947477468146,
286
+ "grad_norm": 0.00042451228364370763,
287
+ "learning_rate": 1.4080300720723387e-05,
288
+ "loss": 0.0062,
289
+ "step": 20000
290
+ },
291
+ {
292
+ "epoch": 2.123692116440485,
293
+ "grad_norm": 1.8360736930844723e-06,
294
+ "learning_rate": 1.3932308238741471e-05,
295
+ "loss": 0.0067,
296
+ "step": 20500
297
+ },
298
+ {
299
+ "epoch": 2.1754894851341553,
300
+ "grad_norm": 0.0001334488915745169,
301
+ "learning_rate": 1.378431575675956e-05,
302
+ "loss": 0.006,
303
+ "step": 21000
304
+ },
305
+ {
306
+ "epoch": 2.2272868538278257,
307
+ "grad_norm": 4.610120413417462e-06,
308
+ "learning_rate": 1.3636323274777644e-05,
309
+ "loss": 0.0061,
310
+ "step": 21500
311
+ },
312
+ {
313
+ "epoch": 2.279084222521496,
314
+ "grad_norm": 2.7200339900446124e-06,
315
+ "learning_rate": 1.3488330792795728e-05,
316
+ "loss": 0.0,
317
+ "step": 22000
318
+ },
319
+ {
320
+ "epoch": 2.3308815912151664,
321
+ "grad_norm": 3.3594403703318676e-06,
322
+ "learning_rate": 1.3340338310813813e-05,
323
+ "loss": 0.009,
324
+ "step": 22500
325
+ },
326
+ {
327
+ "epoch": 2.3826789599088367,
328
+ "grad_norm": 5.500828137883218e-06,
329
+ "learning_rate": 1.3192345828831897e-05,
330
+ "loss": 0.0083,
331
+ "step": 23000
332
+ },
333
+ {
334
+ "epoch": 2.434476328602507,
335
+ "grad_norm": 1103.414306640625,
336
+ "learning_rate": 1.304435334684998e-05,
337
+ "loss": 0.0058,
338
+ "step": 23500
339
+ },
340
+ {
341
+ "epoch": 2.4862736972961774,
342
+ "grad_norm": 1.7569537931194645e-06,
343
+ "learning_rate": 1.2896360864868064e-05,
344
+ "loss": 0.0028,
345
+ "step": 24000
346
+ },
347
+ {
348
+ "epoch": 2.5380710659898478,
349
+ "grad_norm": 9.93580897556967e-07,
350
+ "learning_rate": 1.2748368382886152e-05,
351
+ "loss": 0.0052,
352
+ "step": 24500
353
+ },
354
+ {
355
+ "epoch": 2.589868434683518,
356
+ "grad_norm": 2.71925017225616e-12,
357
+ "learning_rate": 1.2600375900904236e-05,
358
+ "loss": 0.0,
359
+ "step": 25000
360
+ },
361
+ {
362
+ "epoch": 2.6416658033771885,
363
+ "grad_norm": 1.8420889318804257e-05,
364
+ "learning_rate": 1.245238341892232e-05,
365
+ "loss": 0.0195,
366
+ "step": 25500
367
+ },
368
+ {
369
+ "epoch": 2.693463172070859,
370
+ "grad_norm": 8.20615071006614e-07,
371
+ "learning_rate": 1.2304390936940405e-05,
372
+ "loss": 0.008,
373
+ "step": 26000
374
+ },
375
+ {
376
+ "epoch": 2.745260540764529,
377
+ "grad_norm": 6.169057451188564e-05,
378
+ "learning_rate": 1.215639845495849e-05,
379
+ "loss": 0.0005,
380
+ "step": 26500
381
+ },
382
+ {
383
+ "epoch": 2.7970579094581995,
384
+ "grad_norm": 2.9846903544239467e-06,
385
+ "learning_rate": 1.2008405972976574e-05,
386
+ "loss": 0.0037,
387
+ "step": 27000
388
+ },
389
+ {
390
+ "epoch": 2.84885527815187,
391
+ "grad_norm": 8.764583071751986e-06,
392
+ "learning_rate": 1.1860413490994659e-05,
393
+ "loss": 0.0156,
394
+ "step": 27500
395
+ },
396
+ {
397
+ "epoch": 2.9006526468455403,
398
+ "grad_norm": 5.1639810408232734e-05,
399
+ "learning_rate": 1.1712421009012743e-05,
400
+ "loss": 0.0,
401
+ "step": 28000
402
+ },
403
+ {
404
+ "epoch": 2.9524500155392106,
405
+ "grad_norm": 8.454779163002968e-06,
406
+ "learning_rate": 1.1564428527030829e-05,
407
+ "loss": 0.0,
408
+ "step": 28500
409
+ },
410
+ {
411
+ "epoch": 3.004247384232881,
412
+ "grad_norm": 1.0601724653724887e-07,
413
+ "learning_rate": 1.1416436045048913e-05,
414
+ "loss": 0.0,
415
+ "step": 29000
416
+ },
417
+ {
418
+ "epoch": 3.0560447529265513,
419
+ "grad_norm": 3.302725417597685e-06,
420
+ "learning_rate": 1.1268443563066998e-05,
421
+ "loss": 0.0,
422
+ "step": 29500
423
+ },
424
+ {
425
+ "epoch": 3.1078421216202217,
426
+ "grad_norm": 8.728113243705593e-06,
427
+ "learning_rate": 1.1120451081085082e-05,
428
+ "loss": 0.0067,
429
+ "step": 30000
430
+ },
431
+ {
432
+ "epoch": 3.159639490313892,
433
+ "grad_norm": 2.4715068320801947e-06,
434
+ "learning_rate": 1.0972458599103167e-05,
435
+ "loss": 0.0,
436
+ "step": 30500
437
+ },
438
+ {
439
+ "epoch": 3.2114368590075624,
440
+ "grad_norm": 6.171033419377636e-06,
441
+ "learning_rate": 1.0824466117121251e-05,
442
+ "loss": 0.0,
443
+ "step": 31000
444
+ },
445
+ {
446
+ "epoch": 3.2632342277012327,
447
+ "grad_norm": 2.5147855922114104e-06,
448
+ "learning_rate": 1.0676473635139336e-05,
449
+ "loss": 0.0,
450
+ "step": 31500
451
+ },
452
+ {
453
+ "epoch": 3.315031596394903,
454
+ "grad_norm": 2.676899021025747e-05,
455
+ "learning_rate": 1.052848115315742e-05,
456
+ "loss": 0.006,
457
+ "step": 32000
458
+ },
459
+ {
460
+ "epoch": 3.3668289650885734,
461
+ "grad_norm": 2.081859747704584e-05,
462
+ "learning_rate": 1.0380488671175506e-05,
463
+ "loss": 0.0028,
464
+ "step": 32500
465
+ },
466
+ {
467
+ "epoch": 3.418626333782244,
468
+ "grad_norm": 2.3868110474722926e-06,
469
+ "learning_rate": 1.023249618919359e-05,
470
+ "loss": 0.0001,
471
+ "step": 33000
472
+ },
473
+ {
474
+ "epoch": 3.470423702475914,
475
+ "grad_norm": 2.7923347261094023e-06,
476
+ "learning_rate": 1.0084503707211675e-05,
477
+ "loss": 0.0,
478
+ "step": 33500
479
+ },
480
+ {
481
+ "epoch": 3.5222210711695845,
482
+ "grad_norm": 4.678757704823511e-06,
483
+ "learning_rate": 9.93651122522976e-06,
484
+ "loss": 0.0,
485
+ "step": 34000
486
+ },
487
+ {
488
+ "epoch": 3.574018439863255,
489
+ "grad_norm": 3.305537575215567e-06,
490
+ "learning_rate": 9.788518743247844e-06,
491
+ "loss": 0.0,
492
+ "step": 34500
493
+ },
494
+ {
495
+ "epoch": 3.625815808556925,
496
+ "grad_norm": 2.4619773739686934e-06,
497
+ "learning_rate": 9.640526261265928e-06,
498
+ "loss": 0.0,
499
+ "step": 35000
500
+ },
501
+ {
502
+ "epoch": 3.6776131772505956,
503
+ "grad_norm": 2.973723951527063e-07,
504
+ "learning_rate": 9.492533779284013e-06,
505
+ "loss": 0.0,
506
+ "step": 35500
507
+ },
508
+ {
509
+ "epoch": 3.729410545944266,
510
+ "grad_norm": 5.624352183986048e-07,
511
+ "learning_rate": 9.344541297302097e-06,
512
+ "loss": 0.012,
513
+ "step": 36000
514
+ },
515
+ {
516
+ "epoch": 3.7812079146379363,
517
+ "grad_norm": 1.8933849332825048e-07,
518
+ "learning_rate": 9.196548815320182e-06,
519
+ "loss": 0.0,
520
+ "step": 36500
521
+ },
522
+ {
523
+ "epoch": 3.8330052833316066,
524
+ "grad_norm": 3.9811013266444206e-05,
525
+ "learning_rate": 9.048556333338268e-06,
526
+ "loss": 0.0,
527
+ "step": 37000
528
+ },
529
+ {
530
+ "epoch": 3.884802652025277,
531
+ "grad_norm": 2.1272378944559023e-05,
532
+ "learning_rate": 8.900563851356352e-06,
533
+ "loss": 0.0073,
534
+ "step": 37500
535
+ },
536
+ {
537
+ "epoch": 3.9366000207189473,
538
+ "grad_norm": 8.419656296609901e-07,
539
+ "learning_rate": 8.752571369374436e-06,
540
+ "loss": 0.0045,
541
+ "step": 38000
542
+ },
543
+ {
544
+ "epoch": 3.9883973894126177,
545
+ "grad_norm": 1.4807918660153518e-06,
546
+ "learning_rate": 8.604578887392521e-06,
547
+ "loss": 0.0,
548
+ "step": 38500
549
+ },
550
+ {
551
+ "epoch": 4.0401947581062885,
552
+ "grad_norm": 2.789050199680787e-07,
553
+ "learning_rate": 8.456586405410605e-06,
554
+ "loss": 0.0,
555
+ "step": 39000
556
+ },
557
+ {
558
+ "epoch": 4.091992126799958,
559
+ "grad_norm": 7.712332603659888e-07,
560
+ "learning_rate": 8.30859392342869e-06,
561
+ "loss": 0.0,
562
+ "step": 39500
563
+ },
564
+ {
565
+ "epoch": 4.143789495493629,
566
+ "grad_norm": 1.126294500863878e-05,
567
+ "learning_rate": 8.160601441446774e-06,
568
+ "loss": 0.0,
569
+ "step": 40000
570
+ },
571
+ {
572
+ "epoch": 4.195586864187299,
573
+ "grad_norm": 1.1078836905653588e-05,
574
+ "learning_rate": 8.01260895946486e-06,
575
+ "loss": 0.0,
576
+ "step": 40500
577
+ },
578
+ {
579
+ "epoch": 4.24738423288097,
580
+ "grad_norm": 4.333252491051098e-06,
581
+ "learning_rate": 7.864616477482945e-06,
582
+ "loss": 0.0,
583
+ "step": 41000
584
+ },
585
+ {
586
+ "epoch": 4.29918160157464,
587
+ "grad_norm": 7.190360065578716e-06,
588
+ "learning_rate": 7.71662399550103e-06,
589
+ "loss": 0.0,
590
+ "step": 41500
591
+ },
592
+ {
593
+ "epoch": 4.350978970268311,
594
+ "grad_norm": 6.172657776915003e-06,
595
+ "learning_rate": 7.568631513519114e-06,
596
+ "loss": 0.0,
597
+ "step": 42000
598
+ },
599
+ {
600
+ "epoch": 4.4027763389619805,
601
+ "grad_norm": 1.1028377144839396e-07,
602
+ "learning_rate": 7.420639031537199e-06,
603
+ "loss": 0.0,
604
+ "step": 42500
605
+ },
606
+ {
607
+ "epoch": 4.454573707655651,
608
+ "grad_norm": 7.63295773253958e-08,
609
+ "learning_rate": 7.272646549555283e-06,
610
+ "loss": 0.0,
611
+ "step": 43000
612
+ },
613
+ {
614
+ "epoch": 4.506371076349321,
615
+ "grad_norm": 8.308877568197204e-07,
616
+ "learning_rate": 7.124654067573368e-06,
617
+ "loss": 0.0,
618
+ "step": 43500
619
+ },
620
+ {
621
+ "epoch": 4.558168445042992,
622
+ "grad_norm": 8.788915550894671e-08,
623
+ "learning_rate": 6.976661585591452e-06,
624
+ "loss": 0.0,
625
+ "step": 44000
626
+ },
627
+ {
628
+ "epoch": 4.609965813736662,
629
+ "grad_norm": 5.980305104458239e-07,
630
+ "learning_rate": 6.828669103609537e-06,
631
+ "loss": 0.0,
632
+ "step": 44500
633
+ },
634
+ {
635
+ "epoch": 4.661763182430333,
636
+ "grad_norm": 0.00010543836833676323,
637
+ "learning_rate": 6.680676621627622e-06,
638
+ "loss": 0.0034,
639
+ "step": 45000
640
+ },
641
+ {
642
+ "epoch": 4.713560551124003,
643
+ "grad_norm": 2.6961990442941897e-05,
644
+ "learning_rate": 6.532684139645706e-06,
645
+ "loss": 0.0,
646
+ "step": 45500
647
+ },
648
+ {
649
+ "epoch": 4.765357919817673,
650
+ "grad_norm": 2.6214322133455426e-05,
651
+ "learning_rate": 6.384691657663791e-06,
652
+ "loss": 0.0006,
653
+ "step": 46000
654
+ },
655
+ {
656
+ "epoch": 4.817155288511343,
657
+ "grad_norm": 6.838554782007122e-06,
658
+ "learning_rate": 6.236699175681876e-06,
659
+ "loss": 0.0,
660
+ "step": 46500
661
+ },
662
+ {
663
+ "epoch": 4.868952657205014,
664
+ "grad_norm": 1.3388408660830464e-05,
665
+ "learning_rate": 6.08870669369996e-06,
666
+ "loss": 0.0,
667
+ "step": 47000
668
+ },
669
+ {
670
+ "epoch": 4.920750025898684,
671
+ "grad_norm": 1.1914085007447284e-06,
672
+ "learning_rate": 5.940714211718045e-06,
673
+ "loss": 0.0,
674
+ "step": 47500
675
+ },
676
+ {
677
+ "epoch": 4.972547394592355,
678
+ "grad_norm": 1.9319197235745378e-05,
679
+ "learning_rate": 5.792721729736129e-06,
680
+ "loss": 0.0,
681
+ "step": 48000
682
+ },
683
+ {
684
+ "epoch": 5.024344763286025,
685
+ "grad_norm": 7.528370815634844e-07,
686
+ "learning_rate": 5.6447292477542145e-06,
687
+ "loss": 0.0058,
688
+ "step": 48500
689
+ },
690
+ {
691
+ "epoch": 5.0761421319796955,
692
+ "grad_norm": 1.075523073268414e-06,
693
+ "learning_rate": 5.496736765772299e-06,
694
+ "loss": 0.0,
695
+ "step": 49000
696
+ },
697
+ {
698
+ "epoch": 5.1279395006733655,
699
+ "grad_norm": 4.6377437001865474e-07,
700
+ "learning_rate": 5.348744283790383e-06,
701
+ "loss": 0.0,
702
+ "step": 49500
703
+ },
704
+ {
705
+ "epoch": 5.179736869367036,
706
+ "grad_norm": 6.992227667979023e-07,
707
+ "learning_rate": 5.2007518018084694e-06,
708
+ "loss": 0.0,
709
+ "step": 50000
710
+ },
711
+ {
712
+ "epoch": 5.231534238060706,
713
+ "grad_norm": 2.332795929760323e-06,
714
+ "learning_rate": 5.052759319826553e-06,
715
+ "loss": 0.0,
716
+ "step": 50500
717
+ },
718
+ {
719
+ "epoch": 5.283331606754377,
720
+ "grad_norm": 5.32125454810739e-07,
721
+ "learning_rate": 4.9047668378446374e-06,
722
+ "loss": 0.0,
723
+ "step": 51000
724
+ }
725
+ ],
726
+ "logging_steps": 500,
727
+ "max_steps": 67571,
728
+ "num_input_tokens_seen": 0,
729
+ "num_train_epochs": 7,
730
+ "save_steps": 500,
731
+ "stateful_callbacks": {
732
+ "TrainerControl": {
733
+ "args": {
734
+ "should_epoch_stop": false,
735
+ "should_evaluate": false,
736
+ "should_log": false,
737
+ "should_save": true,
738
+ "should_training_stop": false
739
+ },
740
+ "attributes": {}
741
+ }
742
+ },
743
+ "total_flos": 1.4080058755834675e+17,
744
+ "train_batch_size": 4,
745
+ "trial_name": null,
746
+ "trial_params": null
747
+ }
kraken_router/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49df63d21c21735b6ffb8853fe9bbbebb0e5df9f446e7a960be8881b2f46ec03
3
+ size 5048
kraken_router/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_kraken.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TextClassificationPipeline
3
+ from configuration_kraken import KrakenConfig
4
+ import tokenizer_template_switch
5
+
6
+ class KrakenForCausalLM(PreTrainedModel):
7
+ config_class = KrakenConfig
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ self.tokenizers = {key: AutoTokenizer.from_pretrained(name, device_map="auto") for key, name in config.config_dict['tokenizers'].items()}
12
+ self.models = self.load_expert_models(config.config_dict['models'], config.config_dict['quantization'])
13
+ self.router_model = AutoModelForSequenceClassification.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
14
+ self.tokenizer = AutoTokenizer.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
15
+ self.router = TextClassificationPipeline(model=self.router_model, tokenizer=self.tokenizer)
16
+ self.models_indices = config.config_dict['class_indices']
17
+
18
+ def load_expert_models(self, models_dict, quantization_dict):
19
+ models = {}
20
+ for key, name in models_dict.items():
21
+ quantization = quantization_dict.get(key)
22
+ if quantization == "8bit":
23
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype="auto")
24
+ elif quantization == "4bit":
25
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_4bit=True, torch_dtype="auto")
26
+ elif quantization == "awq":
27
+ models[key] = self.load_awq_model(name)
28
+ else:
29
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", torch_dtype="auto")
30
+ return models
31
+
32
+ def load_awq_model(self, name):
33
+ return AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto")
34
+
35
+ def tokenize_inputs(self, text, model_key):
36
+ return self.tokenizers[model_key](text, return_tensors="pt")
37
+
38
+ def determine_model(self, text):
39
+ prediction = self.router(text)[0]["label"]
40
+ model_decision_index = self.models_indices[prediction]
41
+ model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
42
+ return model_keys[model_decision_index]
43
+
44
+ def expert_tokenizer(self, text):
45
+ model_key = self.determine_model(text)
46
+ return self.tokenizers[model_key]
47
+
48
+
49
+ def generate(self, input_ids, **generate_kwargs):
50
+ # Tokenize the input_ids
51
+ text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=False)[0]
52
+
53
+ msgs = tokenizer_template_switch.recover_chat_messages(text, self.tokenizer)
54
+ if msgs and msgs[0]['role'] == 'system' and msgs[0]['content']=='<|im_start|>system':
55
+ # Delete the first element
56
+ msgs.pop(0)
57
+ # Check if the last element has the role 'assistant'
58
+ if msgs and msgs[-1]['role'] == 'assistant':
59
+ # Delete the last element
60
+ msgs.pop()
61
+
62
+ # Determine the model key using the existing routing logic
63
+ model_key = self.determine_model(text)
64
+ # Show the routing result
65
+ print(f"Choosing {model_key} ..")
66
+ # Retrieve the model from the dictionary
67
+ model = self.models[model_key]
68
+
69
+ mod_txt = self.tokenizers[model_key].apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
70
+ current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
71
+
72
+ # Tokenize accordingly to the best model
73
+
74
+ tok = self.tokenizers[model_key](mod_txt, return_tensors="pt")
75
+ tok_input_ids = tok.input_ids.to(current_device)
76
+ tok_attention_mask = tok.attention_mask.to(current_device)
77
+
78
+ # Generate text using the retrieved model
79
+ return model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
80
+
81
+
82
+
tokenizer_template_switch.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from transformers import AutoTokenizer
3
+
4
+ def extract_separators(template):
5
+ """
6
+ Extracts separators used in the tokenization template.
7
+ """
8
+ # Adjust the regex to correctly match the specific pattern between '{{' and '+ message["content"] +'
9
+ pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]"
10
+ matches = re.findall(pattern, template)
11
+ # Clean up any extra spaces and return the matches
12
+ separators = [match.strip() for match in matches]
13
+
14
+ if any("message['role']" in element for element in separators):
15
+ roles = ["system", "user", "assistant"]
16
+ separators_ = []
17
+ for role in roles:
18
+ separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'",""))
19
+ return separators_
20
+
21
+ return separators
22
+
23
+ def detect_eos_token(jinja_template, tokenizer):
24
+ if "<|im_end|>" in jinja_template:
25
+ return "<|im_end|>"
26
+ if "</s>" in jinja_template:
27
+ return "</s>"
28
+ if "eos_token" in jinja_template:
29
+ return tokenizer.eos_token
30
+ if "<|endoftext|>" in jinja_template:
31
+ return tokenizer.eos_token
32
+ else:
33
+ return "<|endoftext|>"
34
+
35
+ def recover_messages(formatted_message, separators, eos_token):
36
+ """
37
+ Recovers the original messages from the formatted message string.
38
+ """
39
+ # Split the formatted message using the end-of-string token
40
+ split_messages = formatted_message.split(eos_token)
41
+
42
+ # Remove the last empty string if it exists due to a trailing separator
43
+ if split_messages and split_messages[-1].strip() == '':
44
+ split_messages.pop()
45
+
46
+ # Prepare the list to hold the recovered messages
47
+ recovered_messages = []
48
+
49
+ # Define roles after the first message, alternating between "user" and "assistant"
50
+ alternate_roles = ["user", "assistant"]
51
+
52
+ # Iterate over the split messages
53
+ for index, message_content in enumerate(split_messages):
54
+ # Determine the role, starting with "system" for the first message
55
+ # then alternating between "user" and "assistant" for subsequent messages
56
+ if index == 0:
57
+ role = "system"
58
+ else:
59
+ role = alternate_roles[(index - 1) % 2]
60
+
61
+ # Clean the message content by removing leading/trailing whitespace and separators
62
+ clean_content = message_content.strip()
63
+ for separator in separators:
64
+ clean_content = clean_content.replace(separator.strip("'"), '', 1).strip()
65
+
66
+ # Append the cleaned message with its role to the list
67
+ recovered_messages.append({"role": role, "content": clean_content})
68
+
69
+ return recovered_messages
70
+
71
+ def recover_chat_messages(tokenized_chat, tokenizer):
72
+ """
73
+ Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries.
74
+ """
75
+ jinja_template = tokenizer.chat_template
76
+ separators = extract_separators(jinja_template)
77
+ eos_token = eos_token = detect_eos_token(jinja_template, tokenizer)
78
+ recovered_messages = recover_messages(tokenized_chat, separators, eos_token)
79
+ return recovered_messages
80
+
81
+ # Example usage
82
+ if __name__ == "__main__":
83
+ checkpoint = "Qwen/Qwen1.5-0.5B"
84
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
85
+
86
+ messages = [
87
+ {
88
+ "role": "system",
89
+ "content": "You are a friendly chatbot who always responds in the style of a pirate",
90
+ },
91
+ {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
92
+ ]
93
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False)
94
+ print(tokenized_chat)
95
+
96
+ recovered_messages = recover_chat_messages(tokenized_chat, tokenizer)
97
+ print(recovered_messages)