File size: 2,380 Bytes
7bfa7e6
 
 
 
 
 
 
 
d708cb9
7bfa7e6
 
 
 
 
 
 
 
d708cb9
 
 
7bfa7e6
 
 
 
 
 
 
 
 
 
 
 
d708cb9
7bfa7e6
 
 
 
 
 
 
d708cb9
8ca00e0
d708cb9
7bfa7e6
 
 
 
 
 
 
 
d708cb9
 
7bfa7e6
d708cb9
 
7bfa7e6
 
 
d708cb9
7bfa7e6
 
d708cb9
7bfa7e6
 
 
 
d708cb9
8ca00e0
d708cb9
7bfa7e6
d708cb9
7bfa7e6
 
d708cb9
7bfa7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import sys
import os
from contextlib import contextmanager

from langchain.schema import Document
from langgraph.graph import END, StateGraph
from langchain_core.runnables.graph import MermaidDrawMethod

from tomlkit import document
from typing_extensions import TypedDict
from typing import List

from IPython.display import display, HTML, Image

from celsius_csrd_chatbot.chains.esrs_categorization import (
    make_esrs_categorization_node,
)
from celsius_csrd_chatbot.chains.esrs_intent import (
    make_esrs_intent_node,
)
from celsius_csrd_chatbot.chains.retriever import make_retriever_node
from celsius_csrd_chatbot.chains.answer_rag import make_rag_node


class GraphState(TypedDict):
    """
    Represents the state of our graph.
    """

    query: str
    esrs_type: str
    answer: str
    documents: List[Document]


def route_intent(state):
    esrs = state["esrs_type"]
    if esrs == "none":
        return "intent_esrs"

    elif esrs == "wrong_esrs":
        return "answer_rag_wrong"

    else:
        return "retrieve_documents"


def make_graph_agent(llm, vectorstore):
    workflow = StateGraph(GraphState)

    # Define the node functions
    categorize_esrs = make_esrs_categorization_node()
    intent_esrs = make_esrs_intent_node(llm)
    retrieve_documents = make_retriever_node(vectorstore)
    answer_rag = make_rag_node(llm, wrong_esrs=False)
    answer_rag_wrong = make_rag_node(llm, wrong_esrs=True)

    # Define the nodes
    workflow.add_node("categorize_esrs", categorize_esrs)
    workflow.add_node("intent_esrs", intent_esrs)
    workflow.add_node("retrieve_documents", retrieve_documents)
    workflow.add_node("answer_rag", answer_rag)
    workflow.add_node("answer_rag_wrong", answer_rag_wrong)

    # Entry point
    workflow.set_entry_point("categorize_esrs")

    # CONDITIONAL EDGES
    workflow.add_conditional_edges("categorize_esrs", route_intent)

    # Define the edges
    workflow.add_edge("intent_esrs", "retrieve_documents")
    workflow.add_edge("retrieve_documents", "answer_rag")
    workflow.add_edge("answer_rag", END)
    workflow.add_edge("answer_rag_wrong", END)

    # Compile
    app = workflow.compile()
    return app


def display_graph(app):

    display(
        Image(
            app.get_graph(xray=True).draw_mermaid_png(
                draw_method=MermaidDrawMethod.API,
            )
        )
    )