momenaca's picture
feature/major backend update with agent
8ca00e0
import sys
import os
from contextlib import contextmanager
from langchain.schema import Document
from langgraph.graph import END, StateGraph
from langchain_core.runnables.graph import MermaidDrawMethod
from tomlkit import document
from typing_extensions import TypedDict
from typing import List
from IPython.display import display, HTML, Image
from celsius_csrd_chatbot.chains.esrs_categorization import (
make_esrs_categorization_node,
)
from celsius_csrd_chatbot.chains.esrs_intent import (
make_esrs_intent_node,
)
from celsius_csrd_chatbot.chains.retriever import make_retriever_node
from celsius_csrd_chatbot.chains.answer_rag import make_rag_node
class GraphState(TypedDict):
"""
Represents the state of our graph.
"""
query: str
esrs_type: str
answer: str
documents: List[Document]
def route_intent(state):
esrs = state["esrs_type"]
if esrs == "none":
return "intent_esrs"
elif esrs == "wrong_esrs":
return "answer_rag_wrong"
else:
return "retrieve_documents"
def make_graph_agent(llm, vectorstore):
workflow = StateGraph(GraphState)
# Define the node functions
categorize_esrs = make_esrs_categorization_node()
intent_esrs = make_esrs_intent_node(llm)
retrieve_documents = make_retriever_node(vectorstore)
answer_rag = make_rag_node(llm, wrong_esrs=False)
answer_rag_wrong = make_rag_node(llm, wrong_esrs=True)
# Define the nodes
workflow.add_node("categorize_esrs", categorize_esrs)
workflow.add_node("intent_esrs", intent_esrs)
workflow.add_node("retrieve_documents", retrieve_documents)
workflow.add_node("answer_rag", answer_rag)
workflow.add_node("answer_rag_wrong", answer_rag_wrong)
# Entry point
workflow.set_entry_point("categorize_esrs")
# CONDITIONAL EDGES
workflow.add_conditional_edges("categorize_esrs", route_intent)
# Define the edges
workflow.add_edge("intent_esrs", "retrieve_documents")
workflow.add_edge("retrieve_documents", "answer_rag")
workflow.add_edge("answer_rag", END)
workflow.add_edge("answer_rag_wrong", END)
# Compile
app = workflow.compile()
return app
def display_graph(app):
display(
Image(
app.get_graph(xray=True).draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
)
)