File size: 2,380 Bytes
7bfa7e6 d708cb9 7bfa7e6 d708cb9 7bfa7e6 d708cb9 7bfa7e6 d708cb9 8ca00e0 d708cb9 7bfa7e6 d708cb9 7bfa7e6 d708cb9 7bfa7e6 d708cb9 7bfa7e6 d708cb9 7bfa7e6 d708cb9 8ca00e0 d708cb9 7bfa7e6 d708cb9 7bfa7e6 d708cb9 7bfa7e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import sys
import os
from contextlib import contextmanager
from langchain.schema import Document
from langgraph.graph import END, StateGraph
from langchain_core.runnables.graph import MermaidDrawMethod
from tomlkit import document
from typing_extensions import TypedDict
from typing import List
from IPython.display import display, HTML, Image
from celsius_csrd_chatbot.chains.esrs_categorization import (
make_esrs_categorization_node,
)
from celsius_csrd_chatbot.chains.esrs_intent import (
make_esrs_intent_node,
)
from celsius_csrd_chatbot.chains.retriever import make_retriever_node
from celsius_csrd_chatbot.chains.answer_rag import make_rag_node
class GraphState(TypedDict):
"""
Represents the state of our graph.
"""
query: str
esrs_type: str
answer: str
documents: List[Document]
def route_intent(state):
esrs = state["esrs_type"]
if esrs == "none":
return "intent_esrs"
elif esrs == "wrong_esrs":
return "answer_rag_wrong"
else:
return "retrieve_documents"
def make_graph_agent(llm, vectorstore):
workflow = StateGraph(GraphState)
# Define the node functions
categorize_esrs = make_esrs_categorization_node()
intent_esrs = make_esrs_intent_node(llm)
retrieve_documents = make_retriever_node(vectorstore)
answer_rag = make_rag_node(llm, wrong_esrs=False)
answer_rag_wrong = make_rag_node(llm, wrong_esrs=True)
# Define the nodes
workflow.add_node("categorize_esrs", categorize_esrs)
workflow.add_node("intent_esrs", intent_esrs)
workflow.add_node("retrieve_documents", retrieve_documents)
workflow.add_node("answer_rag", answer_rag)
workflow.add_node("answer_rag_wrong", answer_rag_wrong)
# Entry point
workflow.set_entry_point("categorize_esrs")
# CONDITIONAL EDGES
workflow.add_conditional_edges("categorize_esrs", route_intent)
# Define the edges
workflow.add_edge("intent_esrs", "retrieve_documents")
workflow.add_edge("retrieve_documents", "answer_rag")
workflow.add_edge("answer_rag", END)
workflow.add_edge("answer_rag_wrong", END)
# Compile
app = workflow.compile()
return app
def display_graph(app):
display(
Image(
app.get_graph(xray=True).draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
)
)
|