File size: 5,224 Bytes
b884933
 
7351c15
a007d8f
 
 
 
 
 
7f07a51
 
 
a007d8f
7f07a51
 
 
 
 
 
 
 
b884933
a007d8f
b884933
a007d8f
 
 
 
7f07a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f16688d
a007d8f
 
 
7f07a51
f16688d
a007d8f
 
 
7f07a51
a007d8f
7f07a51
a007d8f
 
7f07a51
a007d8f
 
7f07a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a007d8f
 
7f07a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import streamlit as st

from langchain_openai import OpenAIEmbeddings
from langchain_openai.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import OpenAIWhisperParser
from langchain_community.document_loaders.blob_loaders.youtube_audio import (
    YoutubeAudioLoader,
)
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain


openai_api_key = os.getenv("OPENAI_API_KEY")

st.set_page_config(page_title="Chat with your data", page_icon="🤖")
st.title("Chat with your data")
st.header("Add your data for RAG")

data_type = st.radio("Choose the type of data to add:", ("Text", "PDF", "YouTube URL"))

if "vectordb" not in st.session_state:
    st.session_state.vectordb = None


def add_text_to_chroma(text):
    embeddings = OpenAIEmbeddings()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    texts = text_splitter.split_text(text)
    vectordb = Chroma.from_texts(
        texts=texts,
        embedding=embeddings,
    )
    return vectordb


def add_pdf_to_chroma(uploaded_pdf):
    loader = PyPDFLoader(uploaded_pdf)
    pages = loader.load()
    embeddings = OpenAIEmbeddings()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    docs = text_splitter.split_documents(pages)
    vectordb = Chroma.from_documents(
        documents=docs,
        embedding=embeddings,
    )
    return vectordb


def add_youtube_to_chroma(youtube_url):
    save_dir = "docs/youtube"
    loader = GenericLoader(
        YoutubeAudioLoader([youtube_url], save_dir), OpenAIWhisperParser()
    )
    pages = loader.load()
    embeddings = OpenAIEmbeddings()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    docs = text_splitter.split_documents(pages)
    vectordb = Chroma.from_documents(
        documents=docs, embedding=embeddings, persist_directory="chroma"
    )
    return vectordb


if data_type == "Text":
    user_text = st.text_area("Enter text data")
    if st.button("Add"):
        st.session_state.vectordb = add_text_to_chroma(user_text)

elif data_type == "PDF":
    uploaded_pdf = st.file_uploader("Upload PDF", type="pdf")
    if st.button("Add"):
        st.session_state.vectordb = add_pdf_to_chroma(uploaded_pdf)

else:
    youtube_url = st.text_input("Enter YouTube URL")
    if st.button("Add"):
        st.session_state.vectordb = add_youtube_to_chroma(youtube_url)

llm = ChatOpenAI(
    api_key=openai_api_key, temperature=0.2, model="gpt-3.5-turbo"
)


def get_context_retreiver_chain(vectordb):
    retriever = vectordb.as_retriever()

    prompt = ChatPromptTemplate.from_messages(
        [
            MessagesPlaceholder(variable_name="chat_history"),
            ("user", "{input}"),
            (
                "user",
                "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation",
            ),
        ]
    )

    retriever_chain = create_history_aware_retriever(llm, retriever, prompt)

    return retriever_chain


def get_conversational_rag_chain(retriever_chain):
    prompt = ChatPromptTemplate.from_messages([
      ("system", "Answer the user's questions based on the below context:\n\n{context}"),
      MessagesPlaceholder(variable_name="chat_history"),
      ("user", "{input}"),
    ])

    stuff_domain_chain = create_stuff_documents_chain(llm, prompt)

    return create_retrieval_chain(retriever_chain, stuff_domain_chain)


def get_response(user_input):
    if st.session_state.vectordb is None:
        return "Please add data first"
    
    retrieveal_chain = get_context_retreiver_chain(st.session_state.vectordb)
    converasational_rag_chain = get_conversational_rag_chain(retrieveal_chain)

    response = converasational_rag_chain.invoke({
        "chat_history": st.session_state.chat_history,
        "input": user_input
    })

    return response


user_query = st.chat_input("Your message")

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

for message in st.session_state.chat_history:
    if isinstance(message, HumanMessage):
        with st.chat_message("Human"):
            st.markdown(message.content)
    else:
        with st.chat_message("AI"):
            st.markdown(message.content)

if user_query and user_query != "":
    with st.chat_message("Human"):
        st.markdown(user_query)

    with st.chat_message("AI"):
        ai_response = get_response(user_query)
        st.markdown(ai_response)

    st.session_state.chat_history.append(HumanMessage(user_query))
    st.session_state.chat_history.append(AIMessage(ai_response))