import os import streamlit as st import tempfile from langchain_openai import OpenAIEmbeddings from langchain_openai.chat_models import ChatOpenAI from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader from langchain_community.document_loaders.generic import GenericLoader from langchain_community.document_loaders.parsers import OpenAIWhisperParser from langchain_community.document_loaders.blob_loaders.youtube_audio import ( YoutubeAudioLoader, ) from langchain_community.vectorstores import Chroma from langchain_core.messages import HumanMessage, AIMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain openai_api_key = os.getenv("OPENAI_API_KEY") st.set_page_config(page_title="Chat with your data", page_icon="🤖") st.title("Chat with your data") st.header("Add your data for RAG") data_type = st.radio( "Choose the type of data to add:", ("Text", "PDF", "Website", "YouTube") ) if data_type == "YouTube": st.warning( "Note: Processing YouTube videos can be quite costly for me in terms of money. Please use this option sparingly. Thank you for your understanding!" ) if "vectordb" not in st.session_state: st.session_state.vectordb = None def get_vectordb_from_text(text): embeddings = OpenAIEmbeddings() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) texts = text_splitter.split_text(text) vectordb = Chroma.from_texts( texts=texts, embedding=embeddings, ) return vectordb def get_vectordb_from_pdf(uploaded_pdf): with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: tmp_file.write(uploaded_pdf.read()) tmp_file_path = tmp_file.name loader = PyPDFLoader(tmp_file_path) pages = loader.load() embeddings = OpenAIEmbeddings() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) docs = text_splitter.split_documents(pages) vectordb = Chroma.from_documents( documents=docs, embedding=embeddings, ) return vectordb def get_vectordb_from_website(website_url): loader = WebBaseLoader(website_url) pages = loader.load() embeddings = OpenAIEmbeddings() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) docs = text_splitter.split_documents(pages) vectordb = Chroma.from_documents( documents=docs, embedding=embeddings, ) def get_vectordb_from_youtube(youtube_url): save_dir = "docs/youtube" loader = GenericLoader( YoutubeAudioLoader([youtube_url], save_dir), OpenAIWhisperParser() ) pages = loader.load() embeddings = OpenAIEmbeddings() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) docs = text_splitter.split_documents(pages) vectordb = Chroma.from_documents( documents=docs, embedding=embeddings, persist_directory="chroma" ) return vectordb if data_type == "Text": user_text = st.text_area("Enter text data") if st.button("Add"): st.session_state.vectordb = get_vectordb_from_text(user_text) elif data_type == "PDF": uploaded_pdf = st.file_uploader("Upload PDF", type="pdf") if st.button("Add"): st.session_state.vectordb = get_vectordb_from_pdf(uploaded_pdf) elif data_type == "Website": website_url = st.text_input("Enter website URL") if st.button("Add"): st.session_state.vectordb = get_vectordb_from_website(website_url) else: youtube_url = st.text_input("Enter YouTube URL") if st.button("Add"): st.session_state.vectordb = get_vectordb_from_youtube(youtube_url) llm = ChatOpenAI(api_key=openai_api_key, temperature=0.2, model="gpt-3.5-turbo") def get_context_retreiver_chain(vectordb): retriever = vectordb.as_retriever() prompt = ChatPromptTemplate.from_messages( [ MessagesPlaceholder(variable_name="chat_history"), ("user", "{input}"), ( "user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation", ), ] ) retriever_chain = create_history_aware_retriever(llm, retriever, prompt) return retriever_chain def get_conversational_rag_chain(retriever_chain): prompt = ChatPromptTemplate.from_messages( [ ( "system", "Answer the user's questions based on the below context:\n\n{context}", ), MessagesPlaceholder(variable_name="chat_history"), ("user", "{input}"), ] ) stuff_domain_chain = create_stuff_documents_chain(llm, prompt) return create_retrieval_chain(retriever_chain, stuff_domain_chain) def get_response(user_input): if st.session_state.vectordb is None: return "Please add data first" retrieveal_chain = get_context_retreiver_chain(st.session_state.vectordb) converasational_rag_chain = get_conversational_rag_chain(retrieveal_chain) response = converasational_rag_chain.invoke( {"chat_history": st.session_state.chat_history, "input": user_input} ) return response["answer"] user_query = st.chat_input("Your message") if "chat_history" not in st.session_state: st.session_state.chat_history = [] for message in st.session_state.chat_history: if isinstance(message, HumanMessage): with st.chat_message("Human"): st.markdown(message.content) else: with st.chat_message("AI"): st.markdown(message.content) if user_query and user_query != "": with st.chat_message("Human"): st.markdown(user_query) with st.chat_message("AI"): ai_response = get_response(user_query) st.markdown(ai_response) st.session_state.chat_history.append(HumanMessage(user_query)) st.session_state.chat_history.append(AIMessage(ai_response))