import sys import os from contextlib import contextmanager from langchain.schema import Document from langgraph.graph import END, StateGraph from langchain_core.runnables.graph import MermaidDrawMethod from tomlkit import document from typing_extensions import TypedDict from typing import List from IPython.display import display, HTML, Image from celsius_csrd_chatbot.chains.esrs_categorization import ( make_esrs_categorization_node, ) from celsius_csrd_chatbot.chains.esrs_intent import ( make_esrs_intent_node, ) from celsius_csrd_chatbot.chains.retriever import make_retriever_node from celsius_csrd_chatbot.chains.answer_rag import make_rag_node class GraphState(TypedDict): """ Represents the state of our graph. """ query: str esrs_type: str answer: str documents: List[Document] def route_intent(state): esrs = state["esrs_type"] if esrs == "none": return "intent_esrs" elif esrs == "wrong_esrs": return "answer_rag_wrong" else: return "retrieve_documents" def make_graph_agent(llm, vectorstore): workflow = StateGraph(GraphState) # Define the node functions categorize_esrs = make_esrs_categorization_node() intent_esrs = make_esrs_intent_node(llm) retrieve_documents = make_retriever_node(vectorstore) answer_rag = make_rag_node(llm, wrong_esrs=False) answer_rag_wrong = make_rag_node(llm, wrong_esrs=True) # Define the nodes workflow.add_node("categorize_esrs", categorize_esrs) workflow.add_node("intent_esrs", intent_esrs) workflow.add_node("retrieve_documents", retrieve_documents) workflow.add_node("answer_rag", answer_rag) workflow.add_node("answer_rag_wrong", answer_rag_wrong) # Entry point workflow.set_entry_point("categorize_esrs") # CONDITIONAL EDGES workflow.add_conditional_edges("categorize_esrs", route_intent) # Define the edges workflow.add_edge("intent_esrs", "retrieve_documents") workflow.add_edge("retrieve_documents", "answer_rag") workflow.add_edge("answer_rag", END) workflow.add_edge("answer_rag_wrong", END) # Compile app = workflow.compile() return app def display_graph(app): display( Image( app.get_graph(xray=True).draw_mermaid_png( draw_method=MermaidDrawMethod.API, ) ) )