feat: add Metharme prompt strategy (#446)
Browse files* Add Metharme tokenizing strategy
This strategy accounts for how the Metharme JSONLs are formatted as well as adds duplicated EOS tokens which can help trim model output length.
I haven't gotten the chance to test this yet, and probably won't have the chance for quite a bit, so I'm committing this now.
* Redo Metharme tokenizing strategy
lol
* fix: oops
* Rearrange a conditional
* chore: reformat code in accordance with linter
* chore: Make lint not freak out
* chore: fix lint
---------
Co-authored-by: NanoCode012 <[email protected]>
- README.md +4 -0
- src/axolotl/prompt_strategies/metharme.py +76 -0
README.md
CHANGED
@@ -257,6 +257,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|
257 |
```json
|
258 |
{"conversations": [{"role": "...", "value": "..."}]}
|
259 |
```
|
|
|
|
|
|
|
|
|
260 |
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
261 |
```json
|
262 |
{"conversations": [{"role": "...", "value": "..."}]}
|
|
|
257 |
```json
|
258 |
{"conversations": [{"role": "...", "value": "..."}]}
|
259 |
```
|
260 |
+
- `metharme`: instruction, adds additional eos tokens
|
261 |
+
```json
|
262 |
+
{"prompt": "...", "generation": "..."}
|
263 |
+
```
|
264 |
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
265 |
```json
|
266 |
{"conversations": [{"role": "...", "value": "..."}]}
|
src/axolotl/prompt_strategies/metharme.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
7 |
+
from axolotl.prompters import AlpacaPrompter
|
8 |
+
|
9 |
+
LOG = logging.getLogger("axolotl")
|
10 |
+
|
11 |
+
IGNORE_TOKEN_ID = -100
|
12 |
+
|
13 |
+
# pylint: disable=duplicate-code
|
14 |
+
|
15 |
+
|
16 |
+
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
17 |
+
"""
|
18 |
+
Tokenizing strategy for the Metharme models
|
19 |
+
"""
|
20 |
+
|
21 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
22 |
+
return (prompt["prompt"], "", prompt["generation"])
|
23 |
+
|
24 |
+
def _tokenize(
|
25 |
+
self,
|
26 |
+
prompt: str,
|
27 |
+
add_eos_token: bool = True,
|
28 |
+
strip_bos_token: bool = False,
|
29 |
+
num_eos_tokens: int = 3,
|
30 |
+
):
|
31 |
+
result = self.tokenizer(
|
32 |
+
prompt,
|
33 |
+
truncation=True,
|
34 |
+
max_length=self.sequence_len,
|
35 |
+
padding=False,
|
36 |
+
return_tensors=None,
|
37 |
+
)
|
38 |
+
if len(result["input_ids"]) == 0:
|
39 |
+
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
40 |
+
# If there's already an EOS token there, subtract from the number added
|
41 |
+
if result["input_ids"][-1] == self.tokenizer.eos_token_id:
|
42 |
+
num_eos_tokens -= 1
|
43 |
+
|
44 |
+
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
|
45 |
+
for _ in range(num_eos_tokens):
|
46 |
+
if len(result["input_ids"]) < self.sequence_len:
|
47 |
+
result["input_ids"].append(self.tokenizer.eos_token_id)
|
48 |
+
result["attention_mask"].append(1)
|
49 |
+
|
50 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
51 |
+
result["input_ids"] = result["input_ids"][1:]
|
52 |
+
result["attention_mask"] = result["attention_mask"][1:]
|
53 |
+
|
54 |
+
result["labels"] = result["input_ids"].copy()
|
55 |
+
return result
|
56 |
+
|
57 |
+
|
58 |
+
class MetharmePrompter(AlpacaPrompter):
|
59 |
+
"""
|
60 |
+
Prompter for the Metharme models.
|
61 |
+
"""
|
62 |
+
|
63 |
+
system_prompt = ""
|
64 |
+
system_no_input_prompt = ""
|
65 |
+
system_format = ""
|
66 |
+
turn_format = "{instruction}"
|
67 |
+
turn_no_input_format = "{instruction}"
|
68 |
+
|
69 |
+
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
|
70 |
+
pass
|
71 |
+
|
72 |
+
|
73 |
+
def load(tokenizer, cfg):
|
74 |
+
return MetharmePromptTokenizingStrategy(
|
75 |
+
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
76 |
+
)
|