|
import streamlit as st |
|
from langchain_core.messages import AIMessage, HumanMessage |
|
from langchain_community.chat_models import ChatOpenAI |
|
from dotenv import load_dotenv |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from download_chart import construct_plot |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain import hub |
|
from langchain_core.prompts.prompt import PromptTemplate |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import OpenAIEmbeddings |
|
load_dotenv() |
|
|
|
def get_conversation_chain(vectorstore): |
|
llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048) |
|
retriever=vectorstore.as_retriever() |
|
|
|
prompt = hub.pull("rlm/rag-prompt") |
|
|
|
rag_chain = ( |
|
{"context": retriever , "question": RunnablePassthrough()} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
return rag_chain |
|
|
|
def get_response(user_query, chat_history): |
|
|
|
template = """ |
|
Chat history: {chat_history} |
|
User question: {user_question} |
|
""" |
|
|
|
embeddings = OpenAIEmbeddings() |
|
db = FAISS.load_local("vectorstore_op", embeddings) |
|
|
|
question = ChatPromptTemplate.from_template(template) |
|
question = question.format(chat_history=chat_history, user_question=user_query) |
|
|
|
chain = get_conversation_chain(db) |
|
|
|
return chain.stream(question) |
|
|
|
def display_chart(): |
|
if "pp_grouped" not in st.session_state or st.session_state['pp_grouped'] is None or len(st.session_state['pp_grouped']) == 0: |
|
st.warning("Aucune partie prenante n'a été définie") |
|
return None |
|
plot = construct_plot() |
|
st.plotly_chart(plot) |
|
|
|
|
|
def display_chat(): |
|
|
|
st.title("Chatbot") |
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [ |
|
AIMessage(content="Salut, voici votre cartographie des parties prenantes. Que puis-je faire pour vous?"), |
|
] |
|
|
|
|
|
|
|
for message in st.session_state.chat_history: |
|
if isinstance(message, AIMessage): |
|
with st.chat_message("AI"): |
|
st.write(message.content) |
|
if "cartographie des parties prenantes" in message.content: |
|
display_chart() |
|
elif isinstance(message, HumanMessage): |
|
with st.chat_message("Moi"): |
|
st.write(message.content) |
|
|
|
|
|
user_query = st.chat_input("Par ici...") |
|
if user_query is not None and user_query != "": |
|
st.session_state.chat_history.append(HumanMessage(content=user_query)) |
|
|
|
with st.chat_message("Moi"): |
|
st.markdown(user_query) |
|
|
|
with st.chat_message("AI"): |
|
|
|
response = st.write_stream(get_response(user_query, st.session_state.chat_history,format_context(st.session_state['pp_grouped'],st.session_state['Nom de la marque']))) |
|
if "cartographie des parties prenantes" in message.content: |
|
display_chart() |
|
|
|
st.session_state.chat_history.append(AIMessage(content=response)) |
|
|