kastan's picture
remove ray dependency
f38e35a
raw
history blame
10.6 kB
import os
import gradio as gr
import retrieval
from text_generation import Client, InferenceAPIClient
# load API keys from globally-availabe .env file
# SECRETS_FILEPATH = "/mnt/project/chatbotai/huggingface_cache/internal_api_keys.env"
# load_dotenv(dotenv_path=SECRETS_FILEPATH, override=True)
openchat_preprompt = (
"\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for "
"fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source "
"community. I am not human, not evil and not alive, and thus have no thoughts and feelings, "
"but I am programmed to be helpful, polite, honest, and friendly. I'm really smart at answering electrical engineering questions.\n")
# LOAD MODELS
ta = retrieval.Retrieval()
NUM_ANSWERS_GENERATED = 3
def clip_img_search(img):
if img is None:
return []
else:
return ta.reverse_img_search(img)
def get_client(model: str):
if model == "Rallio67/joi2_20Be_instruct_alpha":
return Client(os.getenv("JOI_API_URL"))
if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
return Client(os.getenv("OPENCHAT_API_URL"))
return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
def get_usernames(model: str):
"""
Returns:
(str, str, str, str): pre-prompt, username, bot name, separator
"""
if model == "OpenAssistant/oasst-sft-1-pythia-12b":
return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
if model == "Rallio67/joi2_20Be_instruct_alpha":
return "", "User: ", "Joi: ", "\n\n"
if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
return "", "User: ", "Assistant: ", "\n"
def predict(
model: str,
inputs: str,
typical_p: float,
top_p: float,
temperature: float,
top_k: int,
repetition_penalty: float,
watermark: bool,
chatbot,
history,
):
client = get_client(model)
preprompt, user_name, assistant_name, sep = get_usernames(model)
history.append(inputs)
past = []
for data in chatbot:
user_data, model_data = data
if not user_data.startswith(user_name):
user_data = user_name + user_data
if not model_data.startswith(sep + assistant_name):
model_data = sep + assistant_name + model_data
past.append(user_data + model_data.rstrip() + sep)
if not inputs.startswith(user_name):
inputs = user_name + inputs
total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
partial_words = ""
if model == "OpenAssistant/oasst-sft-1-pythia-12b":
iterator = client.generate_stream(
total_inputs,
typical_p=typical_p,
truncate=1000,
watermark=watermark,
max_new_tokens=500,
)
else:
iterator = client.generate_stream(
total_inputs,
top_p=top_p if top_p < 1.0 else None,
top_k=top_k,
truncate=1000,
repetition_penalty=repetition_penalty,
watermark=watermark,
temperature=temperature,
max_new_tokens=500,
stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
)
for i, response in enumerate(iterator):
if response.token.special:
continue
partial_words = partial_words + response.token.text
if partial_words.endswith(user_name.rstrip()):
partial_words = partial_words.rstrip(user_name.rstrip())
if partial_words.endswith(assistant_name.rstrip()):
partial_words = partial_words.rstrip(assistant_name.rstrip())
if i == 0:
history.append(" " + partial_words)
elif response.token.text not in user_name:
history[-1] = partial_words
chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
yield chat, history, None, None, None, []
# Pinecone context retrieval
top_context_list = ta.retrieve_contexts_from_pinecone(user_question=inputs, topk=NUM_ANSWERS_GENERATED)
yield chat, history, top_context_list[0], top_context_list[1], top_context_list[2], []
# run CLIP
images_list = ta.clip_text_to_image(inputs)
yield chat, history, top_context_list[0], top_context_list[1], top_context_list[2], images_list
def reset_textbox():
return gr.update(value="")
def radio_on_change(
value: str,
disclaimer,
typical_p,
top_p,
top_k,
temperature,
repetition_penalty,
watermark,
):
if value == "OpenAssistant/oasst-sft-1-pythia-12b":
typical_p = typical_p.update(value=0.2, visible=True)
top_p = top_p.update(visible=False)
top_k = top_k.update(visible=False)
temperature = temperature.update(visible=False)
disclaimer = disclaimer.update(visible=False)
repetition_penalty = repetition_penalty.update(visible=False)
watermark = watermark.update(False)
elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
typical_p = typical_p.update(visible=False)
top_p = top_p.update(value=0.25, visible=True)
top_k = top_k.update(value=50, visible=True)
temperature = temperature.update(value=0.6, visible=True)
repetition_penalty = repetition_penalty.update(value=1.01, visible=True)
watermark = watermark.update(False)
disclaimer = disclaimer.update(visible=True)
else:
typical_p = typical_p.update(visible=False)
top_p = top_p.update(value=0.95, visible=True)
top_k = top_k.update(value=4, visible=True)
temperature = temperature.update(value=0.5, visible=True)
repetition_penalty = repetition_penalty.update(value=1.03, visible=True)
watermark = watermark.update(True)
disclaimer = disclaimer.update(visible=False)
return (
disclaimer,
typical_p,
top_p,
top_k,
temperature,
repetition_penalty,
watermark,
)
title = """<h1 align="center">πŸ”₯Teaching Assistant Chatbot"""
description = """
"""
openchat_disclaimer = """
<div align="center">Checkout the official <a href=https://huggingface.co/spaces/togethercomputer/OpenChatKit>OpenChatKit feedback app</a> for the full experience.</div>
"""
with gr.Blocks(css="""#col_container {margin-left: auto; margin-right: auto;}
#chatbot {height: 520px; overflow: auto;}""") as demo:
gr.HTML(title)
with gr.Row():
with gr.Accordion("Model choices", open=False, visible=True):
model = gr.Radio(
value="OpenAssistant/oasst-sft-1-pythia-12b",
choices=[
"OpenAssistant/oasst-sft-1-pythia-12b",
# "togethercomputer/GPT-NeoXT-Chat-Base-20B",
"Rallio67/joi2_20Be_instruct_alpha",
"google/flan-t5-xxl",
"google/flan-ul2",
"bigscience/bloom",
"bigscience/bloomz",
"EleutherAI/gpt-neox-20b",
],
label="",
interactive=True,
)
# with gr.Row():
# with gr.Column():
# use_gpt3_checkbox = gr.Checkbox(label="Include GPT-3 (paid)?")
# with gr.Column():
# use_equation_checkbox = gr.Checkbox(label="Prioritize equations?")
state = gr.State([])
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot(elem_id="chatbot")
inputs = gr.Textbox(placeholder="Ask an Electrical Engineering question!", label="Send a message...")
examples = gr.Examples(
examples=[
"What is a Finite State Machine?",
"How do you design a functional a Two-Bit Gray Code Counter?",
"How can we compare an 8-bit 2's complement number to the value -1 using AND, OR, and NOT?",
"What does the uninterrupted counting cycle label mean?",
],
inputs=[inputs],
outputs=[],
)
gr.Markdown("## Relevant Textbook Passages & Lecture Transcripts")
with gr.Row():
with gr.Column():
context1 = gr.Textbox(label="Context 1")
with gr.Column():
context2 = gr.Textbox(label="Context 2")
with gr.Column():
context3 = gr.Textbox(label="Context 3")
gr.Markdown("## Relevant Lecture Slides")
with gr.Row():
with gr.Column(scale=2.6):
lec_gallery = gr.Gallery(label="Lecture images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
with gr.Column(scale=1):
inp_image = gr.Image(type="pil", label="Reverse Image Search (optional)", shape=(224, 398))
inp_image.change(fn=clip_img_search, inputs=inp_image, outputs=lec_gallery, scroll_to_output=True)
disclaimer = gr.Markdown(openchat_disclaimer, visible=False)
# state = gr.State([])
with gr.Row():
with gr.Accordion("Parameters", open=False, visible=True):
typical_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.2,
step=0.05,
interactive=True,
label="Typical P mass",
)
top_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.25,
step=0.05,
interactive=True,
label="Top-p (nucleus sampling)",
visible=False,
)
temperature = gr.Slider(
minimum=-0,
maximum=5.0,
value=0.6,
step=0.1,
interactive=True,
label="Temperature",
visible=False,
)
top_k = gr.Slider(
minimum=1,
maximum=50,
value=50,
step=1,
interactive=True,
label="Top-k",
visible=False,
)
repetition_penalty = gr.Slider(
minimum=0.1,
maximum=3.0,
value=1.03,
step=0.01,
interactive=True,
label="Repetition Penalty",
visible=False,
)
watermark = gr.Checkbox(value=False, label="Text watermarking")
model.change(
lambda value: radio_on_change(
value,
disclaimer,
typical_p,
top_p,
top_k,
temperature,
repetition_penalty,
watermark,
),
inputs=model,
outputs=[
disclaimer,
typical_p,
top_p,
top_k,
temperature,
repetition_penalty,
watermark,
],
)
inputs.submit(
predict,
[
model,
inputs,
typical_p,
top_p,
temperature,
top_k,
repetition_penalty,
watermark,
chatbot,
state,
],
[chatbot, state, context1, context2, context3, lec_gallery],
)
inputs.submit(reset_textbox, [], [inputs])
gr.Markdown(description)
demo.queue(concurrency_count=16).launch(share=True, debug=True)