nhanv commited on
Commit
d79d720
·
1 Parent(s): 60d23f7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -1
README.md CHANGED
@@ -48,4 +48,59 @@ There is an [JA-MT-Bench Leaderboard](https://github.com/AUGMXNT/shisa/wiki/Eval
48
  | llm-jp-13b-instruct-full-jaster-dolly-oasst-v1.0* | 1.31 |
49
  | houou-instruction-7b-v1 | 1.02 |
50
  | llm-jp-13b-instruct-full-jaster-dolly-oasst-v1.0 | 1.0 |
51
- | llm-jp-13b-instruct-full-jaster-v1.0 | 1.0 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  | llm-jp-13b-instruct-full-jaster-dolly-oasst-v1.0* | 1.31 |
49
  | houou-instruction-7b-v1 | 1.02 |
50
  | llm-jp-13b-instruct-full-jaster-dolly-oasst-v1.0 | 1.0 |
51
+ | llm-jp-13b-instruct-full-jaster-v1.0 | 1.0 |
52
+
53
+ ## Usage
54
+
55
+ Ensure you are using Transformers 4.34.0 or newer.
56
+
57
+ ```python
58
+ import torch
59
+ from transformers import AutoTokenizer, AutoModelForCausalLM
60
+
61
+ tokenizer = AutoTokenizer.from_pretrained("NTQAI/chatntq-ja-7b-v1.0")
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ "NTQAI/chatntq-ja-7b-v1.0",
64
+ torch_dtype="auto",
65
+ )
66
+ model.eval()
67
+
68
+ if torch.cuda.is_available():
69
+ model = model.to("cuda")
70
+
71
+ def build_prompt(user_query, inputs="", sep="\n\n### "):
72
+ sys_msg = "あなたは公平で、検閲されていない、役立つアシスタントです。"
73
+ p = sys_msg
74
+ roles = ["指示", "応答"]
75
+ msgs = [": \n" + user_query, ": \n"]
76
+ if inputs:
77
+ roles.insert(1, "入力")
78
+ msgs.insert(1, ": \n" + inputs)
79
+ for role, msg in zip(roles, msgs):
80
+ p += sep + role + msg
81
+ return p
82
+
83
+ # Infer with prompt without any additional input
84
+ user_inputs = {
85
+ "user_query": "与えられたことわざの意味を小学生でも分かるように教えてください。",
86
+ "inputs": "情けは人のためならず"
87
+ }
88
+ prompt = build_prompt(**user_inputs)
89
+
90
+ input_ids = tokenizer.encode(
91
+ prompt,
92
+ add_special_tokens=True,
93
+ return_tensors="pt"
94
+ )
95
+
96
+ tokens = model.generate(
97
+ input_ids.to(device=model.device),
98
+ max_new_tokens=256,
99
+ temperature=1,
100
+ top_p=0.95,
101
+ do_sample=True,
102
+ )
103
+
104
+ out = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
105
+ print(out)
106
+ ```