ChatVector
Collection
モデル間の重みの加減算のみで構築した日本語LLM
•
4 items
•
Updated
•
1
Tora-7B-v0.1 = NTQAI/chatntq-ja-7b-v1.0 + (openchat/openchat-3.5-0106 - mistralai/Mistral-7B-v0.1)
@jovyan様の実装を参考に下記のコードでモデルを作成しました。
import torch
from transformers import AutoModelForCausalLM
def build_chat_vector_model(
base_model_name,
inst_model_name,
target_model_name,
skip_layers,
):
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
inst_model = AutoModelForCausalLM.from_pretrained(
inst_model_name,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
target_model = AutoModelForCausalLM.from_pretrained(
target_model_name,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
# 英語ベースモデル
for k, v in base_model.state_dict().items():
print(k, v.shape)
# 日本語継続事前学習モデル
for k, v in target_model.state_dict().items():
print(k, v.shape)
# 除外対象
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]
for k, v in target_model.state_dict().items():
# layernormも除外
if (k in skip_layers) or ("layernorm" in k):
continue
chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
new_v = v + chat_vector.to(v.device)
v.copy_(new_v)
target_model.save_pretrained("./chat_model")
return
if __name__ == '__main__':
base_model_name = "mistralai/Mistral-7B-v0.1"
inst_model_name = "openchat/openchat-3.5-0106"
target_model_name = "NTQAI/chatntq-ja-7b-v1.0"
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]
build_chat_vector_model(
base_model_name=base_model_name,
inst_model_name=inst_model_name,
target_model_name=target_model_name,
skip_layers=skip_layers
)
model | category | score | ver |
---|---|---|---|
Tora-7B-v0.1 | Writing | 5.4 | single-turn |
Tora-7B-v0.1 | Roleplay | 6.6 | single-turn |
Tora-7B-v0.1 | Reasoning | 7.3 | single-turn |
Tora-7B-v0.1 | Math | 3.5 | single-turn |
Tora-7B-v0.1 | Coding | 4.7 | single-turn |
Tora-7B-v0.1 | Extraction | 6.3 | single-turn |
Tora-7B-v0.1 | STEM | 7.2 | single-turn |
Tora-7B-v0.1 | Humanities | 8.5 | single-turn |
model | category | score |
---|---|---|
Tora-7B-v0.1 | Writing | 7.55 |
Tora-7B-v0.1 | Roleplay | 7.5 |
Tora-7B-v0.1 | Reasoning | 4.35 |
Tora-7B-v0.1 | Math | 2.95 |
Tora-7B-v0.1 | Coding | 3.7 |
Tora-7B-v0.1 | Extraction | 7.0 |
Tora-7B-v0.1 | STEM | 7.85 |
Tora-7B-v0.1 | Humanities | 9.65 |
Tora-7B-v0.1 | AVG_mtbench | 6.319 |
model | category | score |
---|---|---|
Tora-7B-v0.1 | NLI | 0.588 |
Tora-7B-v0.1 | QA | 0.1708 |
Tora-7B-v0.1 | RC | 0.798 |
Tora-7B-v0.1 | MC | 0.25 |
Tora-7B-v0.1 | EL | 0.0 |
Tora-7B-v0.1 | FA | 0.1359 |
Tora-7B-v0.1 | MR | 0.2 |
ChatVectorの記事を執筆してくださった@jovyan様に深くお礼申し上げます。