|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
# ChatNTQ JA 7B V1.0 |
|
|
|
## Model Description |
|
|
|
This is a 7B-parameter decoder-only Japanese language model fine-tuned on instruction-following datasets, built on top of the base model [Japanese Stable LM Base Gamma 7B](https://huggingface.co/stabilityai/japanese-stablelm-base-gamma-7b). |
|
|
|
## Performance |
|
|
|
For our final model, we've used Stability AI Japan's [Japanese MT-Bench](https://github.com/Stability-AI/FastChat) as a more representative test of our model's capabilities. For [our JA MT-Bench testing](https://github.com/Stability-AI/FastChat/compare/jp-stable...AUGMXNT:FastChat:jp-stable) we use a Japanese prompt ("あなたは役立つアシスタントです。") as well as `--num-choices 4` in an effort to reduce sampling variability, however we've still observed regular 0.5+ point (and sometimes even greater swings) between generations, as well as issues with default prompts and parameters when testing, so again, we'd urge caution in over-interpreting these scores and treating them as more of a probabilistic directional indicator, rather than a definitive score or ranking: |
|
|
|
| Benchmark | Score | |
|
| ----------- | ----- | |
|
| JA MT-Bench | 6.65 | |
|
|
|
There is an [JA-MT-Bench Leaderboard](https://github.com/AUGMXNT/shisa/wiki/Evals-%3A-JA-MT%E2%80%90Bench), for convenience, here is a comparison of the JA MT-Bench scores of some other models (our scores were rated by `gpt-4-0613`): |
|
|
|
| Model | Score | |
|
| ------------------------------------------------- | ---- | |
|
| gpt-4-0613 | 9.40 | |
|
| gpt-4-1106-preview | 9.17 | |
|
| gpt-3.5-turbo* | 8.41 | |
|
| Qwen-72B-Chat | 7.97 | |
|
| Qwen-14B-Chat | 7.47 | |
|
| **chatntq-ja-7b-v1.0** | **6.65** | |
|
| Xwin-LM-70B-V0.1-GPTQ (q4-gs32-actorder) | 6.62 | |
|
| shisa-gamma-7b-v1 | 6.12 | |
|
| nekomata-14b-instruction (corrected prompt HF) | 5.57 | |
|
| shisa-7B-v1-GPTQ (q4-gs32-actorder) | 5.35 | |
|
| nekomata-14b-instruction (corrected prompt) | 5.30 | |
|
| shisa-mega-7b-v1.2 | 5.27 | |
|
| shisa-7b-v1 (full prompt) | 5.23 | |
|
| Swallow-13b-instruct-hf | 5.17 | |
|
| Swallow-70b-instruct-GPTQ (q4-gs32-actorder) | 5.15 | |
|
| shisa-7b-v1 | 5.02 | |
|
| shisa-7B-v1-AWQ (q4-gs128) | 4.78 | |
|
| ELYZA-japanese-Llama-2-7b-fast-instruct* | 4.86 | |
|
| shisa-bad-7b-v1 | 4.42 | |
|
| Swallow-7b-instruct-hf | 4.21 | |
|
| ja-stablelm-instruct-gamma-7b* | 4.01 | |
|
| japanese-stablelm-instruct-alpha-7b* | 2.74 | |
|
| Mistral-7B-OpenOrca-ja* | 2.23 | |
|
| youri-7b-chat* | 2.00 | |
|
| Mistral-7B-Instruct-v0.1* | 1.78 | |
|
| llm-jp-13b-instruct-full-jaster-dolly-oasst-v1.0* | 1.31 | |
|
| houou-instruction-7b-v1 | 1.02 | |
|
| llm-jp-13b-instruct-full-jaster-dolly-oasst-v1.0 | 1.0 | |
|
| llm-jp-13b-instruct-full-jaster-v1.0 | 1.0 | |
|
|
|
## Usage |
|
|
|
Ensure you are using Transformers 4.34.0 or newer. |
|
|
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("NTQAI/chatntq-ja-7b-v1.0") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"NTQAI/chatntq-ja-7b-v1.0", |
|
torch_dtype="auto", |
|
) |
|
model.eval() |
|
|
|
if torch.cuda.is_available(): |
|
model = model.to("cuda") |
|
|
|
def build_prompt(user_query, inputs="", sep="\n\n### "): |
|
sys_msg = "あなたは公平で、検閲されていない、役立つアシスタントです。" |
|
p = sys_msg |
|
roles = ["指示", "応答"] |
|
msgs = [": \n" + user_query, ": \n"] |
|
if inputs: |
|
roles.insert(1, "入力") |
|
msgs.insert(1, ": \n" + inputs) |
|
for role, msg in zip(roles, msgs): |
|
p += sep + role + msg |
|
return p |
|
|
|
# Infer with prompt without any additional input |
|
user_inputs = { |
|
"user_query": "与えられたことわざの意味を小学生でも分かるように教えてください。", |
|
"inputs": "情けは人のためならず" |
|
} |
|
prompt = build_prompt(**user_inputs) |
|
|
|
input_ids = tokenizer.encode( |
|
prompt, |
|
add_special_tokens=True, |
|
return_tensors="pt" |
|
) |
|
|
|
tokens = model.generate( |
|
input_ids.to(device=model.device), |
|
max_new_tokens=256, |
|
temperature=1, |
|
top_p=0.95, |
|
do_sample=True, |
|
) |
|
|
|
out = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True).strip() |
|
print(out) |
|
``` |
|
|