|
import streamlit as st |
|
from langchain_core.messages import AIMessage, HumanMessage |
|
from langchain_community.chat_models import ChatOpenAI |
|
from dotenv import load_dotenv |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from download_chart import construct_plot |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain import hub |
|
from langchain_core.prompts.prompt import PromptTemplate |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import OpenAIEmbeddings |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_experimental.text_splitter import SemanticChunker |
|
load_dotenv() |
|
|
|
def get_docs_from_pdf(file): |
|
loader = PyPDFLoader(file) |
|
docs = loader.load_and_split() |
|
return docs |
|
|
|
def get_doc_chunks(docs): |
|
text_splitter = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-small")) |
|
chunks = text_splitter.split_documents(docs) |
|
return chunks |
|
|
|
def get_vectorstore_from_docs(doc_chunks): |
|
embedding = OpenAIEmbeddings(model="text-embedding-3-small") |
|
vectorstore = FAISS.from_documents(documents=doc_chunks, embedding=embedding) |
|
return vectorstore |
|
|
|
def get_conversation_chain(vectorstore): |
|
llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048) |
|
retriever=vectorstore.as_retriever() |
|
|
|
prompt = hub.pull("rlm/rag-prompt") |
|
|
|
rag_chain = ( |
|
{"context": retriever , "question": RunnablePassthrough()} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
return rag_chain |
|
|
|
def create_db(file): |
|
|
|
|
|
|
|
vectorstore = FAISS.load_local(file, OpenAIEmbeddings(model="text-embedding-3-small"),allow_dangerous_deserialization= True) |
|
return vectorstore |
|
|
|
def get_response(chain,user_query, chat_history): |
|
|
|
template = """ |
|
Chat history: {chat_history} |
|
User question: {user_question} |
|
""" |
|
|
|
|
|
question = ChatPromptTemplate.from_template(template) |
|
question = question.format(chat_history=chat_history, user_question=user_query) |
|
|
|
return chain.stream(question) |
|
|
|
|
|
@st.experimental_dialog("Cast your vote") |
|
def vote(item): |
|
st.write(f"Why is {item} your favorite?") |
|
reason = st.text_input("Because...") |
|
if st.button("Submit"): |
|
st.rerun() |
|
|
|
def display_chat_te(): |
|
|
|
st.title("Chatbot") |
|
|
|
|
|
if "chat_history_te" not in st.session_state: |
|
st.session_state.chat_history_te = [ |
|
AIMessage(content="Salut, posez-moi vos question sur la transistion ecologique."), |
|
] |
|
if "chain" not in st.session_state: |
|
db=create_db("./DATA_bziiit/vectorstore_op") |
|
chain = get_conversation_chain(db) |
|
st.session_state.chain = chain |
|
|
|
|
|
for message in st.session_state.chat_history_te: |
|
if isinstance(message, AIMessage): |
|
with st.chat_message("AI"): |
|
st.write(message.content) |
|
elif isinstance(message, HumanMessage): |
|
with st.chat_message("Moi"): |
|
st.write(message.content) |
|
|
|
style = """ |
|
<style> |
|
.css-ocqkz7 { |
|
position: fixed; |
|
bottom: 0; |
|
width: 50%; |
|
justify-content: center; |
|
align-items: end; |
|
margin-bottom: 0.5rem; |
|
} |
|
</style> |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
user_query = st.chat_input(placeholder="c'est quoi la transition écologique ?") |
|
if user_query is not None and user_query != "": |
|
st.session_state.chat_history_te.append(HumanMessage(content=user_query)) |
|
|
|
with st.chat_message("Moi"): |
|
st.markdown(user_query) |
|
|
|
with st.chat_message("AI"): |
|
response = st.write_stream(get_response(st.session_state.chain,user_query, st.session_state.chat_history_te)) |
|
|
|
st.session_state.chat_history_te.append(AIMessage(content=response)) |
|
|