import streamlit as st from langchain_core.messages import AIMessage, HumanMessage from langchain_community.chat_models import ChatOpenAI from dotenv import load_dotenv from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from download_chart import construct_plot from langchain_core.runnables import RunnablePassthrough from langchain import hub from langchain_core.prompts.prompt import PromptTemplate from langchain_community.vectorstores import FAISS from langchain_community.embeddings import OpenAIEmbeddings load_dotenv() def get_conversation_chain(vectorstore): llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048) retriever=vectorstore.as_retriever() prompt = hub.pull("rlm/rag-prompt") # Chain rag_chain = ( {"context": retriever , "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) return rag_chain def get_response(user_query, chat_history): template = """ Chat history: {chat_history} User question: {user_question} """ embeddings = OpenAIEmbeddings() db = FAISS.load_local("vectorstore_op", embeddings) question = ChatPromptTemplate.from_template(template) question = question.format(chat_history=chat_history, user_question=user_query) chain = get_conversation_chain(db) return chain.stream(question) def display_chart(): if "pp_grouped" not in st.session_state or st.session_state['pp_grouped'] is None or len(st.session_state['pp_grouped']) == 0: st.warning("Aucune partie prenante n'a été définie") return None plot = construct_plot() st.plotly_chart(plot) def display_chat(): # app config st.title("Chatbot") # session state if "chat_history" not in st.session_state: st.session_state.chat_history = [ AIMessage(content="Salut, voici votre cartographie des parties prenantes. Que puis-je faire pour vous?"), ] # conversation for message in st.session_state.chat_history: if isinstance(message, AIMessage): with st.chat_message("AI"): st.write(message.content) if "cartographie des parties prenantes" in message.content: display_chart() elif isinstance(message, HumanMessage): with st.chat_message("Moi"): st.write(message.content) # user input user_query = st.chat_input("Par ici...") if user_query is not None and user_query != "": st.session_state.chat_history.append(HumanMessage(content=user_query)) with st.chat_message("Moi"): st.markdown(user_query) with st.chat_message("AI"): response = st.write_stream(get_response(user_query, st.session_state.chat_history,format_context(st.session_state['pp_grouped'],st.session_state['Nom de la marque']))) if "cartographie des parties prenantes" in message.content: display_chart() st.session_state.chat_history.append(AIMessage(content=response))