ShawnAI's picture
Update app.py
3f8dc6c
import gradio as gr
import random
import time
from langchain import PromptTemplate
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, OpenAIEmbeddings
from langchain.vectorstores import Pinecone
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
import pinecone
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
#OPENAI_API_KEY = ""
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_TEMP = 1
OPENAI_API_LINK = "[OpenAI API Key](https://platform.openai.com/account/api-keys)"
OPENAI_LINK = "[OpenAI](https://openai.com)"
PINECONE_KEY = os.environ.get("PINECONE_KEY", "")
PINECONE_ENV = os.environ.get("PINECONE_ENV", "asia-northeast1-gcp")
PINECONE_INDEX = os.environ.get("PINECONE_INDEX", '3gpp-r16')
PINECONE_LINK = "[Pinecone](https://www.pinecone.io)"
LANGCHAIN_LINK = "[LangChain](https://python.langchain.com/en/latest/index.html)"
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "hkunlp/instructor-large")
EMBEDDING_LOADER = os.environ.get("EMBEDDING_LOADER", "HuggingFaceInstructEmbeddings")
EMBEDDING_LIST = ["HuggingFaceInstructEmbeddings", "HuggingFaceEmbeddings", "OpenAIEmbeddings"]
# return top-k text chunks from vector store
TOP_K_DEFAULT = 15
TOP_K_MAX = 30
SCORE_DEFAULT = 0.33
BUTTON_MIN_WIDTH = 215
LLM_NULL = "LLM-UNLOAD-critical"
LLM_DONE = "LLM-LOADED-9cf"
DB_NULL = "DB-UNLOAD-critical"
DB_DONE = "DB-LOADED-9cf"
FORK_BADGE = "Fork-HuggingFace Space-9cf"
def get_logo(inputs, logo) -> str:
return f"""https://img.shields.io/badge/{inputs}?style=flat&logo={logo}&logoColor=white"""
def get_status(inputs, logo, pos) -> str:
return f"""<img
src = "{get_logo(inputs, logo)}";
style = "margin: 0 auto;float:{pos};border: 2px solid transparent;";
>"""
KEY_INIT = "Initialize Model"
KEY_SUBMIT = "Submit"
KEY_CLEAR = "Clear"
MODEL_NULL = get_status(LLM_NULL, "openai", "right")
MODEL_DONE = get_status(LLM_DONE, "openai", "right")
DOCS_NULL = get_status(DB_NULL, "processingfoundation", "right")
DOCS_DONE = get_status(DB_DONE, "processingfoundation", "right")
TAB_1 = "Chatbot"
TAB_2 = "Details"
TAB_3 = "Database"
TAB_4 = "TODO"
FAVICON = './icon.svg'
LLM_LIST = ["gpt-3.5-turbo", "text-davinci-003"]
DOC_1 = '3GPP'
DOC_2 = 'HTTP2'
DOC_SUPPORTED = [DOC_1]
DOC_DEFAULT = [DOC_1]
DOC_LABEL = "Reference Docs"
MODEL_WARNING = f"Please paste your **{OPENAI_API_LINK}** and then **{KEY_INIT}**"
DOCS_WARNING = f"""Database Unloaded
Please check your **{TAB_3}** config and then **{KEY_INIT}**
Or you could uncheck **{DOC_LABEL}** to ask LLM directly"""
webui_title = """
# OpenAI Chatbot Based on Vector Database
"""
dup_link = f'''<a href="https://huggingface.co/spaces/ShawnAI/VectorDB-ChatBot?duplicate=true"
style="display:grid; width: 200px;">
<img src="{get_logo(FORK_BADGE, "addthis")}"></a>'''
init_message = f"""This demonstration website is based on \
**{OPENAI_LINK}** with **{LANGCHAIN_LINK}** and **{PINECONE_LINK}**
1. Insert your **{OPENAI_API_LINK}** and click `{KEY_INIT}`
2. Insert your **Question** and click `{KEY_SUBMIT}`
"""
PROMPT_DOC = PromptTemplate(
input_variables=["context", "chat_history", "question"],
template="""Context:
##
{context}
##
Chat History:
##
{chat_history}
##
Question:
{question}
Answer:"""
)
PROMPT_BASE = PromptTemplate(
input_variables=['question', "chat_history"],
template="""Chat History:
##
{chat_history}
##
Question:
##
{question}
##
Answer:"""
)
#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------
def init_rwkv():
try:
import rwkv
return True
except Exception:
print("RWKV not found, skip local llm")
return False
def init_model(api_key, emb_name, emb_loader, db_api_key, db_env, db_index):
init_rwkv()
try:
if not (api_key and api_key.startswith("sk-") and len(api_key) > 50):
return None,MODEL_NULL+DOCS_NULL,None,None,None,None
llm_dict = {}
for llm_name in LLM_LIST:
if llm_name == "gpt-3.5-turbo":
llm_dict[llm_name] = ChatOpenAI(model_name=llm_name,
temperature = OPENAI_TEMP,
openai_api_key = api_key
)
else:
llm_dict[llm_name] = OpenAI(model_name=llm_name,
temperature = OPENAI_TEMP,
openai_api_key = api_key)
if not (emb_name and db_api_key and db_env and db_index):
return api_key,MODEL_DONE+DOCS_NULL,llm_dict,None,None,None
if emb_loader == "OpenAIEmbeddings":
embeddings = eval(emb_loader)(openai_api_key=api_key)
else:
embeddings = eval(emb_loader)(model_name=emb_name)
pinecone.init(api_key = db_api_key,
environment = db_env)
db = Pinecone.from_existing_index(index_name = db_index,
embedding = embeddings)
return api_key, MODEL_DONE+DOCS_DONE, llm_dict, None, db, None
except Exception as e:
print(e)
return None,MODEL_NULL+DOCS_NULL,None,None,None,None
def get_chat_history(inputs) -> str:
res = []
for human, ai in inputs:
res.append(f"Q: {human}\nA: {ai}")
return "\n".join(res)
def remove_duplicates(documents, score_min):
seen_content = set()
unique_documents = []
for (doc, score) in documents:
if (doc.page_content not in seen_content) and (score >= score_min):
seen_content.add(doc.page_content)
unique_documents.append(doc)
return unique_documents
def doc_similarity(query, db, top_k, score):
docs = db.similarity_search_with_score(query = query,
k=top_k)
#docsearch = db.as_retriever(search_kwargs={'k':top_k})
#docs = docsearch.get_relevant_documents(query)
udocs = remove_duplicates(docs, score)
return udocs
def user(user_message, history):
return "", history+[[user_message, None]]
def bot(box_message, ref_message,
llm_dropdown, llm_dict, doc_list,
db, top_k, score):
# bot_message = random.choice(["Yes", "No"])
# 0 is user question, 1 is bot response
question = box_message[-1][0]
history = box_message[:-1]
if (not llm_dict):
box_message[-1][1] = MODEL_WARNING
return box_message, "", ""
if not ref_message:
ref_message = question
details = f"Q: {question}"
else:
details = f"Q: {question}\nR: {ref_message}"
llm = llm_dict[llm_dropdown]
if DOC_1 in doc_list:
if (not db):
box_message[-1][1] = DOCS_WARNING
return box_message, "", ""
docs = doc_similarity(ref_message, db, top_k, score)
delta_top_k = top_k - len(docs)
if delta_top_k > 0:
docs = doc_similarity(ref_message, db, top_k+delta_top_k, score)
prompt = PROMPT_DOC
#chain = load_qa_chain(llm, chain_type="stuff")
else:
prompt = PROMPT_BASE
docs = []
chain = LLMChain(llm = llm,
prompt = prompt,
output_key = 'output_text')
all_output = chain({"question": question,
"context": docs,
"chat_history": get_chat_history(history)
})
bot_message = all_output['output_text']
source = "".join([f"""<details> <summary>{doc.metadata["source"]}</summary>
{doc.page_content}
</details>""" for i, doc in enumerate(docs)])
#print(source)
box_message[-1][1] = bot_message
return box_message, "", [[details, bot_message + '\n\nMetadata:\n' + source]]
#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------
with gr.Blocks(
title = TAB_1,
theme = "Base",
css = """.bigbox {
min-height:250px;
}
""") as demo:
llm = gr.State()
chain_2 = gr.State() # not inuse
vector_db = gr.State()
gr.Markdown(webui_title)
gr.Markdown(dup_link)
gr.Markdown(init_message)
with gr.Row():
with gr.Column(scale=10):
llm_api_textbox = gr.Textbox(
label = "OpenAI API Key",
# show_label = False,
value = OPENAI_API_KEY,
placeholder = "Paste Your OpenAI API Key (sk-...) and Hit ENTER",
lines=1,
type='password')
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
init = gr.Button(KEY_INIT) #.style(full_width=False)
model_statusbox = gr.HTML(MODEL_NULL+DOCS_NULL)
with gr.Tab(TAB_1):
with gr.Row():
with gr.Column(scale=10):
chatbot = gr.Chatbot(elem_classes="bigbox")
#with gr.Column(scale=1):
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
doc_check = gr.CheckboxGroup(choices = DOC_SUPPORTED,
value = DOC_DEFAULT,
label = DOC_LABEL,
interactive=True)
llm_dropdown = gr.Dropdown(LLM_LIST,
value=LLM_LIST[0],
multiselect=False,
interactive=True,
label="LLM Selection",
)
with gr.Row():
with gr.Column(scale=10):
query = gr.Textbox(label="Question:",
lines=2)
ref = gr.Textbox(label="Reference(optional):")
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
clear = gr.Button(KEY_CLEAR)
submit = gr.Button(KEY_SUBMIT,variant="primary")
with gr.Tab(TAB_2):
with gr.Row():
with gr.Column():
top_k = gr.Slider(1,
TOP_K_MAX,
value=TOP_K_DEFAULT,
step=1,
label="Vector similarity top_k",
interactive=True)
with gr.Column():
score = gr.Slider(0.01,
0.99,
value=SCORE_DEFAULT,
step=0.01,
label="Vector similarity score",
interactive=True)
detail_panel = gr.Chatbot(label="Related Docs")
with gr.Tab(TAB_3):
with gr.Row():
with gr.Column():
emb_textbox = gr.Textbox(
label = "Embedding Model",
# show_label = False,
value = EMBEDDING_MODEL,
placeholder = "Paste Your Embedding Model Repo on HuggingFace",
lines=1,
interactive=True,
type='email')
with gr.Column():
emb_dropdown = gr.Dropdown(
EMBEDDING_LIST,
value=EMBEDDING_LOADER,
multiselect=False,
interactive=True,
label="Embedding Loader")
with gr.Accordion("Pinecone Database for "+DOC_1):
with gr.Row():
db_api_textbox = gr.Textbox(
label = "Pinecone API Key",
# show_label = False,
value = PINECONE_KEY,
placeholder = "Paste Your Pinecone API Key (xx-xx-xx-xx-xx) and Hit ENTER",
lines=1,
interactive=True,
type='password')
with gr.Row():
db_env_textbox = gr.Textbox(
label = "Pinecone Environment",
# show_label = False,
value = PINECONE_ENV,
placeholder = "Paste Your Pinecone Environment (xx-xx-xx) and Hit ENTER",
lines=1,
interactive=True,
type='email')
db_index_textbox = gr.Textbox(
label = "Pinecone Index",
# show_label = False,
value = PINECONE_INDEX,
placeholder = "Paste Your Pinecone Index (xxxx) and Hit ENTER",
lines=1,
interactive=True,
type='email')
with gr.Tab(TAB_4):
"TODO"
init_input = [llm_api_textbox, emb_textbox, emb_dropdown, db_api_textbox, db_env_textbox, db_index_textbox]
init_output = [llm_api_textbox, model_statusbox,
llm, chain_2,
vector_db, chatbot]
llm_api_textbox.submit(init_model, init_input, init_output)
init.click(init_model, init_input, init_output)
submit.click(user,
[query, chatbot],
[query, chatbot],
queue=False).then(
bot,
[chatbot, ref,
llm_dropdown, llm, doc_check,
vector_db, top_k, score],
[chatbot, ref, detail_panel]
)
clear.click(lambda: (None,None,None), None, [query, ref, chatbot], queue=False)
#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(share = False,
inbrowser = True,
favicon_path = FAVICON)