Spaces:
Sleeping
Sleeping
import os | |
import re | |
import json | |
import copy | |
import gradio as gr | |
from llama2 import GradioLLaMA2ChatPPManager | |
from llama2 import gen_text | |
from styles import MODEL_SELECTION_CSS | |
from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS | |
from templates import templates | |
from constants import DEFAULT_GLOBAL_CTX | |
from pingpong import PingPong | |
from pingpong.context import CtxLastWindowStrategy | |
TOKEN = os.getenv('HF_TOKEN') | |
MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf' | |
def build_prompts(ppmanager, global_context, win_size=3): | |
dummy_ppm = copy.deepcopy(ppmanager) | |
dummy_ppm.ctx = global_context | |
lws = CtxLastWindowStrategy(win_size) | |
return lws(dummy_ppm) | |
ex_file = open("examples.txt", "r") | |
examples = ex_file.read().split("\n") | |
ex_btns = [] | |
chl_file = open("channels.txt", "r") | |
channels = chl_file.read().split("\n") | |
channel_btns = [] | |
def get_placeholders(text): | |
"""Returns all substrings in between <placeholder> and </placeholder>.""" | |
pattern = r"\[([^\]]*)\]" | |
matches = re.findall(pattern, text) | |
return matches | |
def fill_up_placeholders(txt): | |
placeholders = get_placeholders(txt) | |
highlighted_txt = txt | |
return ( | |
gr.update( | |
visible=True, | |
value=highlighted_txt | |
), | |
gr.update( | |
visible=True if len(placeholders) >= 1 else False, | |
placeholder=placeholders[0] if len(placeholders) >= 1 else "" | |
), | |
gr.update( | |
visible=True if len(placeholders) >= 2 else False, | |
placeholder=placeholders[1] if len(placeholders) >= 2 else "" | |
), | |
gr.update( | |
visible=True if len(placeholders) >= 3 else False, | |
placeholder=placeholders[2] if len(placeholders) >= 3 else "" | |
), | |
"" if len(placeholders) >= 1 else txt | |
) | |
async def rollback_last( | |
idx, local_data, chat_state, | |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv | |
): | |
res = [ | |
chat_state["ppmanager_type"].from_json(json.dumps(ppm)) | |
for ppm in local_data | |
] | |
ppm = res[idx] | |
last_user_message = res[idx].pingpongs[-1].ping | |
res[idx].pingpongs = res[idx].pingpongs[:-1] | |
ppm.add_pingpong( | |
PingPong(last_user_message, "") | |
) | |
prompt = build_prompts(ppm, global_context, ctx_num_lconv) | |
async for result in gen_text( | |
prompt, hf_model=MODEL_ID, hf_token=TOKEN, | |
parameters={ | |
'max_new_tokens': res_mnts, | |
'do_sample': res_sample, | |
'return_full_text': False, | |
'temperature': res_temp, | |
'top_k': res_topk, | |
'repetition_penalty': res_rpen | |
} | |
): | |
ppm.append_pong(result) | |
yield "", prompt, ppm.build_uis(), str(res) | |
yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=True) | |
def reset_chat(idx, ld, state): | |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] | |
res[idx].pingpongs = [] | |
return ( | |
"", | |
[], | |
str(res), | |
gr.update(visible=True), | |
gr.update(interactive=False), | |
) | |
async def chat_stream( | |
idx, local_data, instruction_txtbox, chat_state, | |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv | |
): | |
res = [ | |
chat_state["ppmanager_type"].from_json(json.dumps(ppm)) | |
for ppm in local_data | |
] | |
ppm = res[idx] | |
ppm.add_pingpong( | |
PingPong(instruction_txtbox, "") | |
) | |
prompt = build_prompts(ppm, global_context, ctx_num_lconv) | |
async for result in gen_text( | |
prompt, hf_model=MODEL_ID, hf_token=TOKEN, | |
parameters={ | |
'max_new_tokens': res_mnts, | |
'do_sample': res_sample, | |
'return_full_text': False, | |
'temperature': res_temp, | |
'top_k': res_topk, | |
'repetition_penalty': res_rpen | |
} | |
): | |
ppm.append_pong(result) | |
yield "", prompt, ppm.build_uis(), str(res) | |
def channel_num(btn_title): | |
choice = 0 | |
for idx, channel in enumerate(channels): | |
if channel == btn_title: | |
choice = idx | |
return choice | |
def set_chatbot(btn, ld, state): | |
choice = channel_num(btn) | |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] | |
empty = len(res[choice].pingpongs) == 0 | |
return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty)) | |
def set_example(btn): | |
return btn, gr.update(visible=False) | |
def get_final_template( | |
txt, placeholder_txt1, placeholder_txt2, placeholder_txt3 | |
): | |
placeholders = get_placeholders(txt) | |
example_prompt = txt | |
if len(placeholders) >= 1: | |
if placeholder_txt1 != "": | |
example_prompt = example_prompt.replace(f"[{placeholders[0]}]", placeholder_txt1) | |
if len(placeholders) >= 2: | |
if placeholder_txt2 != "": | |
example_prompt = example_prompt.replace(f"[{placeholders[1]}]", placeholder_txt2) | |
if len(placeholders) >= 3: | |
if placeholder_txt3 != "": | |
example_prompt = example_prompt.replace(f"[{placeholders[2]}]", placeholder_txt3) | |
return ( | |
example_prompt, | |
"", | |
"", | |
"" | |
) | |
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo: | |
with gr.Column() as chat_view: | |
idx = gr.State(0) | |
chat_state = gr.State({ | |
"ppmanager_type": GradioLLaMA2ChatPPManager | |
}) | |
local_data = gr.JSON({}, visible=False) | |
gr.Markdown("## LLaMA2 70B with Gradio Chat and Hugging Face Inference API", elem_classes=["center"]) | |
gr.Markdown( | |
"This space demonstrates how to build feature rich chatbot UI in [Gradio](https://www.gradio.app/). Supported features " | |
"include • multiple chatting channels, • chat history save/restoration, • stop generating text response, • regenerate the " | |
"last conversation, • clean the chat history, • dynamic kick-starting prompt templates, • adjusting text generation parameters, " | |
"• inspecting the actual prompt that the model sees. The underlying Large Language Model is the [Meta AI](https://ai.meta.com/)'s " | |
"[LLaMA2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) which is hosted as [Hugging Face Inference API](https://huggingface.co/inference-api), " | |
"and [Text Generation Inference](https://github.com/huggingface/text-generation-inference) is the underlying serving framework.", | |
elem_classes=["center"]) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=180): | |
gr.Markdown("GradioChat", elem_id="left-top") | |
with gr.Column(elem_id="left-pane"): | |
chat_back_btn = gr.Button("Back", elem_id="chat-back-btn") | |
with gr.Accordion("Histories", elem_id="chat-history-accordion", open=True): | |
channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"])) | |
for channel in channels[1:]: | |
channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"])) | |
with gr.Column(scale=8, elem_id="right-pane"): | |
with gr.Column( | |
elem_id="initial-popup", visible=False | |
) as example_block: | |
with gr.Row(scale=1): | |
with gr.Column(elem_id="initial-popup-left-pane"): | |
gr.Markdown("GradioChat", elem_id="initial-popup-title") | |
gr.Markdown("Making the community's best AI chat models available to everyone.") | |
with gr.Column(elem_id="initial-popup-right-pane"): | |
gr.Markdown("Chat UI is now open sourced on Hugging Face Hub") | |
gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)") | |
with gr.Column(scale=1): | |
gr.Markdown("Examples") | |
with gr.Row(): | |
for example in examples: | |
ex_btns.append(gr.Button(example, elem_classes=["example-btn"])) | |
with gr.Column(elem_id="aux-btns-popup", visible=True): | |
with gr.Row(): | |
stop = gr.Button("Stop", elem_classes=["aux-btn"]) | |
regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"]) | |
clean = gr.Button("Clean", elem_classes=["aux-btn"]) | |
with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False): | |
context_inspector = gr.Textbox( | |
"", | |
elem_id="aux-viewer-inspector", | |
label="", | |
lines=30, | |
max_lines=50, | |
) | |
chatbot = gr.Chatbot(elem_id='chatbot', label="LLaMA2-70B-Chat") | |
instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt") | |
with gr.Accordion("Example Templates", open=False): | |
template_txt = gr.Textbox(visible=False) | |
template_md = gr.Markdown(label="Chosen Template", visible=False, elem_classes="template-txt") | |
with gr.Row(): | |
placeholder_txt1 = gr.Textbox(label="placeholder #1", visible=False, interactive=True) | |
placeholder_txt2 = gr.Textbox(label="placeholder #2", visible=False, interactive=True) | |
placeholder_txt3 = gr.Textbox(label="placeholder #3", visible=False, interactive=True) | |
for template in templates: | |
with gr.Tab(template['title']): | |
gr.Examples( | |
template['template'], | |
inputs=[template_txt], | |
outputs=[template_md, placeholder_txt1, placeholder_txt2, placeholder_txt3, instruction_txtbox], | |
run_on_click=True, | |
fn=fill_up_placeholders, | |
) | |
with gr.Accordion("Control Panel", open=False) as control_panel: | |
with gr.Column(): | |
with gr.Column(): | |
gr.Markdown("#### Global context") | |
with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=True): | |
global_context = gr.Textbox( | |
DEFAULT_GLOBAL_CTX, | |
lines=5, | |
max_lines=10, | |
interactive=True, | |
elem_id="global-context" | |
) | |
# gr.Markdown("#### Internet search") | |
# with gr.Row(): | |
# internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode") | |
# serper_api_key = gr.Textbox( | |
# value= "" if args.serper_api_key is None else args.serper_api_key, | |
# placeholder="Get one by visiting serper.dev", | |
# label="Serper api key" | |
# ) | |
gr.Markdown("#### GenConfig for **response** text generation") | |
with gr.Row(): | |
res_temp = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="temp", interactive=True) | |
res_topk = gr.Slider(20, 1000, 50, step=1, label="top_k", interactive=True) | |
res_rpen = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="rep_penalty", interactive=True) | |
res_mnts = gr.Slider(64, 8192, 512, step=1, label="new_tokens", interactive=True) | |
res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True) | |
with gr.Column(): | |
gr.Markdown("#### Context managements") | |
with gr.Row(): | |
ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True) | |
send_event = instruction_txtbox.submit( | |
chat_stream, | |
[idx, local_data, instruction_txtbox, chat_state, | |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv], | |
[instruction_txtbox, context_inspector, chatbot, local_data] | |
) | |
# regen_event1 = regenerate.click( | |
# rollback_last, | |
# [idx, local_data, chat_state], | |
# [instruction_txtbox, chatbot, local_data, regenerate] | |
# ) | |
# regen_event2 = regen_event1.then( | |
# chat_stream, | |
# [idx, local_data, instruction_txtbox, chat_state, | |
# global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv], | |
# [context_inspector, chatbot, local_data] | |
# ) | |
# regen_event3 = regen_event2.then( | |
# lambda: gr.update(interactive=True), | |
# None, | |
# regenerate | |
# ) | |
# regen_event4 = regen_event3.then( | |
# None, local_data, None, | |
# _js="(v)=>{ setStorage('local_data',v) }" | |
# ) | |
regen_event = regenerate.click( | |
rollback_last, | |
[idx, local_data, chat_state, | |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv], | |
[context_inspector, chatbot, local_data, regenerate] | |
).then( | |
None, local_data, None, | |
_js="(v)=>{ setStorage('local_data',v) }" | |
) | |
stop.click( | |
None, None, None, | |
cancels=[send_event, regen_event] | |
) | |
for btn in channel_btns: | |
btn.click( | |
set_chatbot, | |
[btn, local_data, chat_state], | |
[chatbot, idx, example_block, regenerate] | |
).then( | |
None, btn, None, | |
_js=UPDATE_LEFT_BTNS_STATE | |
) | |
for btn in ex_btns: | |
btn.click( | |
set_example, | |
[btn], | |
[instruction_txtbox, example_block] | |
) | |
clean.click( | |
reset_chat, | |
[idx, local_data, chat_state], | |
[instruction_txtbox, chatbot, local_data, example_block, regenerate] | |
).then( | |
None, local_data, None, | |
_js="(v)=>{ setStorage('local_data',v) }" | |
) | |
placeholder_txt1.change( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[template_md], | |
show_progress=False, | |
_js=UPDATE_PLACEHOLDERS, | |
fn=None | |
) | |
placeholder_txt2.change( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[template_md], | |
show_progress=False, | |
_js=UPDATE_PLACEHOLDERS, | |
fn=None | |
) | |
placeholder_txt3.change( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[template_md], | |
show_progress=False, | |
_js=UPDATE_PLACEHOLDERS, | |
fn=None | |
) | |
placeholder_txt1.submit( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
fn=get_final_template | |
) | |
placeholder_txt2.submit( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
fn=get_final_template | |
) | |
placeholder_txt3.submit( | |
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
fn=get_final_template | |
) | |
demo.load( | |
None, | |
inputs=None, | |
outputs=[chatbot, local_data], | |
_js=GET_LOCAL_STORAGE, | |
) | |
demo.queue().launch() |