|
import sys |
|
import os |
|
from contextlib import contextmanager |
|
|
|
from langchain.schema import Document |
|
from langgraph.graph import END, StateGraph |
|
from langchain_core.runnables.graph import MermaidDrawMethod |
|
|
|
from tomlkit import document |
|
from typing_extensions import TypedDict |
|
from typing import List |
|
|
|
from IPython.display import display, HTML, Image |
|
|
|
from celsius_csrd_chatbot.chains.esrs_categorization import ( |
|
make_esrs_categorization_node, |
|
) |
|
from celsius_csrd_chatbot.chains.esrs_intent import ( |
|
make_esrs_intent_node, |
|
) |
|
from celsius_csrd_chatbot.chains.retriever import make_retriever_node |
|
from celsius_csrd_chatbot.chains.answer_rag import make_rag_node |
|
|
|
|
|
class GraphState(TypedDict): |
|
""" |
|
Represents the state of our graph. |
|
""" |
|
|
|
query: str |
|
esrs_type: str |
|
answer: str |
|
documents: List[Document] |
|
|
|
|
|
def route_intent(state): |
|
esrs = state["esrs_type"] |
|
if esrs == "none": |
|
return "intent_esrs" |
|
|
|
elif esrs == "wrong_esrs": |
|
return "answer_rag" |
|
|
|
else: |
|
return "retrieve_documents" |
|
|
|
|
|
def make_id_dict(values): |
|
return {k: k for k in values} |
|
|
|
|
|
def make_graph_agent(llm, vectorstore): |
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
categorize_esrs = make_esrs_categorization_node() |
|
intent_esrs = make_esrs_intent_node(llm) |
|
retrieve_documents = make_retriever_node(vectorstore) |
|
answer_rag = make_rag_node(llm, wrong_esrs=False) |
|
answer_rag_wrong = make_rag_node(llm, wrong_esrs=True) |
|
|
|
|
|
workflow.add_node("categorize_esrs", categorize_esrs) |
|
workflow.add_node("intent_esrs", intent_esrs) |
|
workflow.add_node("retrieve_documents", retrieve_documents) |
|
workflow.add_node("answer_rag", answer_rag) |
|
workflow.add_node("answer_rag_wrong", answer_rag_wrong) |
|
|
|
|
|
workflow.set_entry_point("categorize_esrs") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"categorize_esrs", |
|
route_intent, |
|
make_id_dict(["intent_esrs", "retrieve_documents", "answer_rag_wrong"]), |
|
) |
|
|
|
|
|
workflow.add_edge("intent_esrs", "retrieve_documents") |
|
workflow.add_edge("retrieve_documents", "answer_rag") |
|
workflow.add_edge("answer_rag", END) |
|
workflow.add_edge("answer_rag_wrong", END) |
|
|
|
|
|
app = workflow.compile() |
|
return app |
|
|
|
|
|
def display_graph(app): |
|
|
|
display( |
|
Image( |
|
app.get_graph(xray=True).draw_mermaid_png( |
|
draw_method=MermaidDrawMethod.API, |
|
) |
|
) |
|
) |
|
|