chansung's picture
Update app.py
d82a572
raw
history blame
18.6 kB
import os
import re
import json
import copy
import gradio as gr
from llama2 import GradioLLaMA2ChatPPManager
from llama2 import gen_text
from styles import MODEL_SELECTION_CSS
from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS
from templates import templates
from constants import DEFAULT_GLOBAL_CTX
from pingpong import PingPong
from pingpong.context import CtxLastWindowStrategy
TOKEN = os.getenv('HF_TOKEN')
MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf'
def build_prompts(ppmanager, global_context, win_size=3):
dummy_ppm = copy.deepcopy(ppmanager)
dummy_ppm.ctx = global_context
lws = CtxLastWindowStrategy(win_size)
return lws(dummy_ppm)
ex_file = open("examples.txt", "r")
examples = ex_file.read().split("\n")
ex_btns = []
chl_file = open("channels.txt", "r")
channels = chl_file.read().split("\n")
channel_btns = []
def get_placeholders(text):
"""Returns all substrings in between <placeholder> and </placeholder>."""
pattern = r"\[([^\]]*)\]"
matches = re.findall(pattern, text)
return matches
def fill_up_placeholders(txt):
placeholders = get_placeholders(txt)
highlighted_txt = txt
return (
gr.update(
visible=True,
value=highlighted_txt
),
gr.update(
visible=True if len(placeholders) >= 1 else False,
placeholder=placeholders[0] if len(placeholders) >= 1 else ""
),
gr.update(
visible=True if len(placeholders) >= 2 else False,
placeholder=placeholders[1] if len(placeholders) >= 2 else ""
),
gr.update(
visible=True if len(placeholders) >= 3 else False,
placeholder=placeholders[2] if len(placeholders) >= 3 else ""
),
"" if len(placeholders) >= 1 else txt
)
async def rollback_last(
idx, local_data, chat_state,
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
):
res = [
chat_state["ppmanager_type"].from_json(json.dumps(ppm))
for ppm in local_data
]
ppm = res[idx]
last_user_message = res[idx].pingpongs[-1].ping
res[idx].pingpongs = res[idx].pingpongs[:-1]
ppm.add_pingpong(
PingPong(last_user_message, "")
)
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
async for result in gen_text(
prompt, hf_model=MODEL_ID, hf_token=TOKEN,
parameters={
'max_new_tokens': res_mnts,
'do_sample': res_sample,
'return_full_text': False,
'temperature': res_temp,
'top_k': res_topk,
'repetition_penalty': res_rpen
}
):
ppm.append_pong(result)
yield prompt, ppm.build_uis(), str(res), gr.update(interactive=False)
yield prompt, ppm.build_uis(), str(res), gr.update(interactive=True)
def reset_chat(idx, ld, state):
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
res[idx].pingpongs = []
return (
"",
[],
str(res),
gr.update(visible=True),
gr.update(interactive=False),
)
def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cpu"):
internet_search_ppm = copy.deepcopy(ppm)
internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, "
f"give me an appropriate query to answer my question for google search. "
f"You should not say more than query. You should not say any words except the query."
internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
instruction = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
###
searcher = SimilaritySearcher.from_pretrained(device=device)
iss = InternetSearchStrategy(
searcher,
instruction=instruction,
serper_api_key=serper_api_key
)(ppmanager)
step_ppm = None
while True:
try:
step_ppm, _ = next(iss)
yield "", step_ppm.build_uis()
except StopIteration:
break
search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
yield search_prompt, ppmanager.build_uis()
async def chat_stream(
idx, local_data, instruction_txtbox, chat_state,
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
internet_option, serper_api_key
):
res = [
chat_state["ppmanager_type"].from_json(json.dumps(ppm))
for ppm in local_data
]
ppm = res[idx]
ppm.add_pingpong(
PingPong(instruction_txtbox, "")
)
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
#######
if internet_option:
search_prompt = None
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
search_prompt = tmp_prompt
yield "", uis, prompt, str(res)
async for result in gen_text(
prompt, hf_model=MODEL_ID, hf_token=TOKEN,
parameters={
'max_new_tokens': res_mnts,
'do_sample': res_sample,
'return_full_text': False,
'temperature': res_temp,
'top_k': res_topk,
'repetition_penalty': res_rpen
}
):
ppm.append_pong(result)
yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=False)
yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=True)
def channel_num(btn_title):
choice = 0
for idx, channel in enumerate(channels):
if channel == btn_title:
choice = idx
return choice
def set_chatbot(btn, ld, state):
choice = channel_num(btn)
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
empty = len(res[choice].pingpongs) == 0
return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty))
def set_example(btn):
return btn, gr.update(visible=False)
def get_final_template(
txt, placeholder_txt1, placeholder_txt2, placeholder_txt3
):
placeholders = get_placeholders(txt)
example_prompt = txt
if len(placeholders) >= 1:
if placeholder_txt1 != "":
example_prompt = example_prompt.replace(f"[{placeholders[0]}]", placeholder_txt1)
if len(placeholders) >= 2:
if placeholder_txt2 != "":
example_prompt = example_prompt.replace(f"[{placeholders[1]}]", placeholder_txt2)
if len(placeholders) >= 3:
if placeholder_txt3 != "":
example_prompt = example_prompt.replace(f"[{placeholders[2]}]", placeholder_txt3)
return (
example_prompt,
"",
"",
""
)
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
with gr.Column() as chat_view:
idx = gr.State(0)
chat_state = gr.State({
"ppmanager_type": GradioLLaMA2ChatPPManager
})
local_data = gr.JSON({}, visible=False)
gr.Markdown("## LLaMA2 70B with Gradio Chat and Hugging Face Inference API", elem_classes=["center"])
gr.Markdown(
"This space demonstrates how to build feature rich chatbot UI in [Gradio](https://www.gradio.app/). Supported features "
"include • multiple chatting channels, • chat history save/restoration, • stop generating text response, • regenerate the "
"last conversation, • clean the chat history, • dynamic kick-starting prompt templates, • adjusting text generation parameters, "
"• inspecting the actual prompt that the model sees. The underlying Large Language Model is the [Meta AI](https://ai.meta.com/)'s "
"[LLaMA2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) which is hosted as [Hugging Face Inference API](https://huggingface.co/inference-api), "
"and [Text Generation Inference](https://github.com/huggingface/text-generation-inference) is the underlying serving framework. ",
elem_classes=["center"]
)
gr.Markdown(
"***NOTE:*** If you are subscribing [PRO](https://huggingface.co/pricing#pro), you can simply duplicate this space and use your "
"Hugging Face Access Token to run the same application. Just add `HF_TOKEN` secret with the Token value accorindg to [this guide](https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables)",
elem_classes=["center"]
)
with gr.Row():
with gr.Column(scale=1, min_width=180):
gr.Markdown("GradioChat", elem_id="left-top")
with gr.Column(elem_id="left-pane"):
with gr.Accordion("Histories", elem_id="chat-history-accordion", open=True):
channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"]))
for channel in channels[1:]:
channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"]))
with gr.Column(scale=8, elem_id="right-pane"):
with gr.Column(
elem_id="initial-popup", visible=False
) as example_block:
with gr.Row(scale=1):
with gr.Column(elem_id="initial-popup-left-pane"):
gr.Markdown("GradioChat", elem_id="initial-popup-title")
gr.Markdown("Making the community's best AI chat models available to everyone.")
with gr.Column(elem_id="initial-popup-right-pane"):
gr.Markdown("Chat UI is now open sourced on Hugging Face Hub")
gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)")
with gr.Column(scale=1):
gr.Markdown("Examples")
with gr.Row():
for example in examples:
ex_btns.append(gr.Button(example, elem_classes=["example-btn"]))
with gr.Column(elem_id="aux-btns-popup", visible=True):
with gr.Row():
# stop = gr.Button("Stop", elem_classes=["aux-btn"])
regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"])
clean = gr.Button("Clean", elem_classes=["aux-btn"])
with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False):
context_inspector = gr.Textbox(
"",
elem_id="aux-viewer-inspector",
label="",
lines=30,
max_lines=50,
)
chatbot = gr.Chatbot(elem_id='chatbot', label="LLaMA2-70B-Chat")
instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt")
with gr.Accordion("Example Templates", open=False):
template_txt = gr.Textbox(visible=False)
template_md = gr.Markdown(label="Chosen Template", visible=False, elem_classes="template-txt")
with gr.Row():
placeholder_txt1 = gr.Textbox(label="placeholder #1", visible=False, interactive=True)
placeholder_txt2 = gr.Textbox(label="placeholder #2", visible=False, interactive=True)
placeholder_txt3 = gr.Textbox(label="placeholder #3", visible=False, interactive=True)
for template in templates:
with gr.Tab(template['title']):
gr.Examples(
template['template'],
inputs=[template_txt],
outputs=[template_md, placeholder_txt1, placeholder_txt2, placeholder_txt3, instruction_txtbox],
run_on_click=True,
fn=fill_up_placeholders,
)
with gr.Accordion("Control Panel", open=False) as control_panel:
with gr.Column():
with gr.Column():
gr.Markdown("#### Global context")
with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=True):
global_context = gr.Textbox(
DEFAULT_GLOBAL_CTX,
lines=5,
max_lines=10,
interactive=True,
elem_id="global-context"
)
gr.Markdown("#### Internet search")
with gr.Row():
internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
serper_api_key = gr.Textbox(
value= os.getenv("SERPER_API_KEY"),
placeholder="Get one by visiting serper.dev",
label="Serper api key",
visible=False
)
gr.Markdown("#### GenConfig for **response** text generation")
with gr.Row():
res_temp = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="temp", interactive=True)
res_topk = gr.Slider(20, 1000, 50, step=1, label="top_k", interactive=True)
res_rpen = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="rep_penalty", interactive=True)
res_mnts = gr.Slider(64, 8192, 512, step=1, label="new_tokens", interactive=True)
res_sample = gr.Radio([True, False], value=True, label="sample", interactive=True)
with gr.Column():
gr.Markdown("#### Context managements")
with gr.Row():
ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True)
send_event = instruction_txtbox.submit(
lambda: [
gr.update(visible=False),
gr.update(interactive=True)
],
None,
[example_block, regenerate]
).then(
chat_stream,
[idx, local_data, instruction_txtbox, chat_state,
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
internet_option, serper_api_key],
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate]
).then(
None, local_data, None,
_js="(v)=>{ setStorage('local_data',v) }"
)
# regen_event1 = regenerate.click(
# rollback_last,
# [idx, local_data, chat_state],
# [instruction_txtbox, chatbot, local_data, regenerate]
# )
# regen_event2 = regen_event1.then(
# chat_stream,
# [idx, local_data, instruction_txtbox, chat_state,
# global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv],
# [context_inspector, chatbot, local_data]
# )
# regen_event3 = regen_event2.then(
# lambda: gr.update(interactive=True),
# None,
# regenerate
# )
# regen_event4 = regen_event3.then(
# None, local_data, None,
# _js="(v)=>{ setStorage('local_data',v) }"
# )
regen_event = regenerate.click(
rollback_last,
[idx, local_data, chat_state,
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
internet_option, serper_api_key],
[context_inspector, chatbot, local_data, regenerate]
).then(
None, local_data, None,
_js="(v)=>{ setStorage('local_data',v) }"
)
# stop.click(
# lambda: gr.update(interactive=True), None, regenerate,
# cancels=[send_event, regen_event]
# )
for btn in channel_btns:
btn.click(
set_chatbot,
[btn, local_data, chat_state],
[chatbot, idx, example_block, regenerate]
).then(
None, btn, None,
_js=UPDATE_LEFT_BTNS_STATE
)
for btn in ex_btns:
btn.click(
set_example,
[btn],
[instruction_txtbox, example_block]
)
clean.click(
reset_chat,
[idx, local_data, chat_state],
[instruction_txtbox, chatbot, local_data, example_block, regenerate]
).then(
None, local_data, None,
_js="(v)=>{ setStorage('local_data',v) }"
)
placeholder_txt1.change(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[template_md],
show_progress=False,
_js=UPDATE_PLACEHOLDERS,
fn=None
)
placeholder_txt2.change(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[template_md],
show_progress=False,
_js=UPDATE_PLACEHOLDERS,
fn=None
)
placeholder_txt3.change(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[template_md],
show_progress=False,
_js=UPDATE_PLACEHOLDERS,
fn=None
)
placeholder_txt1.submit(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3],
fn=get_final_template
)
placeholder_txt2.submit(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3],
fn=get_final_template
)
placeholder_txt3.submit(
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3],
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3],
fn=get_final_template
)
demo.load(
None,
inputs=None,
outputs=[chatbot, local_data],
_js=GET_LOCAL_STORAGE,
)
demo.queue(concurrency_count=5, max_size=256).launch()