momenaca's picture
add latest updates regarding agent mode
d708cb9
raw
history blame
2.54 kB
import sys
import os
from contextlib import contextmanager
from langchain.schema import Document
from langgraph.graph import END, StateGraph
from langchain_core.runnables.graph import MermaidDrawMethod
from tomlkit import document
from typing_extensions import TypedDict
from typing import List
from IPython.display import display, HTML, Image
from celsius_csrd_chatbot.chains.esrs_categorization import (
make_esrs_categorization_node,
)
from celsius_csrd_chatbot.chains.esrs_intent import (
make_esrs_intent_node,
)
from celsius_csrd_chatbot.chains.retriever import make_retriever_node
from celsius_csrd_chatbot.chains.answer_rag import make_rag_node
class GraphState(TypedDict):
"""
Represents the state of our graph.
"""
query: str
esrs_type: str
answer: str
documents: List[Document]
def route_intent(state):
esrs = state["esrs_type"]
if esrs == "none":
return "intent_esrs"
elif esrs == "wrong_esrs":
return "answer_rag"
else:
return "retrieve_documents"
def make_id_dict(values):
return {k: k for k in values}
def make_graph_agent(llm, vectorstore):
workflow = StateGraph(GraphState)
# Define the node functions
categorize_esrs = make_esrs_categorization_node()
intent_esrs = make_esrs_intent_node(llm)
retrieve_documents = make_retriever_node(vectorstore)
answer_rag = make_rag_node(llm, wrong_esrs=False)
answer_rag_wrong = make_rag_node(llm, wrong_esrs=True)
# Define the nodes
workflow.add_node("categorize_esrs", categorize_esrs)
workflow.add_node("intent_esrs", intent_esrs)
workflow.add_node("retrieve_documents", retrieve_documents)
workflow.add_node("answer_rag", answer_rag)
workflow.add_node("answer_rag_wrong", answer_rag_wrong)
# Entry point
workflow.set_entry_point("categorize_esrs")
# CONDITIONAL EDGES
workflow.add_conditional_edges(
"categorize_esrs",
route_intent,
make_id_dict(["intent_esrs", "retrieve_documents", "answer_rag_wrong"]),
)
# Define the edges
workflow.add_edge("intent_esrs", "retrieve_documents")
workflow.add_edge("retrieve_documents", "answer_rag")
workflow.add_edge("answer_rag", END)
workflow.add_edge("answer_rag_wrong", END)
# Compile
app = workflow.compile()
return app
def display_graph(app):
display(
Image(
app.get_graph(xray=True).draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
)
)