YeungNLP commited on
Commit
f59231a
·
1 Parent(s): cd59fd4

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +150 -0
README.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 使用[Firefly](https://github.com/yangjianxin1/Firefly)项目微调internlm-7b。训练数据约为一百万多轮对话数据,包括项目分享的moss数据+2万条school math数据。
2
+
3
+ 训练loss:
4
+ ![firefly_logo](internlm-loss.jpg)
5
+
6
+ 更多详情见项目:[Firefly](https://github.com/yangjianxin1/Firefly)
7
+
8
+
9
+ 单轮对话:
10
+ ```python
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ import torch
13
+ """
14
+ 单轮对话,不具有对话历史的记忆功能
15
+ """
16
+
17
+
18
+ def main():
19
+ model_name = 'YeungNLP/firefly-internlm-7b'
20
+
21
+ max_new_tokens = 500
22
+ top_p = 0.9
23
+ temperature = 0.35
24
+ repetition_penalty = 1.0
25
+ device = 'cuda'
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ trust_remote_code=True,
29
+ low_cpu_mem_usage=True,
30
+ torch_dtype=torch.float16,
31
+ device_map='auto'
32
+ ).to(device).eval()
33
+ tokenizer = AutoTokenizer.from_pretrained(
34
+ model_name,
35
+ trust_remote_code=True,
36
+ # llama不支持fast
37
+ use_fast=False if model.config.model_type == 'llama' else True
38
+ )
39
+ # QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
40
+ if tokenizer.__class__.__name__ == 'QWenTokenizer':
41
+ tokenizer.pad_token_id = tokenizer.eod_id
42
+ tokenizer.bos_token_id = tokenizer.eod_id
43
+ tokenizer.eos_token_id = tokenizer.eod_id
44
+
45
+ text = input('User:')
46
+ while True:
47
+ text = text.strip()
48
+ # chatglm使用官方的数据组织格式
49
+ if model.config.model_type == 'chatglm':
50
+ text = '[Round 1]\n\n问:{}\n\n答:'.format(text)
51
+ input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
52
+ # 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_id
53
+ else:
54
+ input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
55
+ bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
56
+ eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
57
+ input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
58
+ with torch.no_grad():
59
+ outputs = model.generate(
60
+ input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
61
+ top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
62
+ eos_token_id=tokenizer.eos_token_id
63
+ )
64
+ outputs = outputs.tolist()[0][len(input_ids[0]):]
65
+ response = tokenizer.decode(outputs)
66
+ response = response.strip().replace(tokenizer.eos_token, "").strip()
67
+ print("Firefly:{}".format(response))
68
+ text = input('User:')
69
+
70
+
71
+ if __name__ == '__main__':
72
+ main()
73
+ ```
74
+
75
+
76
+ 多轮对话:
77
+ ```python
78
+ from transformers import AutoModelForCausalLM, AutoTokenizer
79
+ import torch
80
+
81
+
82
+ def main():
83
+ model_name = 'YeungNLP/firefly-internlm-7b'
84
+
85
+ device = 'cuda'
86
+ max_new_tokens = 500 # 每轮对话最多生成多少个token
87
+ history_max_len = 1000 # 模型记忆的最大token长度
88
+ top_p = 0.9
89
+ temperature = 0.35
90
+ repetition_penalty = 1.0
91
+
92
+ # 加载模型
93
+ model = AutoModelForCausalLM.from_pretrained(
94
+ model_name,
95
+ trust_remote_code=True,
96
+ low_cpu_mem_usage=True,
97
+ torch_dtype=torch.float16,
98
+ device_map='auto'
99
+ ).to(device).eval()
100
+ tokenizer = AutoTokenizer.from_pretrained(
101
+ model_name,
102
+ trust_remote_code=True,
103
+ # llama不支持fast
104
+ use_fast=False if model.config.model_type == 'llama' else True
105
+ )
106
+ # QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
107
+ if tokenizer.__class__.__name__ == 'QWenTokenizer':
108
+ tokenizer.pad_token_id = tokenizer.eod_id
109
+ tokenizer.bos_token_id = tokenizer.eod_id
110
+ tokenizer.eos_token_id = tokenizer.eod_id
111
+
112
+ # 记录所有历史记录
113
+ if model.config.model_type != 'chatglm':
114
+ history_token_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long)
115
+ else:
116
+ history_token_ids = torch.tensor([[]], dtype=torch.long)
117
+
118
+ # 开始对话
119
+ utterance_id = 0 # 记录当前是第几轮对话,为了契合chatglm的数据组织格式
120
+ user_input = input('User:')
121
+ while True:
122
+ utterance_id += 1
123
+ # chatglm使用官方的数据组织格式
124
+ if model.config.model_type == 'chatglm':
125
+ user_input = '[Round {}]\n\n问:{}\n\n答:'.format(utterance_id, user_input)
126
+ user_input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
127
+ # firefly的数据组织格式
128
+ # 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_id
129
+ else:
130
+ input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
131
+ eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
132
+ user_input_ids = torch.concat([input_ids, eos_token_id], dim=1)
133
+ history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
134
+ model_input_ids = history_token_ids[:, -history_max_len:].to(device)
135
+ with torch.no_grad():
136
+ outputs = model.generate(
137
+ input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
138
+ temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
139
+ )
140
+ model_input_ids_len = model_input_ids.size(1)
141
+ response_ids = outputs[:, model_input_ids_len:]
142
+ history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
143
+ response = tokenizer.batch_decode(response_ids)
144
+ print("Firefly:" + response[0].strip().replace(tokenizer.eos_token, ""))
145
+ user_input = input('User:')
146
+
147
+
148
+ if __name__ == '__main__':
149
+ main()
150
+ ```