Carto-RSE / chat_te.py
Ilyas KHIAT
ajout et big update
8df1e9f
raw
history blame
3.13 kB
import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.chat_models import ChatOpenAI
from dotenv import load_dotenv
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from download_chart import construct_plot
from langchain_core.runnables import RunnablePassthrough
from langchain import hub
from langchain_core.prompts.prompt import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OpenAIEmbeddings
load_dotenv()
def get_conversation_chain(vectorstore):
llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048)
retriever=vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
# Chain
rag_chain = (
{"context": retriever , "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
def get_response(user_query, chat_history):
template = """
Chat history: {chat_history}
User question: {user_question}
"""
embeddings = OpenAIEmbeddings()
db = FAISS.load_local("vectorstore_op", embeddings)
question = ChatPromptTemplate.from_template(template)
question = question.format(chat_history=chat_history, user_question=user_query)
chain = get_conversation_chain(db)
return chain.stream(question)
def display_chart():
if "pp_grouped" not in st.session_state or st.session_state['pp_grouped'] is None or len(st.session_state['pp_grouped']) == 0:
st.warning("Aucune partie prenante n'a été définie")
return None
plot = construct_plot()
st.plotly_chart(plot)
def display_chat():
# app config
st.title("Chatbot")
# session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = [
AIMessage(content="Salut, voici votre cartographie des parties prenantes. Que puis-je faire pour vous?"),
]
# conversation
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.write(message.content)
if "cartographie des parties prenantes" in message.content:
display_chart()
elif isinstance(message, HumanMessage):
with st.chat_message("Moi"):
st.write(message.content)
# user input
user_query = st.chat_input("Par ici...")
if user_query is not None and user_query != "":
st.session_state.chat_history.append(HumanMessage(content=user_query))
with st.chat_message("Moi"):
st.markdown(user_query)
with st.chat_message("AI"):
response = st.write_stream(get_response(user_query, st.session_state.chat_history,format_context(st.session_state['pp_grouped'],st.session_state['Nom de la marque'])))
if "cartographie des parties prenantes" in message.content:
display_chart()
st.session_state.chat_history.append(AIMessage(content=response))