import os import gc from string import Template from threading import Thread import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer tokenizer = AutoTokenizer.from_pretrained( "PY007/LiteChat-Preview", ) model = AutoModelForCausalLM.from_pretrained( "PY007/LiteChat-Preview", trust_remote_code=True, device_map="auto", torch_dtype=torch.float16 ) model.eval() max_context_length = model.config.max_position_embeddings max_new_tokens = 1024 prompt_template = Template("""\ ### Instruction: $human ### Response: $bot\ """) system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt") max_sys_tokens = system_prompt_tokens['input_ids'].size(-1) def bot(history): history = history or [] # Inject prompt formatting into the history prompt_history = [] for human, bot in history: if bot is not None: bot = bot.replace("
", "\n") bot = bot.rstrip() prompt_history.append( prompt_template.substitute( human=human, bot=bot if bot is not None else "") ) msg_tokens = tokenizer( "\n\n".join(prompt_history).strip(), return_tensors="pt", add_special_tokens=False # Use from the system prompt ) # Take only the most recent context up to the max context length and prepend the # system prompt with the messages max_tokens = -max_context_length + max_new_tokens + max_sys_tokens inputs = BatchEncoding({ k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1) for k in msg_tokens }).to('cuda') # inputs = BatchEncoding({ # k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1) # for k in msg_tokens # }) # Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models if inputs.get("token_type_ids", None) is not None: inputs.pop("token_type_ids") streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=0.95, top_k=50, temperature=0.7, ) thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() partial_text = "" for new_text in streamer: # Process out the prompt separator new_text = new_text.replace("
", "\n") if "###" in new_text: new_text = new_text.split("###")[0] partial_text += new_text.strip() history[-1][1] = partial_text break else: # Filter empty trailing new lines if new_text == "\n": new_text = new_text.strip() partial_text += new_text history[-1][1] = partial_text yield history return partial_text def user(user_message, history): return "", history + [[user_message, None]] with gr.Blocks() as demo: gr.Markdown("# SLM-Chat by Zhang Peiyuan, StatNLP") gr.HTML("PY007/SLM-Alpaca-Finetuned-Preview") chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500) state = gr.State([]) with gr.Row(): with gr.Column(): msg = gr.Textbox( label="Send a message", placeholder="Send a message", show_label=False ).style(container=False) with gr.Column(): with gr.Row(): submit = gr.Button("Send") stop = gr.Button("Stop") clear = gr.Button("Clear History") submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True) submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True) stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False) clear.click(lambda: None, None, [chatbot], queue=True) demo.queue(max_size=32) demo.launch()