Spaces:
Running
Running
import sys | |
import os | |
import uuid | |
from dotenv import load_dotenv | |
from typing import Annotated, List, Tuple | |
from typing_extensions import TypedDict | |
from langchain.tools import tool, BaseTool | |
from langchain.schema import Document | |
from langgraph.graph import StateGraph, START, END, MessagesState | |
from langgraph.graph.message import add_messages | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate | |
# from langchain.schema import SystemMessage, HumanMessage, AIMessage, ToolMessage | |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage | |
from langchain.retrievers.multi_query import MultiQueryRetriever | |
import json | |
from langchain_core.messages import BaseMessage | |
from pydantic import BaseModel | |
import re | |
from fuzzywuzzy import fuzz | |
sys.path.append(os.path.abspath('..')) | |
import src.utils.qdrant_manager as qm | |
import prompts.system_prompts as sp | |
import prompts.quote_finder_prompts as qfp | |
load_dotenv('/Users/nadaa/Documents/code/py_innovations/srf_chatbot_v2/.env') | |
class AgentState(TypedDict): | |
messages: Annotated[list[BaseMessage], add_messages] | |
documents: list[Document] | |
query: str | |
final_response: str | |
class ToolManager: | |
def __init__(self, collection_name="openai_large_chunks_1000char"): | |
self.tools = [] | |
self.qdrant = qm.QdrantManager(collection_name=collection_name) | |
self.vectorstore = self.qdrant.get_vectorstore() | |
self.add_tools() | |
def get_tools(self): | |
return self.tools | |
def add_tools(self): | |
def vector_search(query: str, k: int = 10) -> list[Document]: | |
"""Useful for simple queries. This tool will search a vector database for passages from the teachings of Paramhansa Yogananda and other publications from the Self Realization Fellowship (SRF). | |
The user has the option to specify the number of passages they want the search to return, otherwise the number of passages will be set to the default value.""" | |
retriever = self.vectorstore.as_retriever(search_kwargs={"k": k}) | |
documents = retriever.invoke(query) | |
return documents | |
def multiple_query_vector_search(query: str, k: int = 10) -> list[Document]: | |
"""Useful when the user's query is vague, complex, or involves multiple concepts. | |
This tool will write multiple versions of the user's query and search the vector database for relevant passages. | |
Use this tool when the user asks for an in depth answer to their question.""" | |
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5) | |
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=self.vectorstore.as_retriever(), llm=llm) | |
documents = retriever_from_llm.invoke(query) | |
return documents | |
self.tools.append(vector_search) | |
self.tools.append(multiple_query_vector_search) | |
class BasicToolNode: | |
"""A node that runs the tools requested in the last AIMessage.""" | |
def __init__(self, tools: list) -> None: | |
self.tools_by_name = {tool.name: tool for tool in tools} | |
def __call__(self, inputs: dict): | |
if messages := inputs.get("messages", []): | |
message = messages[-1] | |
else: | |
raise ValueError("No message found in input") | |
outputs = [] | |
documents = [] | |
for tool_call in message.tool_calls: | |
tool_result = self.tools_by_name[tool_call["name"]].invoke( | |
tool_call["args"] | |
) | |
outputs.append( | |
ToolMessage( | |
content=str(tool_result), | |
name=tool_call["name"], | |
tool_call_id=tool_call["id"], | |
) | |
) | |
documents += tool_result | |
return {"messages": outputs, "documents": documents} | |
# Create the Pydantic Model for the quote finder | |
class Quote(BaseModel): | |
'''Most relevant quotes to the user's query strictly pulled verbatim from the context provided. Quotes can be up to three sentences long.''' | |
quote: str | |
class QuoteList(BaseModel): | |
quotes: List[Quote] | |
class QuoteFinder: | |
def __init__(self, model: str = 'gpt-4o-mini', temperature: float = 0.5): | |
self.quotes_prompt = qfp.quote_finder_prompt | |
self.llm = ChatOpenAI(model=model, temperature=temperature) | |
self.llm_with_quotes_output = self.llm.with_structured_output(QuoteList) | |
self.quote_finder_chain = self.quotes_prompt | self.llm_with_quotes_output | |
def find_quotes_per_document(self, state: AgentState): | |
docs = state["documents"] | |
query = state["query"] | |
for doc in docs: | |
passage = doc.page_content | |
quotes = self.quote_finder_chain.invoke({"query": query, "passage": passage}) | |
doc.metadata["quotes"] = quotes | |
return {"documents": docs} | |
def _highlight_quotes(self, document, quotes): | |
highlighted_content = document.page_content | |
matched_quotes = [] | |
for quote in quotes.quotes: | |
# Fuzzy match the quote in the document | |
best_match = None | |
best_ratio = 0 | |
for i in range(len(highlighted_content)): | |
substring = highlighted_content[i:i+len(quote.quote)] | |
ratio = fuzz.ratio(substring.lower(), quote.quote.lower()) | |
if ratio > best_ratio: | |
best_ratio = ratio | |
best_match = substring | |
if best_match and best_ratio > 90: # Adjust threshold as needed | |
# Escape special regex characters in the best match | |
escaped_match = re.escape(best_match) | |
# Replace the matched text with highlighted version | |
highlighted_content = re.sub( | |
escaped_match, | |
f"<mark>{best_match}</mark>", | |
highlighted_content, | |
flags=re.IGNORECASE | |
) | |
matched_quotes.append(quote) | |
return highlighted_content, matched_quotes | |
def highlight_quotes_in_document(self, state: AgentState): | |
docs = state["documents"] | |
for doc in docs: | |
quotes = doc.metadata["quotes"] | |
annotated_passage, matched_quotes = self._highlight_quotes(doc, quotes) | |
doc.metadata["highlighted_content"] = annotated_passage | |
doc.metadata["matched_quotes"] = matched_quotes | |
return {"documents": docs} | |
def final_response_formatter(self, state: AgentState): | |
docs = state["documents"] | |
final_response = "" | |
for doc in docs: | |
final_response += doc.metadata["publication_name"] + ": " + doc.metadata["chapter_name"] + "\n" + doc.metadata["highlighted_content"] + "\n\n" | |
return {"final_response": final_response} | |
class PassageFinder: | |
def __init__( | |
self, | |
model: str = 'gpt-4o-mini', | |
temperature: float = 0.5, | |
): | |
self.llm = ChatOpenAI(model=model, temperature=temperature) | |
self.tools = ToolManager().get_tools() | |
self.llm_with_tools = self.llm.bind_tools(self.tools) | |
self.quote_finder = QuoteFinder() | |
# Build the graph | |
self.graph = self.build_graph() | |
def get_configurable(self): | |
# This thread id is used to keep track of the chatbot's conversation | |
self.thread_id = str(uuid.uuid4()) | |
return {"configurable": {"thread_id": self.thread_id}} | |
def chatbot(self, state: AgentState): | |
messages = state["messages"] | |
original_query = messages[0].content | |
return {"messages": [self.llm_with_tools.invoke(messages)], "query": original_query} | |
def build_graph(self): | |
# Add chatbot state | |
graph_builder = StateGraph(AgentState) | |
# Add nodes | |
tool_node = BasicToolNode(tools=self.tools) | |
# tool_node = ToolNode(self.tools) | |
graph_builder.add_node("tools", tool_node) | |
graph_builder.add_node("chatbot", self.chatbot) | |
graph_builder.add_node("quote_finder", self.quote_finder.find_quotes_per_document) | |
graph_builder.add_node("quote_highlighter", self.quote_finder.highlight_quotes_in_document) | |
graph_builder.add_node("final_response_formatter", self.quote_finder.final_response_formatter) | |
# Add a conditional edge wherein the chatbot can decide whether or not to go to the tools | |
graph_builder.add_conditional_edges( | |
"chatbot", | |
tools_condition, | |
) | |
# Add fixed edges | |
graph_builder.add_edge(START, "chatbot") | |
graph_builder.add_edge("tools", "quote_finder") | |
graph_builder.add_edge("quote_finder", "quote_highlighter") | |
graph_builder.add_edge("quote_highlighter", "final_response_formatter") | |
graph_builder.add_edge("final_response_formatter", END) | |
# Instantiate the memory saver | |
memory = MemorySaver() | |
# Compile the graph | |
return graph_builder.compile(checkpointer=memory) | |