File size: 5,485 Bytes
4574d1c
 
 
 
 
 
 
 
bd0b6a8
 
 
 
 
 
 
1edde50
 
 
bd0b6a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2ce159
bd0b6a8
e2ce159
406b469
 
 
bd0b6a8
e2ce159
 
 
 
 
bd0b6a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f52045d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
---
license: apache-2.0
datasets:
- kunishou/oasst1-89k-ja
- kunishou/databricks-dolly-15k-ja
language:
- ja
---
# How to use

We write our prompts in the ChatML format.

### With vLLM (recommended for much faster inference)

<details><summary>Install vLLM</summary>

  [Reference](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
  
```bash
  pip install vllm
```
</details>

```python
from vllm import LLM, SamplingParams
model_name = "lightblue/jod"
llm = LLM(model=model_name)

SYSTEM_MESSAGE = "You are a helpful assistant."
def process_chat_history(next_user_msg, text_chat_history = []):
    prompt_text = "<|im_start|>system\n"
    prompt_text += SYSTEM_MESSAGE
    prompt_text += "<|im_end|>\n\n"

    for user_msg, ai_msg in text_chat_history:
        prompt_text += "<|im_start|>user\n"
        prompt_text += user_msg
        prompt_text += "<|im_end|>\n\n"
        prompt_text += "<|im_start|>assistant\n"
        prompt_text += ai_msg
        prompt_text += "<|im_end|>\n\n"

    prompt_text += "<|im_start|>user\n"
    prompt_text += next_user_msg
    prompt_text += "<|im_end|>\n\n"
    prompt_text += "<|im_start|>assistant\n"
    return prompt_text

user_prompt = "日本の一番高い山は?"
prompt = process_chat_history(user_prompt)
sampling_params = SamplingParams(temperature=0, max_tokens=528)
outputs = llm.generate(prompt, sampling_params)
bot_message = outputs[0].outputs[0].text.strip()
print(bot_message)
```


### With Huggingface

```python
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_name = "lightblue/jod"

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(
    model_dir, torch_dtype=torch.bfloat16, device_map='auto', load_in_4bit=True,
)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

SYSTEM_MESSAGE = "You are a helpful assistant."
def process_chat_history(next_user_msg, text_chat_history = []):
    prompt_text = "<|im_start|>system\n"
    prompt_text += SYSTEM_MESSAGE
    prompt_text += "<|im_end|>\n\n"

    for user_msg, ai_msg in text_chat_history:
        prompt_text += "<|im_start|>user\n"
        prompt_text += user_msg
        prompt_text += "<|im_end|>\n\n"
        prompt_text += "<|im_start|>assistant\n"
        prompt_text += ai_msg
        prompt_text += "<|im_end|>\n\n"

    prompt_text += "<|im_start|>user\n"
    prompt_text += next_user_msg
    prompt_text += "<|im_end|>\n\n"
    prompt_text += "<|im_start|>assistant\n"
    return prompt_text

user_prompt = "日本の一番高い山は?"
prompt = process_chat_history(user_prompt)
bot_message = pipe(do_closed_qa(test_article, question), max_new_tokens=128, temperature=0)[0]["generated_text"]
print(bot_message)
```


# Training details

We trained on the following 3 datasets:
* (J) - [JASTER](https://github.com/llm-jp/llm-jp-eval)
* (O) - [kunishou/oasst1-89k-ja](https://huggingface.co/datasets/kunishou/oasst1-89k-ja/)
* (D) - [kunishou/databricks-dolly-15k-ja](https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja/)

using the ([Open-Orca/Mistral-7B-SlimOrca](https://huggingface.co/Open-Orca/Mistral-7B-SlimOrca)) model as our base checkpoint.

This model was trained using the ChatML format, so it should be used for inference using the ChatML chatbot format.
We chose this format as the base model ([Open-Orca/Mistral-7B-SlimOrca](https://huggingface.co/Open-Orca/Mistral-7B-SlimOrca)) was trained with this format, and we find the chatbot format more compelling for practical use compared to the Alpaca style instruction format.

We trained for 1 epoch using the following Axolotl config. (Early stopping was not performed during our training.)
<details><summary>Axolotl config .yaml</summary>

  ```yaml
  base_model: Open-Orca/Mistral-7B-SlimOrca
base_model_config: Open-Orca/Mistral-7B-SlimOrca
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: ./data/jaster_plus.jsonl
    ds_type: json # see other options below
    type: sharegpt
    conversation: chatml
dataset_prepared_path: false
val_set_size: 0.002
output_dir: ./train_output/openorca-mistral-jaster-1epoch

use_wandb: true
wandb_project: \<HIDDEN\>
wandb_entity: \<HIDDEN\>

debug: 

adapter: qlora
lora_model_dir:

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

gradient_accumulation_steps: 1
micro_batch_size: 10
eval_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience: 10
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 10
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps: 10
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"
```

</details>

[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)