Spaces:
Running
Running
File size: 9,341 Bytes
1b8dab4 5299cfa 1b8dab4 5299cfa 1b8dab4 5299cfa 1b8dab4 5299cfa 1b8dab4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import sys
import os
import uuid
from dotenv import load_dotenv
from typing import Annotated, List, Tuple
from typing_extensions import TypedDict
from langchain.tools import tool, BaseTool
from langchain.schema import Document
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate
# from langchain.schema import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
from langchain.retrievers.multi_query import MultiQueryRetriever
import json
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
import re
from fuzzywuzzy import fuzz
sys.path.append(os.path.abspath('..'))
import src.utils.qdrant_manager as qm
import prompts.system_prompts as sp
import prompts.quote_finder_prompts as qfp
load_dotenv('/Users/nadaa/Documents/code/py_innovations/srf_chatbot_v2/.env')
class AgentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
documents: list[Document]
query: str
final_response: str
class ToolManager:
def __init__(self, collection_name="openai_large_chunks_1500char"):
self.tools = []
self.qdrant = qm.QdrantManager(collection_name=collection_name)
self.vectorstore = self.qdrant.get_vectorstore()
self.add_tools()
def get_tools(self):
return self.tools
def add_tools(self):
@tool
def vector_search(query: str, k: int = 10) -> list[Document]:
"""Useful for simple queries. This tool will search a vector database for passages from the teachings of Paramhansa Yogananda and other publications from the Self Realization Fellowship (SRF).
The user has the option to specify the number of passages they want the search to return, otherwise the number of passages will be set to the default value."""
retriever = self.vectorstore.as_retriever(search_kwargs={"k": k})
documents = retriever.invoke(query)
return documents
@tool
def multiple_query_vector_search(query: str, k: int = 10) -> list[Document]:
"""Useful when the user's query is vague, complex, or involves multiple concepts.
This tool will write multiple versions of the user's query and search the vector database for relevant passages.
Use this tool when the user asks for an in depth answer to their question."""
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=self.vectorstore.as_retriever(), llm=llm)
documents = retriever_from_llm.invoke(query)
return documents
self.tools.append(vector_search)
self.tools.append(multiple_query_vector_search)
class BasicToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
documents = []
for tool_call in message.tool_calls:
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
outputs.append(
ToolMessage(
content=str(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
documents += tool_result
return {"messages": outputs, "documents": documents}
# Create the Pydantic Model for the quote finder
class Quote(BaseModel):
'''Most relevant quotes to the user's query strictly pulled verbatim from the context provided. Quotes can be up to three sentences long.'''
quote: str
class QuoteList(BaseModel):
quotes: List[Quote]
class QuoteFinder:
def __init__(self, model: str = 'gpt-4o-mini', temperature: float = 0.5):
self.quotes_prompt = qfp.quote_finder_prompt
self.llm = ChatOpenAI(model=model, temperature=temperature)
self.llm_with_quotes_output = self.llm.with_structured_output(QuoteList)
self.quote_finder_chain = self.quotes_prompt | self.llm_with_quotes_output
def find_quotes_per_document(self, state: AgentState):
docs = state["documents"]
query = state["query"]
for doc in docs:
passage = doc.page_content
quotes = self.quote_finder_chain.invoke({"query": query, "passage": passage})
doc.metadata["quotes"] = quotes
return {"documents": docs}
def _highlight_quotes(self, document, quotes):
highlighted_content = document.page_content
matched_quotes = []
for quote in quotes.quotes:
# Fuzzy match the quote in the document
best_match = None
best_ratio = 0
for i in range(len(highlighted_content)):
substring = highlighted_content[i:i+len(quote.quote)]
ratio = fuzz.ratio(substring.lower(), quote.quote.lower())
if ratio > best_ratio:
best_ratio = ratio
best_match = substring
if best_match and best_ratio > 90: # Adjust threshold as needed
# Escape special regex characters in the best match
escaped_match = re.escape(best_match)
# Replace the matched text with highlighted version
highlighted_content = re.sub(
escaped_match,
f"<mark>{best_match}</mark>",
highlighted_content,
flags=re.IGNORECASE
)
matched_quotes.append(quote)
return highlighted_content, matched_quotes
def highlight_quotes_in_document(self, state: AgentState):
docs = state["documents"]
for doc in docs:
quotes = doc.metadata["quotes"]
annotated_passage, matched_quotes = self._highlight_quotes(doc, quotes)
doc.metadata["highlighted_content"] = annotated_passage
doc.metadata["matched_quotes"] = matched_quotes
return {"documents": docs}
def final_response_formatter(self, state: AgentState):
docs = state["documents"]
final_response = ""
for doc in docs:
final_response += doc.metadata["publication_name"] + ": " + doc.metadata["chapter_name"] + "\n" + doc.metadata["highlighted_content"] + "\n\n"
return {"final_response": final_response}
class PassageFinder:
def __init__(
self,
model: str = 'gpt-4o-mini',
temperature: float = 0.5,
):
self.llm = ChatOpenAI(model=model, temperature=temperature)
self.tools = ToolManager().get_tools()
self.llm_with_tools = self.llm.bind_tools(self.tools)
self.quote_finder = QuoteFinder()
# Build the graph
self.graph = self.build_graph()
def get_configurable(self):
# This thread id is used to keep track of the chatbot's conversation
self.thread_id = str(uuid.uuid4())
return {"configurable": {"thread_id": self.thread_id}}
def chatbot(self, state: AgentState):
messages = state["messages"]
original_query = messages[0].content
return {"messages": [self.llm_with_tools.invoke(messages)], "query": original_query}
def build_graph(self):
# Add chatbot state
graph_builder = StateGraph(AgentState)
# Add nodes
tool_node = BasicToolNode(tools=self.tools)
# tool_node = ToolNode(self.tools)
graph_builder.add_node("tools", tool_node)
graph_builder.add_node("chatbot", self.chatbot)
graph_builder.add_node("quote_finder", self.quote_finder.find_quotes_per_document)
graph_builder.add_node("quote_highlighter", self.quote_finder.highlight_quotes_in_document)
graph_builder.add_node("final_response_formatter", self.quote_finder.final_response_formatter)
# Add a conditional edge wherein the chatbot can decide whether or not to go to the tools
graph_builder.add_conditional_edges(
"chatbot",
tools_condition,
)
# Add fixed edges
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("tools", "quote_finder")
graph_builder.add_edge("quote_finder", "quote_highlighter")
graph_builder.add_edge("quote_highlighter", "final_response_formatter")
graph_builder.add_edge("final_response_formatter", END)
# Instantiate the memory saver
memory = MemorySaver()
# Compile the graph
return graph_builder.compile(checkpointer=memory)
|