srf_chatbot_v2 / src /generic_bot.py
nadaaaita's picture
changed qdrant collection to 1000char
a0d6c77
raw
history blame
5.79 kB
import sys
import os
import uuid
from dotenv import load_dotenv
from typing import Annotated, List, Tuple
from typing_extensions import TypedDict
from langchain.tools import tool, BaseTool
from langchain.schema import Document
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate
# from langchain.schema import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
from langchain.retrievers.multi_query import MultiQueryRetriever
import json
sys.path.append(os.path.abspath('..'))
import src.utils.qdrant_manager as qm
import prompts.system_prompts as sp
load_dotenv('/Users/nadaa/Documents/code/py_innovations/srf_chatbot_v2/.env')
class ToolManager:
def __init__(self, collection_name="openai_large_chunks_1000char"):
self.tools = []
self.qdrant = qm.QdrantManager(collection_name=collection_name)
self.vectorstore = self.qdrant.get_vectorstore()
self.add_tools()
def get_tools(self):
return self.tools
def add_tools(self):
@tool
def vector_search(query: str, k: int = 5) -> list[Document]:
"""Useful for simple queries. This tool will search a vector database for passages from the teachings of Paramhansa Yogananda and other publications from the Self Realization Fellowship (SRF).
The user has the option to specify the number of passages they want the search to return, otherwise the number of passages will be set to the default value."""
retriever = self.vectorstore.as_retriever(search_kwargs={"k": k})
documents = retriever.invoke(query)
return documents
@tool
def multiple_query_vector_search(query: str, k: int = 5) -> list[Document]:
"""Useful when the user's query is vague, complex, or involves multiple concepts.
This tool will write multiple versions of the user's query and search the vector database for relevant passages.
Use this tool when the user asks for an in depth answer to their question."""
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=self.vectorstore.as_retriever(), llm=llm)
documents = retriever_from_llm.invoke(query)
return documents
self.tools.append(vector_search)
self.tools.append(multiple_query_vector_search)
class BasicToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
documents = []
for tool_call in message.tool_calls:
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
outputs.append(
ToolMessage(
content=str(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
documents += tool_result
return {"messages": outputs, "documents": documents}
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
documents: list[Document]
system_message: list[SystemMessage]
system_message_dropdown: list[str]
class GenericChatbot:
def __init__(
self,
model: str = 'gpt-4o-mini',
temperature: float = 0,
):
self.llm = ChatOpenAI(model=model, temperature=temperature)
self.tools = ToolManager().get_tools()
self.llm_with_tools = self.llm.bind_tools(self.tools)
# Build the graph
self.graph = self.build_graph()
# Get the configurable
self.config = self.get_configurable()
def get_configurable(self):
# This thread id is used to keep track of the chatbot's conversation
self.thread_id = str(uuid.uuid4())
return {"configurable": {"thread_id": self.thread_id}}
# Add the system message onto the llm
## THIS SHOULD BE REFACTORED SO THAT THE STATE ALWAYS HAS THE DEFINITIVE SYSTEM MESSAGE THAT SHOULD BE IN USE
def chatbot(self, state: AgentState):
messages = state["messages"]
return {"messages": [self.llm_with_tools.invoke(messages)]}
def build_graph(self):
# Add chatbot state
graph_builder = StateGraph(AgentState)
# Add nodes
tool_node = BasicToolNode(tools=self.tools)
# tool_node = ToolNode(self.tools)
graph_builder.add_node("tools", tool_node)
graph_builder.add_node("chatbot", self.chatbot)
# Add a conditional edge wherein the chatbot can decide whether or not to go to the tools
graph_builder.add_conditional_edges(
"chatbot",
tools_condition,
)
# Add fixed edges
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("tools", "chatbot")
# Instantiate the memory saver
memory = MemorySaver()
# Compile the graph
return graph_builder.compile(checkpointer=memory)