chatbot / app.py
hail75's picture
add website option
e3ada61
raw
history blame
6.24 kB
import os
import streamlit as st
import tempfile
from langchain_openai import OpenAIEmbeddings
from langchain_openai.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import OpenAIWhisperParser
from langchain_community.document_loaders.blob_loaders.youtube_audio import (
YoutubeAudioLoader,
)
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
openai_api_key = os.getenv("OPENAI_API_KEY")
st.set_page_config(page_title="Chat with your data", page_icon="πŸ€–")
st.title("Chat with your data")
st.header("Add your data for RAG")
data_type = st.radio(
"Choose the type of data to add:", ("Text", "PDF", "Website", "YouTube")
)
if data_type == "YouTube":
st.warning(
"Note: Processing YouTube videos can be quite costly for me in terms of money. Please use this option sparingly. Thank you for your understanding!"
)
if "vectordb" not in st.session_state:
st.session_state.vectordb = None
def get_vectordb_from_text(text):
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
texts = text_splitter.split_text(text)
vectordb = Chroma.from_texts(
texts=texts,
embedding=embeddings,
)
return vectordb
def get_vectordb_from_pdf(uploaded_pdf):
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(uploaded_pdf.read())
tmp_file_path = tmp_file.name
loader = PyPDFLoader(tmp_file_path)
pages = loader.load()
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.split_documents(pages)
vectordb = Chroma.from_documents(
documents=docs,
embedding=embeddings,
)
return vectordb
def get_vectordb_from_website(website_url):
loader = WebBaseLoader(website_url)
pages = loader.load()
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.split_documents(pages)
vectordb = Chroma.from_documents(
documents=docs,
embedding=embeddings,
)
def get_vectordb_from_youtube(youtube_url):
save_dir = "docs/youtube"
loader = GenericLoader(
YoutubeAudioLoader([youtube_url], save_dir), OpenAIWhisperParser()
)
pages = loader.load()
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.split_documents(pages)
vectordb = Chroma.from_documents(
documents=docs, embedding=embeddings, persist_directory="chroma"
)
return vectordb
if data_type == "Text":
user_text = st.text_area("Enter text data")
if st.button("Add"):
st.session_state.vectordb = get_vectordb_from_text(user_text)
elif data_type == "PDF":
uploaded_pdf = st.file_uploader("Upload PDF", type="pdf")
if st.button("Add"):
st.session_state.vectordb = get_vectordb_from_pdf(uploaded_pdf)
elif data_type == "Website":
website_url = st.text_input("Enter website URL")
if st.button("Add"):
st.session_state.vectordb = get_vectordb_from_website(website_url)
else:
youtube_url = st.text_input("Enter YouTube URL")
if st.button("Add"):
st.session_state.vectordb = get_vectordb_from_youtube(youtube_url)
llm = ChatOpenAI(api_key=openai_api_key, temperature=0.2, model="gpt-3.5-turbo")
def get_context_retreiver_chain(vectordb):
retriever = vectordb.as_retriever()
prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
(
"user",
"Given the above conversation, generate a search query to look up in order to get information relevant to the conversation",
),
]
)
retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
return retriever_chain
def get_conversational_rag_chain(retriever_chain):
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Answer the user's questions based on the below context:\n\n{context}",
),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
]
)
stuff_domain_chain = create_stuff_documents_chain(llm, prompt)
return create_retrieval_chain(retriever_chain, stuff_domain_chain)
def get_response(user_input):
if st.session_state.vectordb is None:
return "Please add data first"
retrieveal_chain = get_context_retreiver_chain(st.session_state.vectordb)
converasational_rag_chain = get_conversational_rag_chain(retrieveal_chain)
response = converasational_rag_chain.invoke(
{"chat_history": st.session_state.chat_history, "input": user_input}
)
return response["answer"]
user_query = st.chat_input("Your message")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
for message in st.session_state.chat_history:
if isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
else:
with st.chat_message("AI"):
st.markdown(message.content)
if user_query and user_query != "":
with st.chat_message("Human"):
st.markdown(user_query)
with st.chat_message("AI"):
ai_response = get_response(user_query)
st.markdown(ai_response)
st.session_state.chat_history.append(HumanMessage(user_query))
st.session_state.chat_history.append(AIMessage(ai_response))