import os
import gc
from string import Template
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer
tokenizer = AutoTokenizer.from_pretrained(
"PY007/LiteChat-Preview",
)
model = AutoModelForCausalLM.from_pretrained(
"PY007/LiteChat-Preview",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16
)
model.eval()
max_context_length = model.config.max_position_embeddings
max_new_tokens = 1024
prompt_template = Template("""\
### Instruction: $human
### Response: $bot\
""")
system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)
def bot(history):
history = history or []
# Inject prompt formatting into the history
prompt_history = []
for human, bot in history:
if bot is not None:
bot = bot.replace("
", "\n")
bot = bot.rstrip()
prompt_history.append(
prompt_template.substitute(
human=human, bot=bot if bot is not None else "")
)
msg_tokens = tokenizer(
"\n\n".join(prompt_history).strip(),
return_tensors="pt",
add_special_tokens=False # Use from the system prompt
)
# Take only the most recent context up to the max context length and prepend the
# system prompt with the messages
max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
inputs = BatchEncoding({
k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
for k in msg_tokens
}).to('cuda')
# inputs = BatchEncoding({
# k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
# for k in msg_tokens
# })
# Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
if inputs.get("token_type_ids", None) is not None:
inputs.pop("token_type_ids")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.7,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
# Process out the prompt separator
new_text = new_text.replace("
", "\n")
if "###" in new_text:
new_text = new_text.split("###")[0]
partial_text += new_text.strip()
history[-1][1] = partial_text
break
else:
# Filter empty trailing new lines
if new_text == "\n":
new_text = new_text.strip()
partial_text += new_text
history[-1][1] = partial_text
yield history
return partial_text
def user(user_message, history):
return "", history + [[user_message, None]]
with gr.Blocks() as demo:
gr.Markdown("# SLM-Chat by Zhang Peiyuan, StatNLP")
gr.HTML("PY007/SLM-Alpaca-Finetuned-Preview
")
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
state = gr.State([])
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Send a message",
placeholder="Send a message",
show_label=False
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Send")
stop = gr.Button("Stop")
clear = gr.Button("Clear History")
submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
clear.click(lambda: None, None, [chatbot], queue=True)
demo.queue(max_size=32)
demo.launch()