|
"""This script refers to the dialogue example of streamlit, the interactive |
|
generation code of chatglm2 and transformers. |
|
|
|
We mainly modified part of the code logic to adapt to the |
|
generation of our model. |
|
Please refer to these links below for more information: |
|
1. streamlit chat example: |
|
https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps |
|
2. chatglm2: |
|
https://github.com/THUDM/ChatGLM2-6B |
|
3. transformers: |
|
https://github.com/huggingface/transformers |
|
Please run with the command `streamlit run path/to/web_demo.py |
|
--server.address=0.0.0.0 --server.port 7860`. |
|
Using `python path/to/web_demo.py` may cause unknown problems. |
|
""" |
|
|
|
import copy |
|
import warnings |
|
from dataclasses import asdict, dataclass |
|
from typing import Callable, List, Optional |
|
|
|
import streamlit as st |
|
import torch |
|
from torch import nn |
|
from transformers.generation.utils import (LogitsProcessorList, |
|
StoppingCriteriaList) |
|
from transformers.utils import logging |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
logger = logging.get_logger(__name__) |
|
model_name_or_path="/root/finetune/work_dirs/assistTuner/merged" |
|
|
|
@dataclass |
|
class GenerationConfig: |
|
|
|
max_length: int = 32768 |
|
top_p: float = 0.8 |
|
temperature: float = 0.8 |
|
do_sample: bool = True |
|
repetition_penalty: float = 1.005 |
|
|
|
|
|
@torch.inference_mode() |
|
def generate_interactive( |
|
model, |
|
tokenizer, |
|
prompt, |
|
generation_config: Optional[GenerationConfig] = None, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], |
|
List[int]]] = None, |
|
additional_eos_token_id: Optional[int] = None, |
|
**kwargs, |
|
): |
|
inputs = tokenizer([prompt], padding=True, return_tensors='pt') |
|
input_length = len(inputs['input_ids'][0]) |
|
for k, v in inputs.items(): |
|
inputs[k] = v.cuda() |
|
input_ids = inputs['input_ids'] |
|
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] |
|
if generation_config is None: |
|
generation_config = model.generation_config |
|
generation_config = copy.deepcopy(generation_config) |
|
model_kwargs = generation_config.update(**kwargs) |
|
bos_token_id, eos_token_id = ( |
|
generation_config.bos_token_id, |
|
generation_config.eos_token_id, |
|
) |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
if additional_eos_token_id is not None: |
|
eos_token_id.append(additional_eos_token_id) |
|
has_default_max_length = kwargs.get( |
|
'max_length') is None and generation_config.max_length is not None |
|
if has_default_max_length and generation_config.max_new_tokens is None: |
|
warnings.warn( |
|
f"Using 'max_length''s default \ |
|
({repr(generation_config.max_length)}) \ |
|
to control the generation length. " |
|
'This behaviour is deprecated and will be removed from the \ |
|
config in v5 of Transformers -- we' |
|
' recommend using `max_new_tokens` to control the maximum \ |
|
length of the generation.', |
|
UserWarning, |
|
) |
|
elif generation_config.max_new_tokens is not None: |
|
generation_config.max_length = generation_config.max_new_tokens + \ |
|
input_ids_seq_length |
|
if not has_default_max_length: |
|
logger.warn( |
|
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " |
|
f"and 'max_length'(={generation_config.max_length}) seem to " |
|
"have been set. 'max_new_tokens' will take precedence. " |
|
'Please refer to the documentation for more information. ' |
|
'(https://huggingface.co/docs/transformers/main/' |
|
'en/main_classes/text_generation)', |
|
UserWarning, |
|
) |
|
|
|
if input_ids_seq_length >= generation_config.max_length: |
|
input_ids_string = 'input_ids' |
|
logger.warning( |
|
f'Input length of {input_ids_string} is {input_ids_seq_length}, ' |
|
f"but 'max_length' is set to {generation_config.max_length}. " |
|
'This can lead to unexpected behavior. You should consider' |
|
" increasing 'max_new_tokens'.") |
|
|
|
|
|
logits_processor = logits_processor if logits_processor is not None \ |
|
else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None \ |
|
else StoppingCriteriaList() |
|
|
|
logits_processor = model._get_logits_processor( |
|
generation_config=generation_config, |
|
input_ids_seq_length=input_ids_seq_length, |
|
encoder_input_ids=input_ids, |
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
|
logits_processor=logits_processor, |
|
) |
|
|
|
stopping_criteria = model._get_stopping_criteria( |
|
generation_config=generation_config, |
|
stopping_criteria=stopping_criteria) |
|
logits_warper = model._get_logits_warper(generation_config) |
|
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
|
scores = None |
|
while True: |
|
model_inputs = model.prepare_inputs_for_generation( |
|
input_ids, **model_kwargs) |
|
|
|
outputs = model( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
if generation_config.do_sample: |
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
else: |
|
next_tokens = torch.argmax(probs, dim=-1) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
model_kwargs = model._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=False) |
|
unfinished_sequences = unfinished_sequences.mul( |
|
(min(next_tokens != i for i in eos_token_id)).long()) |
|
|
|
output_token_ids = input_ids[0].cpu().tolist() |
|
output_token_ids = output_token_ids[input_length:] |
|
for each_eos_token_id in eos_token_id: |
|
if output_token_ids[-1] == each_eos_token_id: |
|
output_token_ids = output_token_ids[:-1] |
|
response = tokenizer.decode(output_token_ids) |
|
|
|
yield response |
|
|
|
|
|
if unfinished_sequences.max() == 0 or stopping_criteria( |
|
input_ids, scores): |
|
break |
|
|
|
|
|
def on_btn_click(): |
|
del st.session_state.messages |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = (AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
trust_remote_code=True).to(torch.bfloat16).cuda()) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, |
|
trust_remote_code=True) |
|
return model, tokenizer |
|
|
|
|
|
def prepare_generation_config(): |
|
with st.sidebar: |
|
max_length = st.slider('Max Length', |
|
min_value=8, |
|
max_value=32768, |
|
value=32768) |
|
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01) |
|
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01) |
|
st.button('Clear Chat History', on_click=on_btn_click) |
|
|
|
generation_config = GenerationConfig(max_length=max_length, |
|
top_p=top_p, |
|
temperature=temperature) |
|
|
|
return generation_config |
|
|
|
|
|
user_prompt = '<|im_start|>user\n{user}<|im_end|>\n' |
|
robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n' |
|
cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\ |
|
<|im_start|>assistant\n' |
|
|
|
|
|
def combine_history(prompt): |
|
messages = st.session_state.messages |
|
meta_instruction = ('You are a helpful, honest, ' |
|
'and harmless AI assistant.') |
|
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n' |
|
for message in messages: |
|
cur_content = message['content'] |
|
if message['role'] == 'user': |
|
cur_prompt = user_prompt.format(user=cur_content) |
|
elif message['role'] == 'robot': |
|
cur_prompt = robot_prompt.format(robot=cur_content) |
|
else: |
|
raise RuntimeError |
|
total_prompt += cur_prompt |
|
total_prompt = total_prompt + cur_query_prompt.format(user=prompt) |
|
return total_prompt |
|
|
|
|
|
def main(): |
|
st.title('internlm2_5-7b-chat-assistant') |
|
|
|
|
|
print('load model begin.') |
|
model, tokenizer = load_model() |
|
print('load model end.') |
|
|
|
generation_config = prepare_generation_config() |
|
|
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message['role'], avatar=message.get('avatar')): |
|
st.markdown(message['content']) |
|
|
|
|
|
if prompt := st.chat_input('What is up?'): |
|
|
|
|
|
with st.chat_message('user', avatar='user'): |
|
|
|
st.markdown(prompt) |
|
real_prompt = combine_history(prompt) |
|
|
|
st.session_state.messages.append({ |
|
'role': 'user', |
|
'content': prompt, |
|
'avatar': 'user' |
|
}) |
|
|
|
with st.chat_message('robot', avatar='assistant'): |
|
|
|
message_placeholder = st.empty() |
|
for cur_response in generate_interactive( |
|
model=model, |
|
tokenizer=tokenizer, |
|
prompt=real_prompt, |
|
additional_eos_token_id=92542, |
|
device='cuda:0', |
|
**asdict(generation_config), |
|
): |
|
|
|
message_placeholder.markdown(cur_response + '▌') |
|
message_placeholder.markdown(cur_response) |
|
|
|
st.session_state.messages.append({ |
|
'role': 'robot', |
|
'content': cur_response, |
|
'avatar': 'assistant', |
|
}) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|