File size: 9,341 Bytes
1b8dab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5299cfa
1b8dab4
 
 
 
 
 
 
5299cfa
1b8dab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5299cfa
1b8dab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5299cfa
1b8dab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import sys
import os
import uuid
from dotenv import load_dotenv
from typing import Annotated, List, Tuple
from typing_extensions import TypedDict
from langchain.tools import tool, BaseTool
from langchain.schema import Document
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate
# from langchain.schema import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
from langchain.retrievers.multi_query import MultiQueryRetriever
import json
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
import re
from fuzzywuzzy import fuzz



sys.path.append(os.path.abspath('..'))


import src.utils.qdrant_manager as qm
import prompts.system_prompts as sp
import prompts.quote_finder_prompts as qfp
load_dotenv('/Users/nadaa/Documents/code/py_innovations/srf_chatbot_v2/.env')


class AgentState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    documents: list[Document]
    query: str
    final_response: str


class ToolManager:
    def __init__(self, collection_name="openai_large_chunks_1500char"):
        self.tools = []
        self.qdrant = qm.QdrantManager(collection_name=collection_name)
        self.vectorstore = self.qdrant.get_vectorstore()
        self.add_tools()

    def get_tools(self):
        return self.tools

    def add_tools(self):
        @tool
        def vector_search(query: str, k: int = 10) -> list[Document]:
            """Useful for simple queries. This tool will search a vector database for passages from the teachings of Paramhansa Yogananda and other publications from the Self Realization Fellowship (SRF).
            The user has the option to specify the number of passages they want the search to return, otherwise the number of passages will be set to the default value."""
            retriever = self.vectorstore.as_retriever(search_kwargs={"k": k})
            documents = retriever.invoke(query)
            return documents
             
        @tool
        def multiple_query_vector_search(query: str, k: int = 10) -> list[Document]:
            """Useful when the user's query is vague, complex, or involves multiple concepts. 
            This tool will write multiple versions of the user's query and search the vector database for relevant passages. 
            Use this tool when the user asks for an in depth answer to their question."""
            
            llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)
            retriever_from_llm = MultiQueryRetriever.from_llm(retriever=self.vectorstore.as_retriever(), llm=llm)
            documents = retriever_from_llm.invoke(query)
            return documents
            
        self.tools.append(vector_search)
        self.tools.append(multiple_query_vector_search)

class BasicToolNode:
    """A node that runs the tools requested in the last AIMessage."""

    def __init__(self, tools: list) -> None:
        self.tools_by_name = {tool.name: tool for tool in tools}

    def __call__(self, inputs: dict):
        if messages := inputs.get("messages", []):
            message = messages[-1]
        else:
            raise ValueError("No message found in input")
        outputs = []
        documents = []
        for tool_call in message.tool_calls:
            tool_result = self.tools_by_name[tool_call["name"]].invoke(
                tool_call["args"]
            )
            outputs.append(
                ToolMessage(
                    content=str(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
            documents += tool_result
            
        return {"messages": outputs, "documents": documents}

# Create the Pydantic Model for the quote finder
class Quote(BaseModel):
    '''Most relevant quotes to the user's query strictly pulled verbatim from the context provided. Quotes can be up to three sentences long.'''
    quote: str

class QuoteList(BaseModel):
    quotes: List[Quote]


class QuoteFinder:
    def __init__(self, model: str = 'gpt-4o-mini', temperature: float = 0.5):
        self.quotes_prompt = qfp.quote_finder_prompt
        self.llm = ChatOpenAI(model=model, temperature=temperature)
        self.llm_with_quotes_output = self.llm.with_structured_output(QuoteList)
        self.quote_finder_chain = self.quotes_prompt | self.llm_with_quotes_output
    

    def find_quotes_per_document(self, state: AgentState):
        docs = state["documents"]
        query = state["query"]
        for doc in docs:
            passage = doc.page_content
            quotes = self.quote_finder_chain.invoke({"query": query, "passage": passage})
            doc.metadata["quotes"] = quotes
        
        return {"documents": docs}
    
    def _highlight_quotes(self, document, quotes):
        highlighted_content = document.page_content
        matched_quotes = []
        for quote in quotes.quotes:
            # Fuzzy match the quote in the document
            best_match = None
            best_ratio = 0
            for i in range(len(highlighted_content)):
                substring = highlighted_content[i:i+len(quote.quote)]
                ratio = fuzz.ratio(substring.lower(), quote.quote.lower())
                if ratio > best_ratio:
                    best_ratio = ratio
                    best_match = substring

            if best_match and best_ratio > 90:  # Adjust threshold as needed
                # Escape special regex characters in the best match
                escaped_match = re.escape(best_match)
                # Replace the matched text with highlighted version
                highlighted_content = re.sub(
                    escaped_match,
                    f"<mark>{best_match}</mark>",
                    highlighted_content,
                    flags=re.IGNORECASE
                )
                matched_quotes.append(quote)

        return highlighted_content, matched_quotes


    def highlight_quotes_in_document(self, state: AgentState):
        docs = state["documents"]
        for doc in docs:
            quotes = doc.metadata["quotes"]
            annotated_passage, matched_quotes = self._highlight_quotes(doc, quotes)
            doc.metadata["highlighted_content"] = annotated_passage
            doc.metadata["matched_quotes"] = matched_quotes
                
        return {"documents": docs}
    

    def final_response_formatter(self, state: AgentState):
        docs = state["documents"]
        final_response = ""
        for doc in docs:
            final_response += doc.metadata["publication_name"] + ": " + doc.metadata["chapter_name"] + "\n" + doc.metadata["highlighted_content"] + "\n\n"

        return {"final_response": final_response}


class PassageFinder:
    def __init__(
        self,
        model: str = 'gpt-4o-mini',
        temperature: float = 0.5,
    ):

        self.llm = ChatOpenAI(model=model, temperature=temperature)
        self.tools = ToolManager().get_tools()
        self.llm_with_tools = self.llm.bind_tools(self.tools)
        self.quote_finder = QuoteFinder()

        # Build the graph
        self.graph = self.build_graph()
 

    def get_configurable(self):
        # This thread id is used to keep track of the chatbot's conversation
        self.thread_id = str(uuid.uuid4())
        return {"configurable": {"thread_id": self.thread_id}}
    
    
    def chatbot(self, state: AgentState):
        messages = state["messages"]
        original_query = messages[0].content
        return {"messages": [self.llm_with_tools.invoke(messages)], "query": original_query}
    
    def build_graph(self):
        # Add chatbot state
        graph_builder = StateGraph(AgentState)

        # Add nodes
        tool_node = BasicToolNode(tools=self.tools)
        # tool_node = ToolNode(self.tools)
        graph_builder.add_node("tools", tool_node)
        graph_builder.add_node("chatbot", self.chatbot)
        graph_builder.add_node("quote_finder", self.quote_finder.find_quotes_per_document)
        graph_builder.add_node("quote_highlighter", self.quote_finder.highlight_quotes_in_document)
        graph_builder.add_node("final_response_formatter", self.quote_finder.final_response_formatter)

        # Add a conditional edge wherein the chatbot can decide whether or not to go to the tools
        graph_builder.add_conditional_edges(
            "chatbot",
            tools_condition,
        )

        # Add fixed edges
        graph_builder.add_edge(START, "chatbot")
        graph_builder.add_edge("tools", "quote_finder")
        graph_builder.add_edge("quote_finder", "quote_highlighter")
        graph_builder.add_edge("quote_highlighter", "final_response_formatter")
        graph_builder.add_edge("final_response_formatter", END)

        # Instantiate the memory saver
        memory = MemorySaver()

        # Compile the graph
        return graph_builder.compile(checkpointer=memory)