|
import os |
|
import streamlit as st |
|
import tempfile |
|
|
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_openai.chat_models import ChatOpenAI |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader |
|
from langchain_community.document_loaders.generic import GenericLoader |
|
from langchain_community.document_loaders.parsers import OpenAIWhisperParser |
|
from langchain_community.document_loaders.blob_loaders.youtube_audio import ( |
|
YoutubeAudioLoader, |
|
) |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_core.messages import HumanMessage, AIMessage |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
st.set_page_config(page_title="Chat with your data", page_icon="π€") |
|
st.title("Chat with your data") |
|
st.header("Add your data for RAG") |
|
|
|
data_type = st.radio( |
|
"Choose the type of data to add:", ("Text", "PDF", "Website", "YouTube") |
|
) |
|
|
|
if data_type == "YouTube": |
|
st.warning( |
|
"Note: Processing YouTube videos can be quite costly for me in terms of money. Please use this option sparingly. Thank you for your understanding!" |
|
) |
|
|
|
if "vectordb" not in st.session_state: |
|
st.session_state.vectordb = None |
|
|
|
|
|
def get_vectordb_from_text(text): |
|
embeddings = OpenAIEmbeddings() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) |
|
texts = text_splitter.split_text(text) |
|
vectordb = Chroma.from_texts( |
|
texts=texts, |
|
embedding=embeddings, |
|
) |
|
return vectordb |
|
|
|
|
|
def get_vectordb_from_pdf(uploaded_pdf): |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: |
|
tmp_file.write(uploaded_pdf.read()) |
|
tmp_file_path = tmp_file.name |
|
loader = PyPDFLoader(tmp_file_path) |
|
pages = loader.load() |
|
embeddings = OpenAIEmbeddings() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) |
|
docs = text_splitter.split_documents(pages) |
|
vectordb = Chroma.from_documents( |
|
documents=docs, |
|
embedding=embeddings, |
|
) |
|
return vectordb |
|
|
|
|
|
def get_vectordb_from_website(website_url): |
|
loader = WebBaseLoader(website_url) |
|
pages = loader.load() |
|
embeddings = OpenAIEmbeddings() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) |
|
docs = text_splitter.split_documents(pages) |
|
vectordb = Chroma.from_documents( |
|
documents=docs, |
|
embedding=embeddings, |
|
) |
|
|
|
|
|
def get_vectordb_from_youtube(youtube_url): |
|
save_dir = "docs/youtube" |
|
loader = GenericLoader( |
|
YoutubeAudioLoader([youtube_url], save_dir), OpenAIWhisperParser() |
|
) |
|
pages = loader.load() |
|
embeddings = OpenAIEmbeddings() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) |
|
docs = text_splitter.split_documents(pages) |
|
vectordb = Chroma.from_documents( |
|
documents=docs, embedding=embeddings, persist_directory="chroma" |
|
) |
|
return vectordb |
|
|
|
|
|
if data_type == "Text": |
|
user_text = st.text_area("Enter text data") |
|
if st.button("Add"): |
|
st.session_state.vectordb = get_vectordb_from_text(user_text) |
|
|
|
elif data_type == "PDF": |
|
uploaded_pdf = st.file_uploader("Upload PDF", type="pdf") |
|
if st.button("Add"): |
|
st.session_state.vectordb = get_vectordb_from_pdf(uploaded_pdf) |
|
|
|
elif data_type == "Website": |
|
website_url = st.text_input("Enter website URL") |
|
if st.button("Add"): |
|
st.session_state.vectordb = get_vectordb_from_website(website_url) |
|
else: |
|
youtube_url = st.text_input("Enter YouTube URL") |
|
if st.button("Add"): |
|
st.session_state.vectordb = get_vectordb_from_youtube(youtube_url) |
|
|
|
llm = ChatOpenAI(api_key=openai_api_key, temperature=0.2, model="gpt-3.5-turbo") |
|
|
|
|
|
def get_context_retreiver_chain(vectordb): |
|
retriever = vectordb.as_retriever() |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("user", "{input}"), |
|
( |
|
"user", |
|
"Given the above conversation, generate a search query to look up in order to get information relevant to the conversation", |
|
), |
|
] |
|
) |
|
|
|
retriever_chain = create_history_aware_retriever(llm, retriever, prompt) |
|
|
|
return retriever_chain |
|
|
|
|
|
def get_conversational_rag_chain(retriever_chain): |
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
"Answer the user's questions based on the below context:\n\n{context}", |
|
), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("user", "{input}"), |
|
] |
|
) |
|
|
|
stuff_domain_chain = create_stuff_documents_chain(llm, prompt) |
|
|
|
return create_retrieval_chain(retriever_chain, stuff_domain_chain) |
|
|
|
|
|
def get_response(user_input): |
|
if st.session_state.vectordb is None: |
|
return "Please add data first" |
|
|
|
retrieveal_chain = get_context_retreiver_chain(st.session_state.vectordb) |
|
converasational_rag_chain = get_conversational_rag_chain(retrieveal_chain) |
|
|
|
response = converasational_rag_chain.invoke( |
|
{"chat_history": st.session_state.chat_history, "input": user_input} |
|
) |
|
|
|
return response["answer"] |
|
|
|
|
|
user_query = st.chat_input("Your message") |
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
for message in st.session_state.chat_history: |
|
if isinstance(message, HumanMessage): |
|
with st.chat_message("Human"): |
|
st.markdown(message.content) |
|
else: |
|
with st.chat_message("AI"): |
|
st.markdown(message.content) |
|
|
|
if user_query and user_query != "": |
|
with st.chat_message("Human"): |
|
st.markdown(user_query) |
|
|
|
with st.chat_message("AI"): |
|
ai_response = get_response(user_query) |
|
st.markdown(ai_response) |
|
|
|
st.session_state.chat_history.append(HumanMessage(user_query)) |
|
st.session_state.chat_history.append(AIMessage(ai_response)) |
|
|