File size: 5,709 Bytes
b7e13eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cca07dd
b7e13eb
 
 
 
 
cca07dd
 
 
b7e13eb
 
 
 
 
 
 
 
 
621bc57
b7e13eb
 
 
 
 
 
 
 
621bc57
b7e13eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d49d2e
b7e13eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621bc57
b7e13eb
 
 
 
 
 
 
 
0ed257c
b7e13eb
 
 
 
 
 
 
 
 
 
 
 
 
f823183
b7e13eb
 
cca07dd
 
 
 
 
 
 
 
 
 
 
 
 
b7e13eb
 
 
 
3454a86
b7e13eb
 
 
 
cca07dd
b7e13eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf149c
 
 
633c39a
 
 
 
 
 
 
b7e13eb
b7341cc
 
b7e13eb
b7341cc
b7e13eb
b7341cc
58cbd1c
b7e13eb
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chains import RetrievalQA
from langchain.chains.conversation.memory import ConversationBufferMemory
from utils.ask_human import CustomAskHumanTool
from utils.model_params import get_model_params
from utils.prompts import create_agent_prompt, create_qa_prompt
from PyPDF2 import PdfReader
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain import HuggingFaceHub
import torch
import streamlit as st
from langchain.utilities import SerpAPIWrapper
from langchain.tools import DuckDuckGoSearchRun
import os
hf_token = os.environ['HF_TOKEN']
serp_token = os.environ['SERP_TOKEN']
repo_id = "sentence-transformers/all-mpnet-base-v2"




HUGGINGFACEHUB_API_TOKEN= hf_token
hf = HuggingFaceHubEmbeddings(
    repo_id=repo_id,
    task="feature-extraction",
    huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN,
)





llm = HuggingFaceHub(
        repo_id='mistralai/Mistral-7B-Instruct-v0.2',
        huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN,

    
)



from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
     
### PAGE ELEMENTS

# st.set_page_config(
#     page_title="RAG Agent Demo",
#     page_icon="🦜",
#     layout="centered",
#     initial_sidebar_state="collapsed",
# )
# st.markdown("### Leveraging the User to Improve Agents in RAG Use Cases")


def main():
     
    st.set_page_config(page_title="Ask your PDF powered by Search Agents")
    st.header("Ask your PDF powered by Search Agents 💬")
        
    # upload file
    pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf")
        
    # extract the text
    if pdf is not None:
        pdf_reader = PdfReader(pdf)
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text()
        
        # Split documents and create text snippets
            
        text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
        texts = text_splitter.split_text(text)

        embeddings = hf
        knowledge_base = FAISS.from_texts(texts, embeddings)

        retriever = knowledge_base.as_retriever(search_kwargs={"k":3})




        qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=retriever,
            return_source_documents=False,
            chain_type_kwargs={
                "prompt": create_qa_prompt(),
            },
        )

        conversational_memory = ConversationBufferMemory(
            memory_key="chat_history", k=3, return_messages=True
        )

        # tool for db search
        db_search_tool = Tool(
            name="dbRetrievalTool",
            func=qa_chain,
            description="""Use this tool to answer document related questions. The input to this tool should be the question.""",
        )

        # search = SerpAPIWrapper(serpapi_api_key=serp_token)

        # google_searchtool= Tool(
        #         name="Current Search",
        #         func=search.run,
        #         description="use this tool to answer real time or current search related questions.",
        #     )
        search = DuckDuckGoSearchRun()
        search_tool = Tool(
            name="search",
            func=search,
            description="use this tool to answer real time or current search related questions."
)
        # tool for asking human
        human_ask_tool = CustomAskHumanTool()
        # agent prompt
        prefix, format_instructions, suffix = create_agent_prompt()
        mode = "Agent with AskHuman tool"

        # initialize agent
        agent = initialize_agent(
            agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                    tools=[db_search_tool,search_tool],
            llm=llm,
            verbose=True,
            max_iterations=5,
            early_stopping_method="generate",
            memory=conversational_memory,
            agent_kwargs={
                "prefix": prefix,
                "format_instructions": format_instructions,
                "suffix": suffix,
            },
            handle_parsing_errors=True,

        )

        # question form
        with st.form(key="form"):
            user_input = st.text_input("Ask your question")
            submit_clicked = st.form_submit_button("Submit Question")

        # output container
        output_container = st.empty()
        if submit_clicked:
            # st_callback = StreamlitCallbackHandler(st.container())
            # response = agent.run(user_input,callbacks = [st_callback])
            response = agent.run(user_input)
            st.write(response)
            # output_container = output_container.container()
            # output_container.chat_message("user").write(user_input)
            # with st.chat_message("assistant"):
            #     st_callback = StreamlitCallbackHandler(st.container())
            #     response = agent.run(user_input, callbacks=[st_callback])
            #     st.write(response)

            # answer_container = output_container.chat_message("assistant", avatar="🦜")
            # st_callback = StreamlitCallbackHandler(answer_container,)

            # answer = agent.run(user_input, callbacks=[st_callback])

            # answer_container = output_container.container()
            # answer_container.chat_message("assistant").write(answer)



if __name__ == '__main__':
    main()