Spaces:
Sleeping
Sleeping
import sys | |
import os | |
from contextlib import contextmanager | |
from langchain.schema import Document | |
from langgraph.graph import END, StateGraph | |
from langchain_core.runnables.graph import MermaidDrawMethod | |
from tomlkit import document | |
from typing_extensions import TypedDict | |
from typing import List | |
from IPython.display import display, HTML, Image | |
from celsius_csrd_chatbot.chains.esrs_categorization import ( | |
make_esrs_categorization_node, | |
) | |
from celsius_csrd_chatbot.chains.esrs_intent import ( | |
make_esrs_intent_node, | |
) | |
from celsius_csrd_chatbot.chains.retriever import make_retriever_node | |
from celsius_csrd_chatbot.chains.answer_rag import make_rag_node | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
""" | |
query: str | |
esrs_type: str | |
answer: str | |
documents: List[Document] | |
def route_intent(state): | |
esrs = state["esrs_type"] | |
if esrs == "none": | |
return "intent_esrs" | |
elif esrs == "wrong_esrs": | |
return "answer_rag_wrong" | |
else: | |
return "retrieve_documents" | |
def make_graph_agent(llm, vectorstore): | |
workflow = StateGraph(GraphState) | |
# Define the node functions | |
categorize_esrs = make_esrs_categorization_node() | |
intent_esrs = make_esrs_intent_node(llm) | |
retrieve_documents = make_retriever_node(vectorstore) | |
answer_rag = make_rag_node(llm, wrong_esrs=False) | |
answer_rag_wrong = make_rag_node(llm, wrong_esrs=True) | |
# Define the nodes | |
workflow.add_node("categorize_esrs", categorize_esrs) | |
workflow.add_node("intent_esrs", intent_esrs) | |
workflow.add_node("retrieve_documents", retrieve_documents) | |
workflow.add_node("answer_rag", answer_rag) | |
workflow.add_node("answer_rag_wrong", answer_rag_wrong) | |
# Entry point | |
workflow.set_entry_point("categorize_esrs") | |
# CONDITIONAL EDGES | |
workflow.add_conditional_edges("categorize_esrs", route_intent) | |
# Define the edges | |
workflow.add_edge("intent_esrs", "retrieve_documents") | |
workflow.add_edge("retrieve_documents", "answer_rag") | |
workflow.add_edge("answer_rag", END) | |
workflow.add_edge("answer_rag_wrong", END) | |
# Compile | |
app = workflow.compile() | |
return app | |
def display_graph(app): | |
display( | |
Image( | |
app.get_graph(xray=True).draw_mermaid_png( | |
draw_method=MermaidDrawMethod.API, | |
) | |
) | |
) | |