momenaca's picture
add latest updates regarding agent mode
d708cb9
raw
history blame
7.06 kB
import os
from datetime import datetime
import gradio as gr
from pinecone import Pinecone
from huggingface_hub import whoami
from langchain.prompts import ChatPromptTemplate
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_openai import AzureChatOpenAI
from langchain.prompts.prompt import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain_community.vectorstores import Pinecone as PineconeVectorstore
from celsius_csrd_chatbot.utils import (
make_html_source,
make_pairs,
_format_chat_history,
_combine_documents,
get_llm,
init_env,
parse_output_llm_with_sources,
)
from celsius_csrd_chatbot.agent import make_graph_agent, display_graph
init_env()
chat_model_init = get_llm()
demo_name = "ESRS_QA"
hf_model = "BAAI/bge-base-en-v1.5"
embeddings = HuggingFaceBgeEmbeddings(
model_name=hf_model,
encode_kwargs={"normalize_embeddings": True},
)
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
index = pc.Index(os.getenv("PINECONE_API_INDEX"))
vectorstore = PineconeVectorstore(index, embeddings, "page_content")
llm = AzureChatOpenAI()
agent = make_graph_agent(llm, vectorstore)
memory = ConversationBufferMemory(
return_messages=True, output_key="answer", input_key="question"
)
async def chat(query, history):
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
(messages in gradio format, messages in langchain format, source documents)"""
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f">> NEW QUESTION ({date_now}) : {query}")
inputs = {"query": query}
result = agent.astream_events(inputs, version="v1")
docs = []
docs_html = ""
output_query = ""
start_streaming = False
steps_display = {
"categorize_esrs": ("๐Ÿ”„๏ธ Analyzing user query", True),
"retrieve_documents": ("๐Ÿ”„๏ธ Searching in the knowledge base", True),
}
try:
async for event in result:
print(event)
if event["event"] == "on_chat_model_stream":
print("line 66")
if start_streaming == False:
print("line 68")
start_streaming = True
history[-1] = (query, "")
new_token = event["data"]["chunk"].content
previous_answer = history[-1][1]
previous_answer = previous_answer if previous_answer is not None else ""
answer_yet = previous_answer + new_token
answer_yet = parse_output_llm_with_sources(answer_yet)
history[-1] = (query, answer_yet)
elif (
event["name"] == "retrieve_documents"
and event["event"] == "on_chain_end"
):
try:
print("line 84")
docs = event["data"]["output"]["documents"]
docs_html = []
for i, d in enumerate(docs, 1):
docs_html.append(make_html_source(d, i))
docs_html = "".join(docs_html)
except Exception as e:
print(f"Error getting documents: {e}")
print(event)
for event_name, (
event_description,
display_output,
) in steps_display.items():
if event["name"] == event_name:
print("line 99")
if event["event"] == "on_chain_start":
print("line 101")
answer_yet = event_description
history[-1] = (query, answer_yet)
history = [tuple(x) for x in history]
yield history, docs_html
except Exception as e:
raise gr.Error(f"{e}")
with open("./assets/style.css", "r") as f:
css = f.read()
# Set up Gradio Theme
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
init_prompt = """
Hello, I am ESRS Q&A, a conversational assistant designed to help you understand the content of European Sustainability Reporting Standards (ESRS). I will answer your questions based **on the official definition of each ESRS as well as complementary guidelines**.
โš ๏ธ Limitations
*Please note that this chatbot is in an early stage phase, it is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
What do you want to learn ?
"""
with gr.Blocks(title=f"{demo_name}", css=css, theme=theme) as demo:
with gr.Column(visible=True) as bloc_2:
with gr.Tab("ESRS Q&A"):
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value=[(None, init_prompt)],
show_copy_button=True,
show_label=False,
elem_id="chatbot",
layout="panel",
avatar_images=(
None,
"https://i.ibb.co/cN0czLp/celsius-logo.png",
),
)
state = gr.State([])
with gr.Row(elem_id="input-message"):
ask = gr.Textbox(
placeholder="Ask me anything here!",
show_label=False,
scale=7,
lines=1,
interactive=True,
elem_id="input-textbox",
)
with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
with gr.Tab("Sources", elem_id="tab-citations", id=1):
sources_textbox = gr.HTML(
show_label=False, elem_id="sources-textbox"
)
docs_textbox = gr.State("")
with gr.Tab("About", elem_classes="max-height other-tabs"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("WIP")
def start_chat(query, history):
history = history + [(query, None)]
history = [tuple(x) for x in history]
return (gr.update(interactive=False), history)
def finish_chat():
return gr.update(interactive=True, value="")
ask.submit(
start_chat,
[ask, chatbot],
[ask, chatbot],
queue=False,
api_name="start_chat_textbox",
).then(
fn=chat,
inputs=[
ask,
chatbot,
],
outputs=[chatbot, sources_textbox],
).then(
finish_chat, None, [ask], api_name="finish_chat_textbox"
)
demo.launch(
share=True,
debug=True,
)