import os import re import json import copy import gradio as gr from llama2 import GradioLLaMA2ChatPPManager from llama2 import gen_text from styles import MODEL_SELECTION_CSS from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS from templates import templates from constants import DEFAULT_GLOBAL_CTX from pingpong import PingPong from pingpong.context import CtxLastWindowStrategy TOKEN = os.getenv('HF_TOKEN') MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf' def build_prompts(ppmanager, global_context, win_size=3): dummy_ppm = copy.deepcopy(ppmanager) dummy_ppm.ctx = global_context lws = CtxLastWindowStrategy(win_size) return lws(dummy_ppm) ex_file = open("examples.txt", "r") examples = ex_file.read().split("\n") ex_btns = [] chl_file = open("channels.txt", "r") channels = chl_file.read().split("\n") channel_btns = [] def get_placeholders(text): """Returns all substrings in between and .""" pattern = r"\[([^\]]*)\]" matches = re.findall(pattern, text) return matches def fill_up_placeholders(txt): placeholders = get_placeholders(txt) highlighted_txt = txt return ( gr.update( visible=True, value=highlighted_txt ), gr.update( visible=True if len(placeholders) >= 1 else False, placeholder=placeholders[0] if len(placeholders) >= 1 else "" ), gr.update( visible=True if len(placeholders) >= 2 else False, placeholder=placeholders[1] if len(placeholders) >= 2 else "" ), gr.update( visible=True if len(placeholders) >= 3 else False, placeholder=placeholders[2] if len(placeholders) >= 3 else "" ), "" if len(placeholders) >= 1 else txt ) async def rollback_last( idx, local_data, chat_state, global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv ): res = [ chat_state["ppmanager_type"].from_json(json.dumps(ppm)) for ppm in local_data ] ppm = res[idx] last_user_message = res[idx].pingpongs[-1].ping res[idx].pingpongs = res[idx].pingpongs[:-1] ppm.add_pingpong( PingPong(last_user_message, "") ) prompt = build_prompts(ppm, global_context, ctx_num_lconv) async for result in gen_text( prompt, hf_model=MODEL_ID, hf_token=TOKEN, parameters={ 'max_new_tokens': res_mnts, 'do_sample': res_sample, 'return_full_text': False, 'temperature': res_temp, 'top_k': res_topk, 'repetition_penalty': res_rpen } ): ppm.append_pong(result) yield "", prompt, ppm.build_uis(), str(res) yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=True) def reset_chat(idx, ld, state): res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] res[idx].pingpongs = [] return ( "", [], str(res), gr.update(visible=True), gr.update(interactive=False), ) async def chat_stream( idx, local_data, instruction_txtbox, chat_state, global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv ): res = [ chat_state["ppmanager_type"].from_json(json.dumps(ppm)) for ppm in local_data ] ppm = res[idx] ppm.add_pingpong( PingPong(instruction_txtbox, "") ) prompt = build_prompts(ppm, global_context, ctx_num_lconv) async for result in gen_text( prompt, hf_model=MODEL_ID, hf_token=TOKEN, parameters={ 'max_new_tokens': res_mnts, 'do_sample': res_sample, 'return_full_text': False, 'temperature': res_temp, 'top_k': res_topk, 'repetition_penalty': res_rpen } ): ppm.append_pong(result) yield "", prompt, ppm.build_uis(), str(res) def channel_num(btn_title): choice = 0 for idx, channel in enumerate(channels): if channel == btn_title: choice = idx return choice def set_chatbot(btn, ld, state): choice = channel_num(btn) res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] empty = len(res[choice].pingpongs) == 0 return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty)) def set_example(btn): return btn, gr.update(visible=False) def get_final_template( txt, placeholder_txt1, placeholder_txt2, placeholder_txt3 ): placeholders = get_placeholders(txt) example_prompt = txt if len(placeholders) >= 1: if placeholder_txt1 != "": example_prompt = example_prompt.replace(f"[{placeholders[0]}]", placeholder_txt1) if len(placeholders) >= 2: if placeholder_txt2 != "": example_prompt = example_prompt.replace(f"[{placeholders[1]}]", placeholder_txt2) if len(placeholders) >= 3: if placeholder_txt3 != "": example_prompt = example_prompt.replace(f"[{placeholders[2]}]", placeholder_txt3) return ( example_prompt, "", "", "" ) with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo: with gr.Column() as chat_view: idx = gr.State(0) chat_state = gr.State({ "ppmanager_type": GradioLLaMA2ChatPPManager }) local_data = gr.JSON({}, visible=False) gr.Markdown("## LLaMA2 70B with Gradio Chat and Hugging Face Inference API", elem_classes=["center"]) gr.Markdown( "This space demonstrates how to build feature rich chatbot UI in [Gradio](https://www.gradio.app/). Supported features " "include • multiple chatting channels, • chat history save/restoration, • stop generating text response, • regenerate the " "last conversation, • clean the chat history, • dynamic kick-starting prompt templates, • adjusting text generation parameters, " "• inspecting the actual prompt that the model sees. The underlying Large Language Model is the [Meta AI](https://ai.meta.com/)'s " "[LLaMA2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) which is hosted as [Hugging Face Inference API](https://huggingface.co/inference-api), " "and [Text Generation Inference](https://github.com/huggingface/text-generation-inference) is the underlying serving framework.", elem_classes=["center"]) with gr.Row(): with gr.Column(scale=1, min_width=180): gr.Markdown("GradioChat", elem_id="left-top") with gr.Column(elem_id="left-pane"): chat_back_btn = gr.Button("Back", elem_id="chat-back-btn") with gr.Accordion("Histories", elem_id="chat-history-accordion", open=True): channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"])) for channel in channels[1:]: channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"])) with gr.Column(scale=8, elem_id="right-pane"): with gr.Column( elem_id="initial-popup", visible=False ) as example_block: with gr.Row(scale=1): with gr.Column(elem_id="initial-popup-left-pane"): gr.Markdown("GradioChat", elem_id="initial-popup-title") gr.Markdown("Making the community's best AI chat models available to everyone.") with gr.Column(elem_id="initial-popup-right-pane"): gr.Markdown("Chat UI is now open sourced on Hugging Face Hub") gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)") with gr.Column(scale=1): gr.Markdown("Examples") with gr.Row(): for example in examples: ex_btns.append(gr.Button(example, elem_classes=["example-btn"])) with gr.Column(elem_id="aux-btns-popup", visible=True): with gr.Row(): stop = gr.Button("Stop", elem_classes=["aux-btn"]) regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"]) clean = gr.Button("Clean", elem_classes=["aux-btn"]) with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False): context_inspector = gr.Textbox( "", elem_id="aux-viewer-inspector", label="", lines=30, max_lines=50, ) chatbot = gr.Chatbot(elem_id='chatbot', label="LLaMA2-70B-Chat") instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt") with gr.Accordion("Example Templates", open=False): template_txt = gr.Textbox(visible=False) template_md = gr.Markdown(label="Chosen Template", visible=False, elem_classes="template-txt") with gr.Row(): placeholder_txt1 = gr.Textbox(label="placeholder #1", visible=False, interactive=True) placeholder_txt2 = gr.Textbox(label="placeholder #2", visible=False, interactive=True) placeholder_txt3 = gr.Textbox(label="placeholder #3", visible=False, interactive=True) for template in templates: with gr.Tab(template['title']): gr.Examples( template['template'], inputs=[template_txt], outputs=[template_md, placeholder_txt1, placeholder_txt2, placeholder_txt3, instruction_txtbox], run_on_click=True, fn=fill_up_placeholders, ) with gr.Accordion("Control Panel", open=False) as control_panel: with gr.Column(): with gr.Column(): gr.Markdown("#### Global context") with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=True): global_context = gr.Textbox( DEFAULT_GLOBAL_CTX, lines=5, max_lines=10, interactive=True, elem_id="global-context" ) # gr.Markdown("#### Internet search") # with gr.Row(): # internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode") # serper_api_key = gr.Textbox( # value= "" if args.serper_api_key is None else args.serper_api_key, # placeholder="Get one by visiting serper.dev", # label="Serper api key" # ) gr.Markdown("#### GenConfig for **response** text generation") with gr.Row(): res_temp = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="temp", interactive=True) res_topk = gr.Slider(20, 1000, 50, step=1, label="top_k", interactive=True) res_rpen = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="rep_penalty", interactive=True) res_mnts = gr.Slider(64, 8192, 512, step=1, label="new_tokens", interactive=True) res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True) with gr.Column(): gr.Markdown("#### Context managements") with gr.Row(): ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True) send_event = instruction_txtbox.submit( chat_stream, [idx, local_data, instruction_txtbox, chat_state, global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv], [instruction_txtbox, context_inspector, chatbot, local_data] ) # regen_event1 = regenerate.click( # rollback_last, # [idx, local_data, chat_state], # [instruction_txtbox, chatbot, local_data, regenerate] # ) # regen_event2 = regen_event1.then( # chat_stream, # [idx, local_data, instruction_txtbox, chat_state, # global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv], # [context_inspector, chatbot, local_data] # ) # regen_event3 = regen_event2.then( # lambda: gr.update(interactive=True), # None, # regenerate # ) # regen_event4 = regen_event3.then( # None, local_data, None, # _js="(v)=>{ setStorage('local_data',v) }" # ) regen_event = regenerate.click( rollback_last, [idx, local_data, chat_state, global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv], [context_inspector, chatbot, local_data, regenerate] ).then( None, local_data, None, _js="(v)=>{ setStorage('local_data',v) }" ) stop.click( None, None, None, cancels=[send_event, regen_event] ) for btn in channel_btns: btn.click( set_chatbot, [btn, local_data, chat_state], [chatbot, idx, example_block, regenerate] ).then( None, btn, None, _js=UPDATE_LEFT_BTNS_STATE ) for btn in ex_btns: btn.click( set_example, [btn], [instruction_txtbox, example_block] ) clean.click( reset_chat, [idx, local_data, chat_state], [instruction_txtbox, chatbot, local_data, example_block, regenerate] ).then( None, local_data, None, _js="(v)=>{ setStorage('local_data',v) }" ) placeholder_txt1.change( inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], outputs=[template_md], show_progress=False, _js=UPDATE_PLACEHOLDERS, fn=None ) placeholder_txt2.change( inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], outputs=[template_md], show_progress=False, _js=UPDATE_PLACEHOLDERS, fn=None ) placeholder_txt3.change( inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], outputs=[template_md], show_progress=False, _js=UPDATE_PLACEHOLDERS, fn=None ) placeholder_txt1.submit( inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], fn=get_final_template ) placeholder_txt2.submit( inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], fn=get_final_template ) placeholder_txt3.submit( inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], fn=get_final_template ) demo.load( None, inputs=None, outputs=[chatbot, local_data], _js=GET_LOCAL_STORAGE, ) demo.queue().launch()