FAYO commited on
Commit
dda1f4a
·
1 Parent(s): 3f8757d

add:xtuner

Browse files
.ipynb_checkpoints/change_script-checkpoint.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ from tqdm import tqdm
4
+
5
+ def process_line(line, old_text, new_text):
6
+ # 解析 JSON 行
7
+ data = json.loads(line)
8
+
9
+ # 递归函数来处理嵌套的字典和列表
10
+ def replace_text(obj):
11
+ if isinstance(obj, dict):
12
+ return {k: replace_text(v) for k, v in obj.items()}
13
+ elif isinstance(obj, list):
14
+ return [replace_text(item) for item in obj]
15
+ elif isinstance(obj, str):
16
+ return obj.replace(old_text, new_text)
17
+ else:
18
+ return obj
19
+
20
+ # 处理整个 JSON 对象
21
+ processed_data = replace_text(data)
22
+
23
+ # 将处理后的对象转回 JSON 字符串
24
+ return json.dumps(processed_data, ensure_ascii=False)
25
+
26
+ def main(input_file, output_file, old_text, new_text):
27
+ with open(input_file, 'r', encoding='utf-8') as infile, \
28
+ open(output_file, 'w', encoding='utf-8') as outfile:
29
+
30
+ # 计算总行数用于进度条
31
+ total_lines = sum(1 for _ in infile)
32
+ infile.seek(0) # 重置文件指针到开头
33
+
34
+ # 使用 tqdm 创建进度条
35
+ for line in tqdm(infile, total=total_lines, desc="Processing"):
36
+ processed_line = process_line(line.strip(), old_text, new_text)
37
+ outfile.write(processed_line + '\n')
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser(description="Replace text in a JSONL file.")
41
+ parser.add_argument("input_file", help="Input JSONL file to process")
42
+ parser.add_argument("output_file", help="Output file for processed JSONL")
43
+ parser.add_argument("--old_text", default="尖米", help="Text to be replaced")
44
+ parser.add_argument("--new_text", default="FAYO", help="Text to replace with")
45
+ args = parser.parse_args()
46
+
47
+ main(args.input_file, args.output_file, args.old_text, args.new_text)
.ipynb_checkpoints/xtuner_streamlit_demo-checkpoint.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script refers to the dialogue example of streamlit, the interactive
2
+ generation code of chatglm2 and transformers.
3
+
4
+ We mainly modified part of the code logic to adapt to the
5
+ generation of our model.
6
+ Please refer to these links below for more information:
7
+ 1. streamlit chat example:
8
+ https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
9
+ 2. chatglm2:
10
+ https://github.com/THUDM/ChatGLM2-6B
11
+ 3. transformers:
12
+ https://github.com/huggingface/transformers
13
+ Please run with the command `streamlit run path/to/web_demo.py
14
+ --server.address=0.0.0.0 --server.port 7860`.
15
+ Using `python path/to/web_demo.py` may cause unknown problems.
16
+ """
17
+ # isort: skip_file
18
+ import copy
19
+ import warnings
20
+ from dataclasses import asdict, dataclass
21
+ from typing import Callable, List, Optional
22
+
23
+ import streamlit as st
24
+ import torch
25
+ from torch import nn
26
+ from transformers.generation.utils import (LogitsProcessorList,
27
+ StoppingCriteriaList)
28
+ from transformers.utils import logging
29
+
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
31
+
32
+ logger = logging.get_logger(__name__)
33
+ model_name_or_path="/root/finetune/work_dirs/assistTuner/merged"
34
+
35
+ @dataclass
36
+ class GenerationConfig:
37
+ # this config is used for chat to provide more diversity
38
+ max_length: int = 32768
39
+ top_p: float = 0.8
40
+ temperature: float = 0.8
41
+ do_sample: bool = True
42
+ repetition_penalty: float = 1.005
43
+
44
+
45
+ @torch.inference_mode()
46
+ def generate_interactive(
47
+ model,
48
+ tokenizer,
49
+ prompt,
50
+ generation_config: Optional[GenerationConfig] = None,
51
+ logits_processor: Optional[LogitsProcessorList] = None,
52
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
53
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
54
+ List[int]]] = None,
55
+ additional_eos_token_id: Optional[int] = None,
56
+ **kwargs,
57
+ ):
58
+ inputs = tokenizer([prompt], padding=True, return_tensors='pt')
59
+ input_length = len(inputs['input_ids'][0])
60
+ for k, v in inputs.items():
61
+ inputs[k] = v.cuda()
62
+ input_ids = inputs['input_ids']
63
+ _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
64
+ if generation_config is None:
65
+ generation_config = model.generation_config
66
+ generation_config = copy.deepcopy(generation_config)
67
+ model_kwargs = generation_config.update(**kwargs)
68
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
69
+ generation_config.bos_token_id,
70
+ generation_config.eos_token_id,
71
+ )
72
+ if isinstance(eos_token_id, int):
73
+ eos_token_id = [eos_token_id]
74
+ if additional_eos_token_id is not None:
75
+ eos_token_id.append(additional_eos_token_id)
76
+ has_default_max_length = kwargs.get(
77
+ 'max_length') is None and generation_config.max_length is not None
78
+ if has_default_max_length and generation_config.max_new_tokens is None:
79
+ warnings.warn(
80
+ f"Using 'max_length''s default \
81
+ ({repr(generation_config.max_length)}) \
82
+ to control the generation length. "
83
+ 'This behaviour is deprecated and will be removed from the \
84
+ config in v5 of Transformers -- we'
85
+ ' recommend using `max_new_tokens` to control the maximum \
86
+ length of the generation.',
87
+ UserWarning,
88
+ )
89
+ elif generation_config.max_new_tokens is not None:
90
+ generation_config.max_length = generation_config.max_new_tokens + \
91
+ input_ids_seq_length
92
+ if not has_default_max_length:
93
+ logger.warn( # pylint: disable=W4902
94
+ f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
95
+ f"and 'max_length'(={generation_config.max_length}) seem to "
96
+ "have been set. 'max_new_tokens' will take precedence. "
97
+ 'Please refer to the documentation for more information. '
98
+ '(https://huggingface.co/docs/transformers/main/'
99
+ 'en/main_classes/text_generation)',
100
+ UserWarning,
101
+ )
102
+
103
+ if input_ids_seq_length >= generation_config.max_length:
104
+ input_ids_string = 'input_ids'
105
+ logger.warning(
106
+ f'Input length of {input_ids_string} is {input_ids_seq_length}, '
107
+ f"but 'max_length' is set to {generation_config.max_length}. "
108
+ 'This can lead to unexpected behavior. You should consider'
109
+ " increasing 'max_new_tokens'.")
110
+
111
+ # 2. Set generation parameters if not already defined
112
+ logits_processor = logits_processor if logits_processor is not None \
113
+ else LogitsProcessorList()
114
+ stopping_criteria = stopping_criteria if stopping_criteria is not None \
115
+ else StoppingCriteriaList()
116
+
117
+ logits_processor = model._get_logits_processor(
118
+ generation_config=generation_config,
119
+ input_ids_seq_length=input_ids_seq_length,
120
+ encoder_input_ids=input_ids,
121
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
122
+ logits_processor=logits_processor,
123
+ )
124
+
125
+ stopping_criteria = model._get_stopping_criteria(
126
+ generation_config=generation_config,
127
+ stopping_criteria=stopping_criteria)
128
+ logits_warper = model._get_logits_warper(generation_config)
129
+
130
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
131
+ scores = None
132
+ while True:
133
+ model_inputs = model.prepare_inputs_for_generation(
134
+ input_ids, **model_kwargs)
135
+ # forward pass to get next token
136
+ outputs = model(
137
+ **model_inputs,
138
+ return_dict=True,
139
+ output_attentions=False,
140
+ output_hidden_states=False,
141
+ )
142
+
143
+ next_token_logits = outputs.logits[:, -1, :]
144
+
145
+ # pre-process distribution
146
+ next_token_scores = logits_processor(input_ids, next_token_logits)
147
+ next_token_scores = logits_warper(input_ids, next_token_scores)
148
+
149
+ # sample
150
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
151
+ if generation_config.do_sample:
152
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
153
+ else:
154
+ next_tokens = torch.argmax(probs, dim=-1)
155
+
156
+ # update generated ids, model inputs, and length for next step
157
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
158
+ model_kwargs = model._update_model_kwargs_for_generation(
159
+ outputs, model_kwargs, is_encoder_decoder=False)
160
+ unfinished_sequences = unfinished_sequences.mul(
161
+ (min(next_tokens != i for i in eos_token_id)).long())
162
+
163
+ output_token_ids = input_ids[0].cpu().tolist()
164
+ output_token_ids = output_token_ids[input_length:]
165
+ for each_eos_token_id in eos_token_id:
166
+ if output_token_ids[-1] == each_eos_token_id:
167
+ output_token_ids = output_token_ids[:-1]
168
+ response = tokenizer.decode(output_token_ids)
169
+
170
+ yield response
171
+ # stop when each sentence is finished
172
+ # or if we exceed the maximum length
173
+ if unfinished_sequences.max() == 0 or stopping_criteria(
174
+ input_ids, scores):
175
+ break
176
+
177
+
178
+ def on_btn_click():
179
+ del st.session_state.messages
180
+
181
+
182
+ @st.cache_resource
183
+ def load_model():
184
+ model = (AutoModelForCausalLM.from_pretrained(
185
+ model_name_or_path,
186
+ trust_remote_code=True).to(torch.bfloat16).cuda())
187
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
188
+ trust_remote_code=True)
189
+ return model, tokenizer
190
+
191
+
192
+ def prepare_generation_config():
193
+ with st.sidebar:
194
+ max_length = st.slider('Max Length',
195
+ min_value=8,
196
+ max_value=32768,
197
+ value=32768)
198
+ top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
199
+ temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
200
+ st.button('Clear Chat History', on_click=on_btn_click)
201
+
202
+ generation_config = GenerationConfig(max_length=max_length,
203
+ top_p=top_p,
204
+ temperature=temperature)
205
+
206
+ return generation_config
207
+
208
+
209
+ user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
210
+ robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
211
+ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
212
+ <|im_start|>assistant\n'
213
+
214
+
215
+ def combine_history(prompt):
216
+ messages = st.session_state.messages
217
+ meta_instruction = ('You are a helpful, honest, '
218
+ 'and harmless AI assistant.')
219
+ total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
220
+ for message in messages:
221
+ cur_content = message['content']
222
+ if message['role'] == 'user':
223
+ cur_prompt = user_prompt.format(user=cur_content)
224
+ elif message['role'] == 'robot':
225
+ cur_prompt = robot_prompt.format(robot=cur_content)
226
+ else:
227
+ raise RuntimeError
228
+ total_prompt += cur_prompt
229
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
230
+ return total_prompt
231
+
232
+
233
+ def main():
234
+ st.title('internlm2_5-7b-chat-assistant')
235
+
236
+ # torch.cuda.empty_cache()
237
+ print('load model begin.')
238
+ model, tokenizer = load_model()
239
+ print('load model end.')
240
+
241
+ generation_config = prepare_generation_config()
242
+
243
+ # Initialize chat history
244
+ if 'messages' not in st.session_state:
245
+ st.session_state.messages = []
246
+
247
+ # Display chat messages from history on app rerun
248
+ for message in st.session_state.messages:
249
+ with st.chat_message(message['role'], avatar=message.get('avatar')):
250
+ st.markdown(message['content'])
251
+
252
+ # Accept user input
253
+ if prompt := st.chat_input('What is up?'):
254
+ # Display user message in chat message container
255
+
256
+ with st.chat_message('user', avatar='user'):
257
+
258
+ st.markdown(prompt)
259
+ real_prompt = combine_history(prompt)
260
+ # Add user message to chat history
261
+ st.session_state.messages.append({
262
+ 'role': 'user',
263
+ 'content': prompt,
264
+ 'avatar': 'user'
265
+ })
266
+
267
+ with st.chat_message('robot', avatar='assistant'):
268
+
269
+ message_placeholder = st.empty()
270
+ for cur_response in generate_interactive(
271
+ model=model,
272
+ tokenizer=tokenizer,
273
+ prompt=real_prompt,
274
+ additional_eos_token_id=92542,
275
+ device='cuda:0',
276
+ **asdict(generation_config),
277
+ ):
278
+ # Display robot response in chat message container
279
+ message_placeholder.markdown(cur_response + '▌')
280
+ message_placeholder.markdown(cur_response)
281
+ # Add robot response to chat history
282
+ st.session_state.messages.append({
283
+ 'role': 'robot',
284
+ 'content': cur_response, # pylint: disable=undefined-loop-variable
285
+ 'avatar': 'assistant',
286
+ })
287
+ torch.cuda.empty_cache()
288
+
289
+
290
+ if __name__ == '__main__':
291
+ main()
292
+
assistant_Tuner.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
assistant_Tuner_change.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
change_script.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ from tqdm import tqdm
4
+
5
+ def process_line(line, old_text, new_text):
6
+ # 解析 JSON 行
7
+ data = json.loads(line)
8
+
9
+ # 递归函数来处理嵌套的字典和列表
10
+ def replace_text(obj):
11
+ if isinstance(obj, dict):
12
+ return {k: replace_text(v) for k, v in obj.items()}
13
+ elif isinstance(obj, list):
14
+ return [replace_text(item) for item in obj]
15
+ elif isinstance(obj, str):
16
+ return obj.replace(old_text, new_text)
17
+ else:
18
+ return obj
19
+
20
+ # 处理整个 JSON 对象
21
+ processed_data = replace_text(data)
22
+
23
+ # 将处理后的对象转回 JSON 字符串
24
+ return json.dumps(processed_data, ensure_ascii=False)
25
+
26
+ def main(input_file, output_file, old_text, new_text):
27
+ with open(input_file, 'r', encoding='utf-8') as infile, \
28
+ open(output_file, 'w', encoding='utf-8') as outfile:
29
+
30
+ # 计算总行数用于进度条
31
+ total_lines = sum(1 for _ in infile)
32
+ infile.seek(0) # 重置文件指针到开头
33
+
34
+ # 使用 tqdm 创建进度条
35
+ for line in tqdm(infile, total=total_lines, desc="Processing"):
36
+ processed_line = process_line(line.strip(), old_text, new_text)
37
+ outfile.write(processed_line + '\n')
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser(description="Replace text in a JSONL file.")
41
+ parser.add_argument("input_file", help="Input JSONL file to process")
42
+ parser.add_argument("output_file", help="Output file for processed JSONL")
43
+ parser.add_argument("--old_text", default="尖米", help="Text to be replaced")
44
+ parser.add_argument("--new_text", default="FAYO", help="Text to replace with")
45
+ args = parser.parse_args()
46
+
47
+ main(args.input_file, args.output_file, args.old_text, args.new_text)
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/root/finetune/models/internlm2_5-7b-chat",
3
+ "architectures": [
4
+ "InternLM2ForCausalLM"
5
+ ],
6
+ "attn_implementation": "eager",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
9
+ "AutoModel": "modeling_internlm2.InternLM2ForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM"
11
+ },
12
+ "bias": false,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 14336,
19
+ "max_position_embeddings": 32768,
20
+ "model_type": "internlm2",
21
+ "num_attention_heads": 32,
22
+ "num_hidden_layers": 32,
23
+ "num_key_value_heads": 8,
24
+ "pad_token_id": 2,
25
+ "pretraining_tp": 1,
26
+ "rms_norm_eps": 1e-05,
27
+ "rope_scaling": {
28
+ "factor": 2.0,
29
+ "type": "dynamic"
30
+ },
31
+ "rope_theta": 1000000,
32
+ "tie_word_embeddings": false,
33
+ "torch_dtype": "float16",
34
+ "transformers_version": "4.39.0",
35
+ "use_cache": true,
36
+ "vocab_size": 92544
37
+ }
internlm2_5_chat_7b_qlora_alpaca_e3_copy.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from datasets import load_dataset
4
+ from mmengine.dataset import DefaultSampler
5
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
6
+ LoggerHook, ParamSchedulerHook)
7
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
8
+ from peft import LoraConfig
9
+ from torch.optim import AdamW
10
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
11
+ BitsAndBytesConfig)
12
+
13
+ from xtuner.dataset import process_hf_dataset
14
+ from xtuner.dataset.collate_fns import default_collate_fn
15
+ from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
16
+ from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
17
+ VarlenAttnArgsToMessageHubHook)
18
+ from xtuner.engine.runner import TrainLoop
19
+ from xtuner.model import SupervisedFinetune
20
+ from xtuner.parallel.sequence import SequenceParallelSampler
21
+ from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
22
+
23
+ #######################################################################
24
+ # PART 1 Settings #
25
+ #######################################################################
26
+ # Model
27
+ pretrained_model_name_or_path = '/root/finetune/models/internlm2_5-7b-chat'
28
+ use_varlen_attn = False
29
+
30
+ # Data
31
+ alpaca_en_path = '/root/finetune/data/assistant_Tuner_change.jsonl'
32
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
33
+ max_length = 2048
34
+ pack_to_max_length = True
35
+
36
+ # parallel
37
+ sequence_parallel_size = 1
38
+
39
+ # Scheduler & Optimizer
40
+ batch_size = 1 # per_device
41
+ accumulative_counts = 1
42
+ accumulative_counts *= sequence_parallel_size
43
+ dataloader_num_workers = 0
44
+ max_epochs = 3
45
+ optim_type = AdamW
46
+ lr = 2e-4
47
+ betas = (0.9, 0.999)
48
+ weight_decay = 0
49
+ max_norm = 1 # grad clip
50
+ warmup_ratio = 0.03
51
+
52
+ # Save
53
+ save_steps = 500
54
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
55
+
56
+ # Evaluate the generation performance during the training
57
+ evaluation_freq = 500
58
+ SYSTEM = SYSTEM_TEMPLATE.alpaca
59
+ evaluation_inputs = [
60
+ '请介绍一下你自己', 'Please introduce yourself'
61
+ ]
62
+
63
+ #######################################################################
64
+ # PART 2 Model & Tokenizer #
65
+ #######################################################################
66
+ tokenizer = dict(
67
+ type=AutoTokenizer.from_pretrained,
68
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
69
+ trust_remote_code=True,
70
+ padding_side='right')
71
+
72
+ model = dict(
73
+ type=SupervisedFinetune,
74
+ use_varlen_attn=use_varlen_attn,
75
+ llm=dict(
76
+ type=AutoModelForCausalLM.from_pretrained,
77
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
78
+ trust_remote_code=True,
79
+ torch_dtype=torch.float16,
80
+ quantization_config=dict(
81
+ type=BitsAndBytesConfig,
82
+ load_in_4bit=True,
83
+ load_in_8bit=False,
84
+ llm_int8_threshold=6.0,
85
+ llm_int8_has_fp16_weight=False,
86
+ bnb_4bit_compute_dtype=torch.float16,
87
+ bnb_4bit_use_double_quant=True,
88
+ bnb_4bit_quant_type='nf4')),
89
+ lora=dict(
90
+ type=LoraConfig,
91
+ r=64,
92
+ lora_alpha=16,
93
+ lora_dropout=0.1,
94
+ bias='none',
95
+ task_type='CAUSAL_LM'))
96
+
97
+ #######################################################################
98
+ # PART 3 Dataset & Dataloader #
99
+ #######################################################################
100
+ alpaca_en = dict(
101
+ type=process_hf_dataset,
102
+ dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),
103
+ tokenizer=tokenizer,
104
+ max_length=max_length,
105
+ dataset_map_fn=None,
106
+ template_map_fn=dict(
107
+ type=template_map_fn_factory, template=prompt_template),
108
+ remove_unused_columns=True,
109
+ shuffle_before_pack=True,
110
+ pack_to_max_length=pack_to_max_length,
111
+ use_varlen_attn=use_varlen_attn)
112
+
113
+ sampler = SequenceParallelSampler \
114
+ if sequence_parallel_size > 1 else DefaultSampler
115
+ train_dataloader = dict(
116
+ batch_size=batch_size,
117
+ num_workers=dataloader_num_workers,
118
+ dataset=alpaca_en,
119
+ sampler=dict(type=sampler, shuffle=True),
120
+ collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
121
+
122
+ #######################################################################
123
+ # PART 4 Scheduler & Optimizer #
124
+ #######################################################################
125
+ # optimizer
126
+ optim_wrapper = dict(
127
+ type=AmpOptimWrapper,
128
+ optimizer=dict(
129
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
130
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
131
+ accumulative_counts=accumulative_counts,
132
+ loss_scale='dynamic',
133
+ dtype='float16')
134
+
135
+ # learning policy
136
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
137
+ param_scheduler = [
138
+ dict(
139
+ type=LinearLR,
140
+ start_factor=1e-5,
141
+ by_epoch=True,
142
+ begin=0,
143
+ end=warmup_ratio * max_epochs,
144
+ convert_to_iter_based=True),
145
+ dict(
146
+ type=CosineAnnealingLR,
147
+ eta_min=0.0,
148
+ by_epoch=True,
149
+ begin=warmup_ratio * max_epochs,
150
+ end=max_epochs,
151
+ convert_to_iter_based=True)
152
+ ]
153
+
154
+ # train, val, test setting
155
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
156
+
157
+ #######################################################################
158
+ # PART 5 Runtime #
159
+ #######################################################################
160
+ # Log the dialogue periodically during the training process, optional
161
+ custom_hooks = [
162
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
163
+ dict(
164
+ type=EvaluateChatHook,
165
+ tokenizer=tokenizer,
166
+ every_n_iters=evaluation_freq,
167
+ evaluation_inputs=evaluation_inputs,
168
+ system=SYSTEM,
169
+ prompt_template=prompt_template)
170
+ ]
171
+
172
+ if use_varlen_attn:
173
+ custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
174
+
175
+ # configure default hooks
176
+ default_hooks = dict(
177
+ # record the time of every iteration.
178
+ timer=dict(type=IterTimerHook),
179
+ # print log every 10 iterations.
180
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
181
+ # enable the parameter scheduler.
182
+ param_scheduler=dict(type=ParamSchedulerHook),
183
+ # save checkpoint per `save_steps`.
184
+ checkpoint=dict(
185
+ type=CheckpointHook,
186
+ by_epoch=False,
187
+ interval=save_steps,
188
+ max_keep_ckpts=save_total_limit),
189
+ # set sampler seed in distributed evrionment.
190
+ sampler_seed=dict(type=DistSamplerSeedHook),
191
+ )
192
+
193
+ # configure environment
194
+ env_cfg = dict(
195
+ # whether to enable cudnn benchmark
196
+ cudnn_benchmark=False,
197
+ # set multi process parameters
198
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
199
+ # set distributed parameters
200
+ dist_cfg=dict(backend='nccl'),
201
+ )
202
+
203
+ # set visualizer
204
+ visualizer = None
205
+
206
+ # set log level
207
+ log_level = 'INFO'
208
+
209
+ # load from which checkpoint
210
+ load_from = None
211
+
212
+ # whether to resume training from the loaded checkpoint
213
+ resume = False
214
+
215
+ # Defaults to use random seed and disable `deterministic`
216
+ randomness = dict(seed=None, deterministic=False)
217
+
218
+ # set log processor
219
+ log_processor = dict(by_epoch=False)
220
+
xtuner_streamlit_demo.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script refers to the dialogue example of streamlit, the interactive
2
+ generation code of chatglm2 and transformers.
3
+
4
+ We mainly modified part of the code logic to adapt to the
5
+ generation of our model.
6
+ Please refer to these links below for more information:
7
+ 1. streamlit chat example:
8
+ https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
9
+ 2. chatglm2:
10
+ https://github.com/THUDM/ChatGLM2-6B
11
+ 3. transformers:
12
+ https://github.com/huggingface/transformers
13
+ Please run with the command `streamlit run path/to/web_demo.py
14
+ --server.address=0.0.0.0 --server.port 7860`.
15
+ Using `python path/to/web_demo.py` may cause unknown problems.
16
+ """
17
+ # isort: skip_file
18
+ import copy
19
+ import warnings
20
+ from dataclasses import asdict, dataclass
21
+ from typing import Callable, List, Optional
22
+
23
+ import streamlit as st
24
+ import torch
25
+ from torch import nn
26
+ from transformers.generation.utils import (LogitsProcessorList,
27
+ StoppingCriteriaList)
28
+ from transformers.utils import logging
29
+
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
31
+
32
+ logger = logging.get_logger(__name__)
33
+ model_name_or_path="/root/finetune/work_dirs/assistTuner/merged"
34
+
35
+ @dataclass
36
+ class GenerationConfig:
37
+ # this config is used for chat to provide more diversity
38
+ max_length: int = 32768
39
+ top_p: float = 0.8
40
+ temperature: float = 0.8
41
+ do_sample: bool = True
42
+ repetition_penalty: float = 1.005
43
+
44
+
45
+ @torch.inference_mode()
46
+ def generate_interactive(
47
+ model,
48
+ tokenizer,
49
+ prompt,
50
+ generation_config: Optional[GenerationConfig] = None,
51
+ logits_processor: Optional[LogitsProcessorList] = None,
52
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
53
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
54
+ List[int]]] = None,
55
+ additional_eos_token_id: Optional[int] = None,
56
+ **kwargs,
57
+ ):
58
+ inputs = tokenizer([prompt], padding=True, return_tensors='pt')
59
+ input_length = len(inputs['input_ids'][0])
60
+ for k, v in inputs.items():
61
+ inputs[k] = v.cuda()
62
+ input_ids = inputs['input_ids']
63
+ _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
64
+ if generation_config is None:
65
+ generation_config = model.generation_config
66
+ generation_config = copy.deepcopy(generation_config)
67
+ model_kwargs = generation_config.update(**kwargs)
68
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
69
+ generation_config.bos_token_id,
70
+ generation_config.eos_token_id,
71
+ )
72
+ if isinstance(eos_token_id, int):
73
+ eos_token_id = [eos_token_id]
74
+ if additional_eos_token_id is not None:
75
+ eos_token_id.append(additional_eos_token_id)
76
+ has_default_max_length = kwargs.get(
77
+ 'max_length') is None and generation_config.max_length is not None
78
+ if has_default_max_length and generation_config.max_new_tokens is None:
79
+ warnings.warn(
80
+ f"Using 'max_length''s default \
81
+ ({repr(generation_config.max_length)}) \
82
+ to control the generation length. "
83
+ 'This behaviour is deprecated and will be removed from the \
84
+ config in v5 of Transformers -- we'
85
+ ' recommend using `max_new_tokens` to control the maximum \
86
+ length of the generation.',
87
+ UserWarning,
88
+ )
89
+ elif generation_config.max_new_tokens is not None:
90
+ generation_config.max_length = generation_config.max_new_tokens + \
91
+ input_ids_seq_length
92
+ if not has_default_max_length:
93
+ logger.warn( # pylint: disable=W4902
94
+ f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
95
+ f"and 'max_length'(={generation_config.max_length}) seem to "
96
+ "have been set. 'max_new_tokens' will take precedence. "
97
+ 'Please refer to the documentation for more information. '
98
+ '(https://huggingface.co/docs/transformers/main/'
99
+ 'en/main_classes/text_generation)',
100
+ UserWarning,
101
+ )
102
+
103
+ if input_ids_seq_length >= generation_config.max_length:
104
+ input_ids_string = 'input_ids'
105
+ logger.warning(
106
+ f'Input length of {input_ids_string} is {input_ids_seq_length}, '
107
+ f"but 'max_length' is set to {generation_config.max_length}. "
108
+ 'This can lead to unexpected behavior. You should consider'
109
+ " increasing 'max_new_tokens'.")
110
+
111
+ # 2. Set generation parameters if not already defined
112
+ logits_processor = logits_processor if logits_processor is not None \
113
+ else LogitsProcessorList()
114
+ stopping_criteria = stopping_criteria if stopping_criteria is not None \
115
+ else StoppingCriteriaList()
116
+
117
+ logits_processor = model._get_logits_processor(
118
+ generation_config=generation_config,
119
+ input_ids_seq_length=input_ids_seq_length,
120
+ encoder_input_ids=input_ids,
121
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
122
+ logits_processor=logits_processor,
123
+ )
124
+
125
+ stopping_criteria = model._get_stopping_criteria(
126
+ generation_config=generation_config,
127
+ stopping_criteria=stopping_criteria)
128
+ logits_warper = model._get_logits_warper(generation_config)
129
+
130
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
131
+ scores = None
132
+ while True:
133
+ model_inputs = model.prepare_inputs_for_generation(
134
+ input_ids, **model_kwargs)
135
+ # forward pass to get next token
136
+ outputs = model(
137
+ **model_inputs,
138
+ return_dict=True,
139
+ output_attentions=False,
140
+ output_hidden_states=False,
141
+ )
142
+
143
+ next_token_logits = outputs.logits[:, -1, :]
144
+
145
+ # pre-process distribution
146
+ next_token_scores = logits_processor(input_ids, next_token_logits)
147
+ next_token_scores = logits_warper(input_ids, next_token_scores)
148
+
149
+ # sample
150
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
151
+ if generation_config.do_sample:
152
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
153
+ else:
154
+ next_tokens = torch.argmax(probs, dim=-1)
155
+
156
+ # update generated ids, model inputs, and length for next step
157
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
158
+ model_kwargs = model._update_model_kwargs_for_generation(
159
+ outputs, model_kwargs, is_encoder_decoder=False)
160
+ unfinished_sequences = unfinished_sequences.mul(
161
+ (min(next_tokens != i for i in eos_token_id)).long())
162
+
163
+ output_token_ids = input_ids[0].cpu().tolist()
164
+ output_token_ids = output_token_ids[input_length:]
165
+ for each_eos_token_id in eos_token_id:
166
+ if output_token_ids[-1] == each_eos_token_id:
167
+ output_token_ids = output_token_ids[:-1]
168
+ response = tokenizer.decode(output_token_ids)
169
+
170
+ yield response
171
+ # stop when each sentence is finished
172
+ # or if we exceed the maximum length
173
+ if unfinished_sequences.max() == 0 or stopping_criteria(
174
+ input_ids, scores):
175
+ break
176
+
177
+
178
+ def on_btn_click():
179
+ del st.session_state.messages
180
+
181
+
182
+ @st.cache_resource
183
+ def load_model():
184
+ model = (AutoModelForCausalLM.from_pretrained(
185
+ model_name_or_path,
186
+ trust_remote_code=True).to(torch.bfloat16).cuda())
187
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
188
+ trust_remote_code=True)
189
+ return model, tokenizer
190
+
191
+
192
+ def prepare_generation_config():
193
+ with st.sidebar:
194
+ max_length = st.slider('Max Length',
195
+ min_value=8,
196
+ max_value=32768,
197
+ value=32768)
198
+ top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
199
+ temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
200
+ st.button('Clear Chat History', on_click=on_btn_click)
201
+
202
+ generation_config = GenerationConfig(max_length=max_length,
203
+ top_p=top_p,
204
+ temperature=temperature)
205
+
206
+ return generation_config
207
+
208
+
209
+ user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
210
+ robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
211
+ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
212
+ <|im_start|>assistant\n'
213
+
214
+
215
+ def combine_history(prompt):
216
+ messages = st.session_state.messages
217
+ meta_instruction = ('You are a helpful, honest, '
218
+ 'and harmless AI assistant.')
219
+ total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
220
+ for message in messages:
221
+ cur_content = message['content']
222
+ if message['role'] == 'user':
223
+ cur_prompt = user_prompt.format(user=cur_content)
224
+ elif message['role'] == 'robot':
225
+ cur_prompt = robot_prompt.format(robot=cur_content)
226
+ else:
227
+ raise RuntimeError
228
+ total_prompt += cur_prompt
229
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
230
+ return total_prompt
231
+
232
+
233
+ def main():
234
+ st.title('internlm2_5-7b-chat-assistant')
235
+
236
+ # torch.cuda.empty_cache()
237
+ print('load model begin.')
238
+ model, tokenizer = load_model()
239
+ print('load model end.')
240
+
241
+ generation_config = prepare_generation_config()
242
+
243
+ # Initialize chat history
244
+ if 'messages' not in st.session_state:
245
+ st.session_state.messages = []
246
+
247
+ # Display chat messages from history on app rerun
248
+ for message in st.session_state.messages:
249
+ with st.chat_message(message['role'], avatar=message.get('avatar')):
250
+ st.markdown(message['content'])
251
+
252
+ # Accept user input
253
+ if prompt := st.chat_input('What is up?'):
254
+ # Display user message in chat message container
255
+
256
+ with st.chat_message('user', avatar='user'):
257
+
258
+ st.markdown(prompt)
259
+ real_prompt = combine_history(prompt)
260
+ # Add user message to chat history
261
+ st.session_state.messages.append({
262
+ 'role': 'user',
263
+ 'content': prompt,
264
+ 'avatar': 'user'
265
+ })
266
+
267
+ with st.chat_message('robot', avatar='assistant'):
268
+
269
+ message_placeholder = st.empty()
270
+ for cur_response in generate_interactive(
271
+ model=model,
272
+ tokenizer=tokenizer,
273
+ prompt=real_prompt,
274
+ additional_eos_token_id=92542,
275
+ device='cuda:0',
276
+ **asdict(generation_config),
277
+ ):
278
+ # Display robot response in chat message container
279
+ message_placeholder.markdown(cur_response + '▌')
280
+ message_placeholder.markdown(cur_response)
281
+ # Add robot response to chat history
282
+ st.session_state.messages.append({
283
+ 'role': 'robot',
284
+ 'content': cur_response, # pylint: disable=undefined-loop-variable
285
+ 'avatar': 'assistant',
286
+ })
287
+ torch.cuda.empty_cache()
288
+
289
+
290
+ if __name__ == '__main__':
291
+ main()
292
+