Spaces:
Running
Running
import sys | |
import os | |
import uuid | |
from dotenv import load_dotenv | |
from typing import Annotated, List, Tuple | |
from typing_extensions import TypedDict | |
from langchain.tools import tool, BaseTool | |
from langchain.schema import Document | |
from langgraph.graph import StateGraph, START, END, MessagesState | |
from langgraph.graph.message import add_messages | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate | |
# from langchain.schema import SystemMessage, HumanMessage, AIMessage, ToolMessage | |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage | |
from langchain.retrievers.multi_query import MultiQueryRetriever | |
import json | |
sys.path.append(os.path.abspath('..')) | |
import src.utils.qdrant_manager as qm | |
import prompts.system_prompts as sp | |
load_dotenv('/Users/nadaa/Documents/code/py_innovations/srf_chatbot_v2/.env') | |
class ToolManager: | |
def __init__(self, collection_name="openai_large_chunks_1000char"): | |
self.tools = [] | |
self.qdrant = qm.QdrantManager(collection_name=collection_name) | |
self.vectorstore = self.qdrant.get_vectorstore() | |
self.add_tools() | |
def get_tools(self): | |
return self.tools | |
def add_tools(self): | |
def vector_search(query: str, k: int = 5) -> list[Document]: | |
"""Useful for simple queries. This tool will search a vector database for passages from the teachings of Paramhansa Yogananda and other publications from the Self Realization Fellowship (SRF). | |
The user has the option to specify the number of passages they want the search to return, otherwise the number of passages will be set to the default value.""" | |
retriever = self.vectorstore.as_retriever(search_kwargs={"k": k}) | |
documents = retriever.invoke(query) | |
return documents | |
def multiple_query_vector_search(query: str, k: int = 5) -> list[Document]: | |
"""Useful when the user's query is vague, complex, or involves multiple concepts. | |
This tool will write multiple versions of the user's query and search the vector database for relevant passages. | |
Use this tool when the user asks for an in depth answer to their question.""" | |
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5) | |
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=self.vectorstore.as_retriever(), llm=llm) | |
documents = retriever_from_llm.invoke(query) | |
return documents | |
self.tools.append(vector_search) | |
self.tools.append(multiple_query_vector_search) | |
class BasicToolNode: | |
"""A node that runs the tools requested in the last AIMessage.""" | |
def __init__(self, tools: list) -> None: | |
self.tools_by_name = {tool.name: tool for tool in tools} | |
def __call__(self, inputs: dict): | |
if messages := inputs.get("messages", []): | |
message = messages[-1] | |
else: | |
raise ValueError("No message found in input") | |
outputs = [] | |
documents = [] | |
for tool_call in message.tool_calls: | |
tool_result = self.tools_by_name[tool_call["name"]].invoke( | |
tool_call["args"] | |
) | |
outputs.append( | |
ToolMessage( | |
content=str(tool_result), | |
name=tool_call["name"], | |
tool_call_id=tool_call["id"], | |
) | |
) | |
documents += tool_result | |
return {"messages": outputs, "documents": documents} | |
class AgentState(TypedDict): | |
messages: Annotated[list, add_messages] | |
documents: list[Document] | |
system_message: list[SystemMessage] | |
system_message_dropdown: list[str] | |
class GenericChatbot: | |
def __init__( | |
self, | |
model: str = 'gpt-4o-mini', | |
temperature: float = 0, | |
): | |
self.llm = ChatOpenAI(model=model, temperature=temperature) | |
self.tools = ToolManager().get_tools() | |
self.llm_with_tools = self.llm.bind_tools(self.tools) | |
# Build the graph | |
self.graph = self.build_graph() | |
# Get the configurable | |
self.config = self.get_configurable() | |
def get_configurable(self): | |
# This thread id is used to keep track of the chatbot's conversation | |
self.thread_id = str(uuid.uuid4()) | |
return {"configurable": {"thread_id": self.thread_id}} | |
# Add the system message onto the llm | |
## THIS SHOULD BE REFACTORED SO THAT THE STATE ALWAYS HAS THE DEFINITIVE SYSTEM MESSAGE THAT SHOULD BE IN USE | |
def chatbot(self, state: AgentState): | |
messages = state["messages"] | |
return {"messages": [self.llm_with_tools.invoke(messages)]} | |
def build_graph(self): | |
# Add chatbot state | |
graph_builder = StateGraph(AgentState) | |
# Add nodes | |
tool_node = BasicToolNode(tools=self.tools) | |
# tool_node = ToolNode(self.tools) | |
graph_builder.add_node("tools", tool_node) | |
graph_builder.add_node("chatbot", self.chatbot) | |
# Add a conditional edge wherein the chatbot can decide whether or not to go to the tools | |
graph_builder.add_conditional_edges( | |
"chatbot", | |
tools_condition, | |
) | |
# Add fixed edges | |
graph_builder.add_edge(START, "chatbot") | |
graph_builder.add_edge("tools", "chatbot") | |
# Instantiate the memory saver | |
memory = MemorySaver() | |
# Compile the graph | |
return graph_builder.compile(checkpointer=memory) | |