File size: 6,259 Bytes
b884933
 
8c922bb
7351c15
a007d8f
 
 
e3ada61
a007d8f
 
7f07a51
 
 
a007d8f
7f07a51
 
 
 
 
 
 
b884933
a007d8f
b884933
a007d8f
 
e3ada61
 
 
 
 
 
 
 
a007d8f
7f07a51
 
 
 
e3ada61
7f07a51
 
 
 
 
 
 
 
 
 
e3ada61
8c922bb
 
 
 
7f07a51
 
 
 
 
 
 
 
 
 
 
e3ada61
 
 
 
 
 
 
 
 
 
cce7e09
e3ada61
 
7f07a51
 
 
 
 
 
 
 
 
 
 
 
 
f16688d
a007d8f
 
 
e3ada61
f16688d
a007d8f
 
 
e3ada61
a007d8f
e3ada61
 
 
 
7f07a51
a007d8f
 
e3ada61
a007d8f
e3ada61
7f07a51
 
 
 
 
 
 
 
 
 
 
 
 
 
a007d8f
 
7f07a51
 
 
 
 
 
e3ada61
 
 
 
 
 
 
 
 
 
7f07a51
 
 
 
 
 
 
 
 
e3ada61
7f07a51
 
 
e3ada61
 
 
7f07a51
e3ada61
7f07a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import streamlit as st
import tempfile

from langchain_openai import OpenAIEmbeddings
from langchain_openai.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import OpenAIWhisperParser
from langchain_community.document_loaders.blob_loaders.youtube_audio import (
    YoutubeAudioLoader,
)
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain


openai_api_key = os.getenv("OPENAI_API_KEY")

st.set_page_config(page_title="Chat with your data", page_icon="🤖")
st.title("Chat with your data")
st.header("Add your data for RAG")

data_type = st.radio(
    "Choose the type of data to add:", ("Text", "PDF", "Website", "YouTube")
)

if data_type == "YouTube":
    st.warning(
        "Note: Processing YouTube videos can be quite costly for me in terms of money. Please use this option sparingly. Thank you for your understanding!"
    )

if "vectordb" not in st.session_state:
    st.session_state.vectordb = None


def get_vectordb_from_text(text):
    embeddings = OpenAIEmbeddings()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    texts = text_splitter.split_text(text)
    vectordb = Chroma.from_texts(
        texts=texts,
        embedding=embeddings,
    )
    return vectordb


def get_vectordb_from_pdf(uploaded_pdf):
    with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
        tmp_file.write(uploaded_pdf.read())
        tmp_file_path = tmp_file.name
    loader = PyPDFLoader(tmp_file_path)
    pages = loader.load()
    embeddings = OpenAIEmbeddings()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    docs = text_splitter.split_documents(pages)
    vectordb = Chroma.from_documents(
        documents=docs,
        embedding=embeddings,
    )
    return vectordb


def get_vectordb_from_website(website_url):
    loader = WebBaseLoader(website_url)
    pages = loader.load()
    embeddings = OpenAIEmbeddings()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    docs = text_splitter.split_documents(pages)
    vectordb = Chroma.from_documents(
        documents=docs,
        embedding=embeddings,
    )
    return vectordb

def get_vectordb_from_youtube(youtube_url):
    save_dir = "docs/youtube"
    loader = GenericLoader(
        YoutubeAudioLoader([youtube_url], save_dir), OpenAIWhisperParser()
    )
    pages = loader.load()
    embeddings = OpenAIEmbeddings()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    docs = text_splitter.split_documents(pages)
    vectordb = Chroma.from_documents(
        documents=docs, embedding=embeddings, persist_directory="chroma"
    )
    return vectordb


if data_type == "Text":
    user_text = st.text_area("Enter text data")
    if st.button("Add"):
        st.session_state.vectordb = get_vectordb_from_text(user_text)

elif data_type == "PDF":
    uploaded_pdf = st.file_uploader("Upload PDF", type="pdf")
    if st.button("Add"):
        st.session_state.vectordb = get_vectordb_from_pdf(uploaded_pdf)

elif data_type == "Website":
    website_url = st.text_input("Enter website URL")
    if st.button("Add"):
        st.session_state.vectordb = get_vectordb_from_website(website_url)
else:
    youtube_url = st.text_input("Enter YouTube URL")
    if st.button("Add"):
        st.session_state.vectordb = get_vectordb_from_youtube(youtube_url)

llm = ChatOpenAI(api_key=openai_api_key, temperature=0.2, model="gpt-3.5-turbo")


def get_context_retreiver_chain(vectordb):
    retriever = vectordb.as_retriever()

    prompt = ChatPromptTemplate.from_messages(
        [
            MessagesPlaceholder(variable_name="chat_history"),
            ("user", "{input}"),
            (
                "user",
                "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation",
            ),
        ]
    )

    retriever_chain = create_history_aware_retriever(llm, retriever, prompt)

    return retriever_chain


def get_conversational_rag_chain(retriever_chain):
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "Answer the user's questions based on the below context:\n\n{context}",
            ),
            MessagesPlaceholder(variable_name="chat_history"),
            ("user", "{input}"),
        ]
    )

    stuff_domain_chain = create_stuff_documents_chain(llm, prompt)

    return create_retrieval_chain(retriever_chain, stuff_domain_chain)


def get_response(user_input):
    if st.session_state.vectordb is None:
        return "Please add data first"

    retrieveal_chain = get_context_retreiver_chain(st.session_state.vectordb)
    converasational_rag_chain = get_conversational_rag_chain(retrieveal_chain)

    response = converasational_rag_chain.invoke(
        {"chat_history": st.session_state.chat_history, "input": user_input}
    )

    return response["answer"]


user_query = st.chat_input("Your message")

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

for message in st.session_state.chat_history:
    if isinstance(message, HumanMessage):
        with st.chat_message("Human"):
            st.markdown(message.content)
    else:
        with st.chat_message("AI"):
            st.markdown(message.content)

if user_query and user_query != "":
    with st.chat_message("Human"):
        st.markdown(user_query)

    with st.chat_message("AI"):
        ai_response = get_response(user_query)
        st.markdown(ai_response)

    st.session_state.chat_history.append(HumanMessage(user_query))
    st.session_state.chat_history.append(AIMessage(ai_response))