update unfinished agent
Browse files
celsius_csrd_chatbot/agent.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
from contextlib import contextmanager
|
4 |
+
|
5 |
+
from langchain.schema import Document
|
6 |
+
from langgraph.graph import END, StateGraph
|
7 |
+
from langchain_core.runnables.graph import MermaidDrawMethod
|
8 |
+
|
9 |
+
from typing_extensions import TypedDict
|
10 |
+
from typing import List
|
11 |
+
|
12 |
+
from IPython.display import display, HTML, Image
|
13 |
+
|
14 |
+
from celsius_csrd_chatbot.chains.esrs_categorization import (
|
15 |
+
make_esrs_categorization_node,
|
16 |
+
)
|
17 |
+
from celsius_csrd_chatbot.chains.retriever import make_retriever_node
|
18 |
+
from celsius_csrd_chatbot.chains.answer_rag import make_rag_node
|
19 |
+
|
20 |
+
|
21 |
+
class GraphState(TypedDict):
|
22 |
+
"""
|
23 |
+
Represents the state of our graph.
|
24 |
+
"""
|
25 |
+
|
26 |
+
query: str
|
27 |
+
esrs_type: str
|
28 |
+
answer: str
|
29 |
+
|
30 |
+
|
31 |
+
def route_intent(state):
|
32 |
+
esrs = state["esrs_type"]
|
33 |
+
if esrs == "none":
|
34 |
+
return "intent_esrs"
|
35 |
+
|
36 |
+
else:
|
37 |
+
return "retrieve_documents"
|
38 |
+
|
39 |
+
|
40 |
+
def make_graph_agent(llm, vectorstore):
|
41 |
+
workflow = StateGraph(GraphState)
|
42 |
+
|
43 |
+
# Define the node functions
|
44 |
+
categorize_esrs = make_esrs_categorization_node(llm)
|
45 |
+
retrieve_documents = make_retriever_node(vectorstore)
|
46 |
+
answer_rag = make_rag_node(llm)
|
47 |
+
|
48 |
+
# Define the nodes
|
49 |
+
workflow.add_node("categorize_esrs", categorize_esrs)
|
50 |
+
workflow.add_node("retrieve_documents", retrieve_documents)
|
51 |
+
workflow.add_node("answer_rag", answer_rag)
|
52 |
+
|
53 |
+
# Entry point
|
54 |
+
workflow.set_entry_point("categorize_esrs")
|
55 |
+
|
56 |
+
# Define the edges
|
57 |
+
workflow.add_edge("categorize_esrs", "retrieve_documents")
|
58 |
+
workflow.add_edge("retrieve_documents", "answer_rag")
|
59 |
+
workflow.add_edge("answer_rag", END)
|
60 |
+
|
61 |
+
# Compile
|
62 |
+
app = workflow.compile()
|
63 |
+
return app
|
64 |
+
|
65 |
+
|
66 |
+
def display_graph(app):
|
67 |
+
|
68 |
+
display(
|
69 |
+
Image(
|
70 |
+
app.get_graph(xray=True).draw_mermaid_png(
|
71 |
+
draw_method=MermaidDrawMethod.API,
|
72 |
+
)
|
73 |
+
)
|
74 |
+
)
|
celsius_csrd_chatbot/chains/__init__.py
ADDED
File without changes
|
celsius_csrd_chatbot/chains/answer_rag.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import itemgetter
|
2 |
+
|
3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
4 |
+
from langchain_core.output_parsers import StrOutputParser
|
5 |
+
from langchain_core.prompts.prompt import PromptTemplate
|
6 |
+
from langchain_core.prompts.base import format_document
|
7 |
+
|
8 |
+
esrs_wiki = """
|
9 |
+
|
10 |
+
The Corporate Sustainability Reporting Directive (CSRD) is a mandate that requires all companies to report on their sustainability initiatives. In response to this directive, the European Sustainability Reporting Standards (ESRS) were developed. These standards are a key tool in promoting the transition to a sustainable economy within the EU, providing a structured framework for companies to disclose their sustainability initiatives. The ESRS cover a wide range of environmental, social, and governance (ESG) issues, including climate change, biodiversity, and human rights. Companies that adhere to the ESRS can provide investors with valuable insights into their sustainability impact, thereby informing investment decisions. The ESRS are designed to be highly interoperable with global reporting standards, which helps to avoid unnecessary duplication in reporting by companies. The reporting requirements based on the ESRS will be gradually implemented for different companies over time. In summary, the ESRS play a critical role in fostering sustainable finance and enabling companies to demonstrate their commitment to the green deal agenda while accessing sustainable finance.
|
11 |
+
|
12 |
+
---
|
13 |
+
|
14 |
+
"""
|
15 |
+
|
16 |
+
reformulation_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
|
17 |
+
|
18 |
+
Chat History:
|
19 |
+
{chat_history}
|
20 |
+
Follow Up Input: {question}
|
21 |
+
Standalone question:"""
|
22 |
+
|
23 |
+
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(reformulation_template)
|
24 |
+
|
25 |
+
answering_template = """
|
26 |
+
You are an ESG expert, with 20 years experience analyzing corporate sustainability reports.
|
27 |
+
You are specialist in the upcoming CSRD regulation and in general with corporate sustainability disclosure requirements.
|
28 |
+
{esrs_wiki}
|
29 |
+
|
30 |
+
You will answer the question based on the following passages extracted from CSRD specific sustainability guidelines and reports:
|
31 |
+
```
|
32 |
+
{context}
|
33 |
+
```
|
34 |
+
|
35 |
+
Guidelines:
|
36 |
+
1. Context: You'll receive relevant excerpts from a CSRD-specific sustainability guideline or report to address a given question.
|
37 |
+
2. Relevance: Only include passages directly pertaining to the question; omit irrelevant content.
|
38 |
+
3. Facts and Figures: Prioritize factual information in your response.
|
39 |
+
4. Conciseness: Keep answers sharp and succinct, avoiding unnecessary context.
|
40 |
+
5. Focus: Address the specific question without veering into related topics.
|
41 |
+
6. Honesty: If unsure, state that you don't know rather than inventing an answer.
|
42 |
+
7. Source Attribution: When using information from a passage, mention it as [Doc i] at the end of the sentence (where 'i' represents the document number).
|
43 |
+
8. Multiple Sources: If the same content appears in multiple documents, cite them collectively (e.g., [Doc i, Doc j, Doc k]).
|
44 |
+
9. Structured Paragraphs: Instead of bullet-point summaries, compile your responses into well-structured paragraphs.
|
45 |
+
10. Method Focus: When addressing "how" questions, emphasize methods and procedures over outcomes.
|
46 |
+
11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
|
47 |
+
12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
|
48 |
+
|
49 |
+
Question: {question}
|
50 |
+
Answer:
|
51 |
+
"""
|
52 |
+
|
53 |
+
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
54 |
+
|
55 |
+
|
56 |
+
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"):
|
57 |
+
doc_strings = []
|
58 |
+
for i, doc in enumerate(docs):
|
59 |
+
chunk_type = "Doc"
|
60 |
+
if isinstance(doc, str):
|
61 |
+
doc_formatted = doc
|
62 |
+
else:
|
63 |
+
doc_formatted = format_document(doc, document_prompt)
|
64 |
+
doc_string = f"{chunk_type} {i+1}: " + doc_formatted
|
65 |
+
doc_string = doc_string.replace("\n", " ")
|
66 |
+
doc_strings.append(doc_string)
|
67 |
+
|
68 |
+
return sep.join(doc_strings)
|
69 |
+
|
70 |
+
|
71 |
+
def get_text_docs(x):
|
72 |
+
return [doc for doc in x if doc.metadata["chunk_type"] == "text"]
|
73 |
+
|
74 |
+
|
75 |
+
def make_rag_chain(llm):
|
76 |
+
prompt = ChatPromptTemplate.from_template(answering_template)
|
77 |
+
chain = (
|
78 |
+
{
|
79 |
+
"context": lambda x: _combine_documents(x["documents"]),
|
80 |
+
"query": itemgetter("query"),
|
81 |
+
}
|
82 |
+
| prompt
|
83 |
+
| llm
|
84 |
+
| StrOutputParser()
|
85 |
+
)
|
86 |
+
return chain
|
87 |
+
|
88 |
+
|
89 |
+
def make_rag_node(llm):
|
90 |
+
rag_chain = make_rag_chain(llm)
|
91 |
+
|
92 |
+
async def answer_rag(state, config):
|
93 |
+
answer = await rag_chain.ainvoke(state, config)
|
94 |
+
return {"answer": answer}
|
95 |
+
|
96 |
+
return answer_rag
|
celsius_csrd_chatbot/chains/esrs_categorization.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
2 |
+
from langchain.prompts.prompt import PromptTemplate
|
3 |
+
from langchain.output_parsers import PydanticOutputParser
|
4 |
+
from typing import Literal
|
5 |
+
from operator import itemgetter
|
6 |
+
import json
|
7 |
+
from langchain_core.exceptions import OutputParserException
|
8 |
+
|
9 |
+
|
10 |
+
class ESRSAnalysis(BaseModel):
|
11 |
+
"""Analyzing the user query to get ESRS type, sources and intent"""
|
12 |
+
|
13 |
+
esrs_type: Literal[
|
14 |
+
"ESRS 1",
|
15 |
+
"ESRS 2",
|
16 |
+
"ESRS E1",
|
17 |
+
"ESRS E2",
|
18 |
+
"ESRS E3",
|
19 |
+
"ESRS E4",
|
20 |
+
"ESRS E5",
|
21 |
+
"ESRS S1",
|
22 |
+
"ESRS S2",
|
23 |
+
"ESRS S3",
|
24 |
+
"ESRS S4",
|
25 |
+
"ESRS G1",
|
26 |
+
"none",
|
27 |
+
] = Field(
|
28 |
+
description="""
|
29 |
+
Given a user question choose which documents would be most relevant for answering their question :
|
30 |
+
|
31 |
+
- ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
|
32 |
+
- ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
|
33 |
+
- ESRS E1 is for questions about climate change, global warming, GES and energy
|
34 |
+
- ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
|
35 |
+
- ESRS E3 is for questions about water and marine resources
|
36 |
+
- ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
|
37 |
+
- ESRS E5 is for questions about resource use and circular economy
|
38 |
+
- ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
|
39 |
+
- ESRS S2 is for questions about workers in the value chain, workers' treatment
|
40 |
+
- SRS S3 is for questions about affected communities, impact on local communities
|
41 |
+
- ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
|
42 |
+
- ESRS G1 is for questions about governance, risk management, internal control, and business conduct
|
43 |
+
- none is for questions that do not fit into any of the above categories
|
44 |
+
|
45 |
+
Follow these guidelines :
|
46 |
+
|
47 |
+
- Do not take into account upper or lower case letter distinction. For example, 'esrs 1', 'Esrs 1' and 'ESRS 1' should be considered as 'ESRS 1'.
|
48 |
+
- Some questions could be related to multiple ESRS. In this case, keep all options and format the output as such : 'ESRS 1', 'ESRS 2'.
|
49 |
+
- Remember, if the question is not related to any ESRS, the output should be 'none'.
|
50 |
+
""",
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def make_esrs_categorization_chain(llm):
|
55 |
+
parser = PydanticOutputParser(pydantic_object=ESRSAnalysis)
|
56 |
+
prompt_template = """
|
57 |
+
The following question is about ESRS related topics. Please analyze the question and indicate if it refers to specific ESRS.
|
58 |
+
|
59 |
+
{format_instructions}
|
60 |
+
|
61 |
+
Please answer with the appropriate ESRS to answer the question.
|
62 |
+
|
63 |
+
Question: '{query}'
|
64 |
+
Answer:
|
65 |
+
"""
|
66 |
+
|
67 |
+
prompt = PromptTemplate(
|
68 |
+
template=prompt_template,
|
69 |
+
input_variables=["query"],
|
70 |
+
partial_variables={"format_instructions": parser.get_format_instructions()},
|
71 |
+
)
|
72 |
+
chain = {"query": itemgetter("query")} | prompt | llm | parser
|
73 |
+
|
74 |
+
return chain
|
75 |
+
|
76 |
+
|
77 |
+
def make_esrs_categorization_node(llm):
|
78 |
+
|
79 |
+
def categorize_message(state):
|
80 |
+
query = state["query"]
|
81 |
+
categorization_chain = make_esrs_categorization_chain(llm)
|
82 |
+
output = categorization_chain.invoke(query)
|
83 |
+
|
84 |
+
if not output:
|
85 |
+
raise OutputParserException("Output is empty")
|
86 |
+
|
87 |
+
try:
|
88 |
+
# Attempt to parse the output as JSON
|
89 |
+
parsed_output = json.loads(output)
|
90 |
+
except json.JSONDecodeError as e:
|
91 |
+
# Raise a more informative error if the output is not valid JSON
|
92 |
+
raise OutputParserException(f"Invalid JSON output: {output}") from e
|
93 |
+
|
94 |
+
return output
|
95 |
+
|
96 |
+
return categorize_message
|
97 |
+
|
98 |
+
# intent: str = Field(
|
99 |
+
# enum=[
|
100 |
+
# "Specific topic",
|
101 |
+
# "Implementation reco",
|
102 |
+
# "KPI extraction",
|
103 |
+
# ],
|
104 |
+
# description="""
|
105 |
+
# Categorize the user query in one of the following categories,
|
106 |
+
|
107 |
+
# Examples:
|
108 |
+
# - Specific topic: "What are the specificities of ESRS E1 ?"
|
109 |
+
# - Implementation reco: "How should I compute my scope 1 reduction target ?"
|
110 |
+
# - KPI extraction: "When will the CSRD be mandatory for my small French company ?"
|
111 |
+
# """,
|
112 |
+
# )
|
113 |
+
|
114 |
+
# sources: str = Field(
|
115 |
+
# enum=["ESRS", "External"],
|
116 |
+
# description="""
|
117 |
+
# Given a user question choose which documents would be most relevant for answering their question,
|
118 |
+
# - ESRS is for questions about a specific environmental, social or governance topic, as well as CSRD's general principles and disclosures
|
119 |
+
# - External is for questions about how to implement the CSRD, or general questions about CSRD's context
|
120 |
+
# """,
|
121 |
+
# )
|
celsius_csrd_chatbot/chains/esrs_intent.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
2 |
+
from langchain.prompts.prompt import PromptTemplate
|
3 |
+
from langchain.output_parsers import PydanticOutputParser
|
4 |
+
from typing import Literal
|
5 |
+
from operator import itemgetter
|
6 |
+
import json
|
7 |
+
from langchain_core.exceptions import OutputParserException
|
8 |
+
|
9 |
+
|
10 |
+
class ESRSAnalysis(BaseModel):
|
11 |
+
"""Analyzing the user query to get ESRS type, sources and intent"""
|
12 |
+
|
13 |
+
esrs_type: Literal[
|
14 |
+
"ESRS 1",
|
15 |
+
"ESRS 2",
|
16 |
+
"ESRS E1",
|
17 |
+
"ESRS E2",
|
18 |
+
"ESRS E3",
|
19 |
+
"ESRS E4",
|
20 |
+
"ESRS E5",
|
21 |
+
"ESRS S1",
|
22 |
+
"ESRS S2",
|
23 |
+
"ESRS S3",
|
24 |
+
"ESRS S4",
|
25 |
+
"ESRS G1",
|
26 |
+
"none",
|
27 |
+
] = Field(
|
28 |
+
description="""
|
29 |
+
Given a user question choose which documents would be most relevant for answering their question :
|
30 |
+
|
31 |
+
- ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
|
32 |
+
- ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
|
33 |
+
- ESRS E1 is for questions about climate change, global warming, GES and energy
|
34 |
+
- ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
|
35 |
+
- ESRS E3 is for questions about water and marine resources
|
36 |
+
- ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
|
37 |
+
- ESRS E5 is for questions about resource use and circular economy
|
38 |
+
- ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
|
39 |
+
- ESRS S2 is for questions about workers in the value chain, workers' treatment
|
40 |
+
- SRS S3 is for questions about affected communities, impact on local communities
|
41 |
+
- ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
|
42 |
+
- ESRS G1 is for questions about governance, risk management, internal control, and business conduct
|
43 |
+
- none is for questions that do not fit into any of the above categories
|
44 |
+
|
45 |
+
Follow these guidelines :
|
46 |
+
|
47 |
+
- Do not take into account upper or lower case letter distinction. For example, 'esrs 1', 'Esrs 1' and 'ESRS 1' should be considered as 'ESRS 1'.
|
48 |
+
- Some questions could be related to multiple ESRS. In this case, keep all options and format the output as such : 'ESRS 1', 'ESRS 2'.
|
49 |
+
- Remember, if the question is not related to any ESRS, the output should be 'none'.
|
50 |
+
""",
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def make_esrs_categorization_chain(llm):
|
55 |
+
parser = PydanticOutputParser(pydantic_object=ESRSAnalysis)
|
56 |
+
prompt_template = """
|
57 |
+
The following question is about ESRS related topics. Please analyze the question and indicate if it refers to specific ESRS.
|
58 |
+
|
59 |
+
{format_instructions}
|
60 |
+
|
61 |
+
Please answer with the appropriate ESRS to answer the question.
|
62 |
+
|
63 |
+
Question: '{query}'
|
64 |
+
Answer:
|
65 |
+
"""
|
66 |
+
|
67 |
+
prompt = PromptTemplate(
|
68 |
+
template=prompt_template,
|
69 |
+
input_variables=["query"],
|
70 |
+
partial_variables={"format_instructions": parser.get_format_instructions()},
|
71 |
+
)
|
72 |
+
chain = {"query": itemgetter("query")} | prompt | llm | parser
|
73 |
+
|
74 |
+
return chain
|
75 |
+
|
76 |
+
|
77 |
+
def make_esrs_categorization_node(llm):
|
78 |
+
|
79 |
+
def categorize_message(state):
|
80 |
+
query = state["query"]
|
81 |
+
categorization_chain = make_esrs_categorization_chain(llm)
|
82 |
+
output = categorization_chain.invoke(query)
|
83 |
+
|
84 |
+
if not output:
|
85 |
+
raise OutputParserException("Output is empty")
|
86 |
+
|
87 |
+
try:
|
88 |
+
# Attempt to parse the output as JSON
|
89 |
+
parsed_output = json.loads(output)
|
90 |
+
except json.JSONDecodeError as e:
|
91 |
+
# Raise a more informative error if the output is not valid JSON
|
92 |
+
raise OutputParserException(f"Invalid JSON output: {output}") from e
|
93 |
+
|
94 |
+
return output
|
95 |
+
|
96 |
+
return categorize_message
|
97 |
+
|
98 |
+
# intent: str = Field(
|
99 |
+
# enum=[
|
100 |
+
# "Specific topic",
|
101 |
+
# "Implementation reco",
|
102 |
+
# "KPI extraction",
|
103 |
+
# ],
|
104 |
+
# description="""
|
105 |
+
# Categorize the user query in one of the following categories,
|
106 |
+
|
107 |
+
# Examples:
|
108 |
+
# - Specific topic: "What are the specificities of ESRS E1 ?"
|
109 |
+
# - Implementation reco: "How should I compute my scope 1 reduction target ?"
|
110 |
+
# - KPI extraction: "When will the CSRD be mandatory for my small French company ?"
|
111 |
+
# """,
|
112 |
+
# )
|
113 |
+
|
114 |
+
# sources: str = Field(
|
115 |
+
# enum=["ESRS", "External"],
|
116 |
+
# description="""
|
117 |
+
# Given a user question choose which documents would be most relevant for answering their question,
|
118 |
+
# - ESRS is for questions about a specific environmental, social or governance topic, as well as CSRD's general principles and disclosures
|
119 |
+
# - External is for questions about how to implement the CSRD, or general questions about CSRD's context
|
120 |
+
# """,
|
121 |
+
# )
|
celsius_csrd_chatbot/chains/retriever.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def make_retriever_node(vectorstore, k=10):
|
2 |
+
|
3 |
+
def retrieve_documents(state):
|
4 |
+
sources = state["esrs_type"]
|
5 |
+
query = state["query"]
|
6 |
+
if sources == "none":
|
7 |
+
filters_full = {}
|
8 |
+
else:
|
9 |
+
filters_full = {"ESRS": {"$in": [sources]}}
|
10 |
+
docs = []
|
11 |
+
retriever = vectorstore.as_retriever()
|
12 |
+
docs = retriever.similarity_search_with_score(
|
13 |
+
query=query, filter=filters_full, k=k
|
14 |
+
)
|
15 |
+
for doc in docs:
|
16 |
+
doc.metadata["similarity_score"] = doc.metadata["similarity_score"]
|
17 |
+
|
18 |
+
docs = sorted(docs, key=lambda x: x.metadata["similarity_score"], reverse=True)
|
19 |
+
new_state = {"documents": docs}
|
20 |
+
|
21 |
+
return new_state
|
22 |
+
|
23 |
+
return retrieve_documents
|