momenaca commited on
Commit
d708cb9
·
1 Parent(s): 7bfa7e6

add latest updates regarding agent mode

Browse files
app.py CHANGED
@@ -1,11 +1,9 @@
1
  import os
 
2
  import gradio as gr
3
- from operator import itemgetter
4
  from pinecone import Pinecone
5
  from huggingface_hub import whoami
6
  from langchain.prompts import ChatPromptTemplate
7
- from langchain.schema.output_parser import StrOutputParser
8
- from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
9
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
10
  from langchain_openai import AzureChatOpenAI
11
  from langchain.prompts.prompt import PromptTemplate
@@ -18,7 +16,9 @@ from celsius_csrd_chatbot.utils import (
18
  _combine_documents,
19
  get_llm,
20
  init_env,
 
21
  )
 
22
 
23
  init_env()
24
  chat_model_init = get_llm()
@@ -33,165 +33,86 @@ embeddings = HuggingFaceBgeEmbeddings(
33
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
34
  index = pc.Index(os.getenv("PINECONE_API_INDEX"))
35
  vectorstore = PineconeVectorstore(index, embeddings, "page_content")
36
- retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
37
- chat_model = AzureChatOpenAI()
38
- esrs_wiki = """
39
 
40
- The Corporate Sustainability Reporting Directive (CSRD) is a mandate that requires all companies to report on their sustainability initiatives. In response to this directive, the European Sustainability Reporting Standards (ESRS) were developed. These standards are a key tool in promoting the transition to a sustainable economy within the EU, providing a structured framework for companies to disclose their sustainability initiatives. The ESRS cover a wide range of environmental, social, and governance (ESG) issues, including climate change, biodiversity, and human rights. Companies that adhere to the ESRS can provide investors with valuable insights into their sustainability impact, thereby informing investment decisions. The ESRS are designed to be highly interoperable with global reporting standards, which helps to avoid unnecessary duplication in reporting by companies. The reporting requirements based on the ESRS will be gradually implemented for different companies over time. In summary, the ESRS play a critical role in fostering sustainable finance and enabling companies to demonstrate their commitment to the green deal agenda while accessing sustainable finance.
41
-
42
- ---
43
-
44
- """
45
-
46
- reformulation_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
47
-
48
- Chat History:
49
- {chat_history}
50
- Follow Up Input: {question}
51
- Standalone question:"""
52
-
53
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(reformulation_template)
54
-
55
- answering_template = """
56
- You are an ESG expert, with 20 years experience analyzing corporate sustainability reports.
57
- You are specialist in the upcoming CSRD regulation and in general with corporate sustainability disclosure requirements.
58
- {esrs_wiki}
59
-
60
- You will answer the question based on the following passages extracted from CSRD specific sustainability guidelines and reports:
61
- ```
62
- {context}
63
- ```
64
-
65
- Guidelines:
66
- 1. Context: You'll receive relevant excerpts from a CSRD-specific sustainability guideline or report to address a given question.
67
- 2. Relevance: Only include passages directly pertaining to the question; omit irrelevant content.
68
- 3. Facts and Figures: Prioritize factual information in your response.
69
- 4. Conciseness: Keep answers sharp and succinct, avoiding unnecessary context.
70
- 5. Focus: Address the specific question without veering into related topics.
71
- 6. Honesty: If unsure, state that you don't know rather than inventing an answer.
72
- 7. Source Attribution: When using information from a passage, mention it as [Doc i] at the end of the sentence (where 'i' represents the document number).
73
- 8. Multiple Sources: If the same content appears in multiple documents, cite them collectively (e.g., [Doc i, Doc j, Doc k]).
74
- 9. Structured Paragraphs: Instead of bullet-point summaries, compile your responses into well-structured paragraphs.
75
- 10. Method Focus: When addressing "how" questions, emphasize methods and procedures over outcomes.
76
- 11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
77
- 12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
78
-
79
- Question: {question}
80
- Answer:
81
- """
82
-
83
- ANSWER_PROMPT = ChatPromptTemplate.from_template(answering_template)
84
-
85
- DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
86
  memory = ConversationBufferMemory(
87
  return_messages=True, output_key="answer", input_key="question"
88
  )
89
 
90
- # First we add a step to load memory
91
- # This adds a "memory" key to the input object
92
- loaded_memory = RunnablePassthrough.assign(
93
- chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
94
- )
95
- # Now we calculate the standalone question
96
- standalone_question = {
97
- "standalone_question": {
98
- "question": lambda x: x["question"],
99
- "chat_history": lambda x: _format_chat_history(x["chat_history"]),
100
- }
101
- | CONDENSE_QUESTION_PROMPT
102
- | chat_model
103
- | StrOutputParser(),
104
- }
105
- # Now we retrieve the documents
106
- retrieved_documents = {
107
- "docs": itemgetter("standalone_question") | retriever,
108
- "question": lambda x: x["standalone_question"],
109
- }
110
- # Now we construct the inputs for the final prompt
111
- final_inputs = {
112
- "context": lambda x: _combine_documents(x["docs"], DEFAULT_DOCUMENT_PROMPT),
113
- "question": itemgetter("question"),
114
- "esrs_wiki": lambda x: esrs_wiki,
115
- }
116
- # And finally, we do the part that returns the answers
117
- answer = {
118
- "answer": final_inputs | ANSWER_PROMPT | chat_model,
119
- "docs": itemgetter("docs"),
120
- }
121
- # And now we put it all together!
122
- final_chain = loaded_memory | standalone_question | retrieved_documents | answer
123
-
124
-
125
- async def chat(
126
- query: str,
127
- history: list = [],
128
- ):
129
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
130
  (messages in gradio format, messages in langchain format, source documents)"""
131
- source_string = ""
132
- gradio_format = make_pairs([a.content for a in history]) + [(query, "")]
133
-
134
- # reset memory
135
- memory.clear()
136
- for message in history:
137
- memory.chat_memory.add_message(message)
138
-
139
- inputs = {"question": query}
140
- result = final_chain.astream_log({"question": query})
141
- reformulated_question_path_id = "/logs/AzureChatOpenAI/streamed_output_str/-"
142
- retriever_path_id = "/logs/Retriever/final_output"
143
- final_answer_path_id = "/logs/AzureChatOpenAI:2/streamed_output_str/-"
144
-
145
- async for op in result:
146
- op = op.ops[0]
147
- if op["path"] == reformulated_question_path_id: # reforulated question
148
- new_token = op["value"] # str
149
 
150
- elif op["path"] == retriever_path_id: # documents
151
- sources = op["value"]["documents"] # List[Document]
152
- source_string = "\n\n".join(
153
- [(make_html_source(i, doc)) for i, doc in enumerate(sources, 1)]
154
- )
155
 
156
- elif op["path"] == final_answer_path_id: # final answer
157
- new_token = op["value"] # str
158
- answer_yet = gradio_format[-1][1]
159
- gradio_format[-1] = (query, answer_yet + new_token)
160
 
161
- yield "", gradio_format, history, source_string
 
 
 
162
 
163
- memory.save_context(inputs, {"answer": gradio_format[-1][1]})
164
- yield "", gradio_format, memory.load_memory_variables({})["history"], source_string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  with open("./assets/style.css", "r") as f:
168
  css = f.read()
169
 
170
-
171
- def update_visible(oauth_token: gr.OAuthToken | None):
172
- if oauth_token is None:
173
- return {
174
- bloc_1: gr.update(visible=True),
175
- bloc_2: gr.update(visible=False),
176
- bloc_3: gr.update(visible=False),
177
- }
178
-
179
- org_names = [org["name"] for org in whoami(oauth_token.token)["orgs"]]
180
- if "ekimetrics-esrsqa" in org_names: # logged in group
181
- return {
182
- bloc_1: gr.update(visible=False),
183
- bloc_2: gr.update(visible=True),
184
- bloc_3: gr.update(visible=False),
185
- }
186
-
187
- else: # logged but not in group
188
- return {
189
- bloc_1: gr.update(visible=False),
190
- bloc_2: gr.update(visible=False),
191
- bloc_3: gr.update(visible=True),
192
- }
193
-
194
-
195
  # Set up Gradio Theme
196
  theme = gr.themes.Base(
197
  primary_hue="blue",
@@ -211,18 +132,6 @@ What do you want to learn ?
211
 
212
 
213
  with gr.Blocks(title=f"{demo_name}", css=css, theme=theme) as demo:
214
- # with gr.Row():
215
- # with gr.Column(scale=1):
216
- # login = gr.LoginButton()
217
-
218
- # with gr.Column() as bloc_1:
219
- # textbox_1 = gr.Textbox("You are not logged to Hugging Face !", show_label=False)
220
-
221
- # with gr.Column(visible=False) as bloc_3:
222
- # textbox_3 = gr.Textbox(
223
- # "You are not part of the ESRS Q&A Project, if interested ask access here : https://huggingface.co/ekimetrics-esrsqa"
224
- # )
225
-
226
  with gr.Column(visible=True) as bloc_2:
227
  with gr.Tab("ESRS Q&A"):
228
  with gr.Row():
@@ -262,16 +171,29 @@ with gr.Blocks(title=f"{demo_name}", css=css, theme=theme) as demo:
262
  with gr.Column(scale=1):
263
  gr.Markdown("WIP")
264
 
265
- # demo.load(update_visible, inputs=None, outputs=[bloc_1, bloc_2, bloc_3])
266
- # login.click(update_visible, inputs=[], outputs=[bloc_1, bloc_2, bloc_3])
 
 
 
 
 
267
 
268
  ask.submit(
 
 
 
 
 
 
269
  fn=chat,
270
  inputs=[
271
  ask,
272
- state,
273
  ],
274
- outputs=[ask, chatbot, state, sources_textbox],
 
 
275
  )
276
 
277
 
 
1
  import os
2
+ from datetime import datetime
3
  import gradio as gr
 
4
  from pinecone import Pinecone
5
  from huggingface_hub import whoami
6
  from langchain.prompts import ChatPromptTemplate
 
 
7
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
8
  from langchain_openai import AzureChatOpenAI
9
  from langchain.prompts.prompt import PromptTemplate
 
16
  _combine_documents,
17
  get_llm,
18
  init_env,
19
+ parse_output_llm_with_sources,
20
  )
21
+ from celsius_csrd_chatbot.agent import make_graph_agent, display_graph
22
 
23
  init_env()
24
  chat_model_init = get_llm()
 
33
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
34
  index = pc.Index(os.getenv("PINECONE_API_INDEX"))
35
  vectorstore = PineconeVectorstore(index, embeddings, "page_content")
36
+ llm = AzureChatOpenAI()
37
+ agent = make_graph_agent(llm, vectorstore)
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  memory = ConversationBufferMemory(
40
  return_messages=True, output_key="answer", input_key="question"
41
  )
42
 
43
+
44
+ async def chat(query, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
46
  (messages in gradio format, messages in langchain format, source documents)"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
49
+ print(f">> NEW QUESTION ({date_now}) : {query}")
50
+ inputs = {"query": query}
51
+ result = agent.astream_events(inputs, version="v1")
 
52
 
53
+ docs = []
54
+ docs_html = ""
55
+ output_query = ""
56
+ start_streaming = False
57
 
58
+ steps_display = {
59
+ "categorize_esrs": ("🔄️ Analyzing user query", True),
60
+ "retrieve_documents": ("🔄️ Searching in the knowledge base", True),
61
+ }
62
 
63
+ try:
64
+ async for event in result:
65
+ print(event)
66
+ if event["event"] == "on_chat_model_stream":
67
+ print("line 66")
68
+ if start_streaming == False:
69
+ print("line 68")
70
+ start_streaming = True
71
+ history[-1] = (query, "")
72
+
73
+ new_token = event["data"]["chunk"].content
74
+ previous_answer = history[-1][1]
75
+ previous_answer = previous_answer if previous_answer is not None else ""
76
+ answer_yet = previous_answer + new_token
77
+ answer_yet = parse_output_llm_with_sources(answer_yet)
78
+ history[-1] = (query, answer_yet)
79
+
80
+ elif (
81
+ event["name"] == "retrieve_documents"
82
+ and event["event"] == "on_chain_end"
83
+ ):
84
+ try:
85
+ print("line 84")
86
+ docs = event["data"]["output"]["documents"]
87
+ docs_html = []
88
+ for i, d in enumerate(docs, 1):
89
+ docs_html.append(make_html_source(d, i))
90
+ docs_html = "".join(docs_html)
91
+ except Exception as e:
92
+ print(f"Error getting documents: {e}")
93
+ print(event)
94
+
95
+ for event_name, (
96
+ event_description,
97
+ display_output,
98
+ ) in steps_display.items():
99
+ if event["name"] == event_name:
100
+ print("line 99")
101
+ if event["event"] == "on_chain_start":
102
+ print("line 101")
103
+ answer_yet = event_description
104
+ history[-1] = (query, answer_yet)
105
+
106
+ history = [tuple(x) for x in history]
107
+ yield history, docs_html
108
+
109
+ except Exception as e:
110
+ raise gr.Error(f"{e}")
111
 
112
 
113
  with open("./assets/style.css", "r") as f:
114
  css = f.read()
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # Set up Gradio Theme
117
  theme = gr.themes.Base(
118
  primary_hue="blue",
 
132
 
133
 
134
  with gr.Blocks(title=f"{demo_name}", css=css, theme=theme) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
135
  with gr.Column(visible=True) as bloc_2:
136
  with gr.Tab("ESRS Q&A"):
137
  with gr.Row():
 
171
  with gr.Column(scale=1):
172
  gr.Markdown("WIP")
173
 
174
+ def start_chat(query, history):
175
+ history = history + [(query, None)]
176
+ history = [tuple(x) for x in history]
177
+ return (gr.update(interactive=False), history)
178
+
179
+ def finish_chat():
180
+ return gr.update(interactive=True, value="")
181
 
182
  ask.submit(
183
+ start_chat,
184
+ [ask, chatbot],
185
+ [ask, chatbot],
186
+ queue=False,
187
+ api_name="start_chat_textbox",
188
+ ).then(
189
  fn=chat,
190
  inputs=[
191
  ask,
192
+ chatbot,
193
  ],
194
+ outputs=[chatbot, sources_textbox],
195
+ ).then(
196
+ finish_chat, None, [ask], api_name="finish_chat_textbox"
197
  )
198
 
199
 
celsius_csrd_chatbot/__init__.py ADDED
File without changes
celsius_csrd_chatbot/agent.py CHANGED
@@ -6,6 +6,7 @@ from langchain.schema import Document
6
  from langgraph.graph import END, StateGraph
7
  from langchain_core.runnables.graph import MermaidDrawMethod
8
 
 
9
  from typing_extensions import TypedDict
10
  from typing import List
11
 
@@ -14,6 +15,9 @@ from IPython.display import display, HTML, Image
14
  from celsius_csrd_chatbot.chains.esrs_categorization import (
15
  make_esrs_categorization_node,
16
  )
 
 
 
17
  from celsius_csrd_chatbot.chains.retriever import make_retriever_node
18
  from celsius_csrd_chatbot.chains.answer_rag import make_rag_node
19
 
@@ -26,6 +30,7 @@ class GraphState(TypedDict):
26
  query: str
27
  esrs_type: str
28
  answer: str
 
29
 
30
 
31
  def route_intent(state):
@@ -33,30 +38,49 @@ def route_intent(state):
33
  if esrs == "none":
34
  return "intent_esrs"
35
 
 
 
 
36
  else:
37
  return "retrieve_documents"
38
 
39
 
 
 
 
 
40
  def make_graph_agent(llm, vectorstore):
41
  workflow = StateGraph(GraphState)
42
 
43
  # Define the node functions
44
- categorize_esrs = make_esrs_categorization_node(llm)
 
45
  retrieve_documents = make_retriever_node(vectorstore)
46
- answer_rag = make_rag_node(llm)
 
47
 
48
  # Define the nodes
49
  workflow.add_node("categorize_esrs", categorize_esrs)
 
50
  workflow.add_node("retrieve_documents", retrieve_documents)
51
  workflow.add_node("answer_rag", answer_rag)
 
52
 
53
  # Entry point
54
  workflow.set_entry_point("categorize_esrs")
55
 
 
 
 
 
 
 
 
56
  # Define the edges
57
- workflow.add_edge("categorize_esrs", "retrieve_documents")
58
  workflow.add_edge("retrieve_documents", "answer_rag")
59
  workflow.add_edge("answer_rag", END)
 
60
 
61
  # Compile
62
  app = workflow.compile()
 
6
  from langgraph.graph import END, StateGraph
7
  from langchain_core.runnables.graph import MermaidDrawMethod
8
 
9
+ from tomlkit import document
10
  from typing_extensions import TypedDict
11
  from typing import List
12
 
 
15
  from celsius_csrd_chatbot.chains.esrs_categorization import (
16
  make_esrs_categorization_node,
17
  )
18
+ from celsius_csrd_chatbot.chains.esrs_intent import (
19
+ make_esrs_intent_node,
20
+ )
21
  from celsius_csrd_chatbot.chains.retriever import make_retriever_node
22
  from celsius_csrd_chatbot.chains.answer_rag import make_rag_node
23
 
 
30
  query: str
31
  esrs_type: str
32
  answer: str
33
+ documents: List[Document]
34
 
35
 
36
  def route_intent(state):
 
38
  if esrs == "none":
39
  return "intent_esrs"
40
 
41
+ elif esrs == "wrong_esrs":
42
+ return "answer_rag"
43
+
44
  else:
45
  return "retrieve_documents"
46
 
47
 
48
+ def make_id_dict(values):
49
+ return {k: k for k in values}
50
+
51
+
52
  def make_graph_agent(llm, vectorstore):
53
  workflow = StateGraph(GraphState)
54
 
55
  # Define the node functions
56
+ categorize_esrs = make_esrs_categorization_node()
57
+ intent_esrs = make_esrs_intent_node(llm)
58
  retrieve_documents = make_retriever_node(vectorstore)
59
+ answer_rag = make_rag_node(llm, wrong_esrs=False)
60
+ answer_rag_wrong = make_rag_node(llm, wrong_esrs=True)
61
 
62
  # Define the nodes
63
  workflow.add_node("categorize_esrs", categorize_esrs)
64
+ workflow.add_node("intent_esrs", intent_esrs)
65
  workflow.add_node("retrieve_documents", retrieve_documents)
66
  workflow.add_node("answer_rag", answer_rag)
67
+ workflow.add_node("answer_rag_wrong", answer_rag_wrong)
68
 
69
  # Entry point
70
  workflow.set_entry_point("categorize_esrs")
71
 
72
+ # CONDITIONAL EDGES
73
+ workflow.add_conditional_edges(
74
+ "categorize_esrs",
75
+ route_intent,
76
+ make_id_dict(["intent_esrs", "retrieve_documents", "answer_rag_wrong"]),
77
+ )
78
+
79
  # Define the edges
80
+ workflow.add_edge("intent_esrs", "retrieve_documents")
81
  workflow.add_edge("retrieve_documents", "answer_rag")
82
  workflow.add_edge("answer_rag", END)
83
+ workflow.add_edge("answer_rag_wrong", END)
84
 
85
  # Compile
86
  app = workflow.compile()
celsius_csrd_chatbot/chains/answer_rag.py CHANGED
@@ -13,15 +13,6 @@ The Corporate Sustainability Reporting Directive (CSRD) is a mandate that requir
13
 
14
  """
15
 
16
- reformulation_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
17
-
18
- Chat History:
19
- {chat_history}
20
- Follow Up Input: {question}
21
- Standalone question:"""
22
-
23
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(reformulation_template)
24
-
25
  answering_template = """
26
  You are an ESG expert, with 20 years experience analyzing corporate sustainability reports.
27
  You are specialist in the upcoming CSRD regulation and in general with corporate sustainability disclosure requirements.
@@ -46,7 +37,7 @@ answering_template = """
46
  11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
47
  12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
48
 
49
- Question: {question}
50
  Answer:
51
  """
52
 
@@ -78,6 +69,7 @@ def make_rag_chain(llm):
78
  {
79
  "context": lambda x: _combine_documents(x["documents"]),
80
  "query": itemgetter("query"),
 
81
  }
82
  | prompt
83
  | llm
@@ -86,11 +78,17 @@ def make_rag_chain(llm):
86
  return chain
87
 
88
 
89
- def make_rag_node(llm):
90
  rag_chain = make_rag_chain(llm)
91
 
92
  async def answer_rag(state, config):
93
- answer = await rag_chain.ainvoke(state, config)
94
- return {"answer": answer}
 
 
 
 
 
 
95
 
96
  return answer_rag
 
13
 
14
  """
15
 
 
 
 
 
 
 
 
 
 
16
  answering_template = """
17
  You are an ESG expert, with 20 years experience analyzing corporate sustainability reports.
18
  You are specialist in the upcoming CSRD regulation and in general with corporate sustainability disclosure requirements.
 
37
  11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
38
  12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
39
 
40
+ Question: {query}
41
  Answer:
42
  """
43
 
 
69
  {
70
  "context": lambda x: _combine_documents(x["documents"]),
71
  "query": itemgetter("query"),
72
+ "esrs_wiki": lambda x: esrs_wiki,
73
  }
74
  | prompt
75
  | llm
 
78
  return chain
79
 
80
 
81
+ def make_rag_node(llm, wrong_esrs=False):
82
  rag_chain = make_rag_chain(llm)
83
 
84
  async def answer_rag(state, config):
85
+ if wrong_esrs:
86
+ return {
87
+ "answer": "I'm sorry, I can't answer that question. Please provide a valid ESRS."
88
+ }
89
+
90
+ else:
91
+ answer = await rag_chain.ainvoke(state, config)
92
+ return {"answer": answer}
93
 
94
  return answer_rag
celsius_csrd_chatbot/chains/esrs_categorization.py CHANGED
@@ -1,121 +1,34 @@
1
- from langchain_core.pydantic_v1 import BaseModel, Field
2
- from langchain.prompts.prompt import PromptTemplate
3
- from langchain.output_parsers import PydanticOutputParser
4
- from typing import Literal
5
- from operator import itemgetter
6
- import json
7
- from langchain_core.exceptions import OutputParserException
8
 
9
 
10
- class ESRSAnalysis(BaseModel):
11
- """Analyzing the user query to get ESRS type, sources and intent"""
12
-
13
- esrs_type: Literal[
14
- "ESRS 1",
15
- "ESRS 2",
16
- "ESRS E1",
17
- "ESRS E2",
18
- "ESRS E3",
19
- "ESRS E4",
20
- "ESRS E5",
21
- "ESRS S1",
22
- "ESRS S2",
23
- "ESRS S3",
24
- "ESRS S4",
25
- "ESRS G1",
26
- "none",
27
- ] = Field(
28
- description="""
29
- Given a user question choose which documents would be most relevant for answering their question :
30
-
31
- - ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
32
- - ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
33
- - ESRS E1 is for questions about climate change, global warming, GES and energy
34
- - ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
35
- - ESRS E3 is for questions about water and marine resources
36
- - ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
37
- - ESRS E5 is for questions about resource use and circular economy
38
- - ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
39
- - ESRS S2 is for questions about workers in the value chain, workers' treatment
40
- - SRS S3 is for questions about affected communities, impact on local communities
41
- - ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
42
- - ESRS G1 is for questions about governance, risk management, internal control, and business conduct
43
- - none is for questions that do not fit into any of the above categories
44
-
45
- Follow these guidelines :
46
-
47
- - Do not take into account upper or lower case letter distinction. For example, 'esrs 1', 'Esrs 1' and 'ESRS 1' should be considered as 'ESRS 1'.
48
- - Some questions could be related to multiple ESRS. In this case, keep all options and format the output as such : 'ESRS 1', 'ESRS 2'.
49
- - Remember, if the question is not related to any ESRS, the output should be 'none'.
50
- """,
51
- )
52
-
53
-
54
- def make_esrs_categorization_chain(llm):
55
- parser = PydanticOutputParser(pydantic_object=ESRSAnalysis)
56
- prompt_template = """
57
- The following question is about ESRS related topics. Please analyze the question and indicate if it refers to specific ESRS.
58
-
59
- {format_instructions}
60
-
61
- Please answer with the appropriate ESRS to answer the question.
62
-
63
- Question: '{query}'
64
- Answer:
65
- """
66
-
67
- prompt = PromptTemplate(
68
- template=prompt_template,
69
- input_variables=["query"],
70
- partial_variables={"format_instructions": parser.get_format_instructions()},
71
- )
72
- chain = {"query": itemgetter("query")} | prompt | llm | parser
73
-
74
- return chain
75
-
76
-
77
- def make_esrs_categorization_node(llm):
78
 
79
  def categorize_message(state):
80
  query = state["query"]
81
- categorization_chain = make_esrs_categorization_chain(llm)
82
- output = categorization_chain.invoke(query)
83
-
84
- if not output:
85
- raise OutputParserException("Output is empty")
86
-
87
- try:
88
- # Attempt to parse the output as JSON
89
- parsed_output = json.loads(output)
90
- except json.JSONDecodeError as e:
91
- # Raise a more informative error if the output is not valid JSON
92
- raise OutputParserException(f"Invalid JSON output: {output}") from e
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  return output
95
 
96
  return categorize_message
97
-
98
- # intent: str = Field(
99
- # enum=[
100
- # "Specific topic",
101
- # "Implementation reco",
102
- # "KPI extraction",
103
- # ],
104
- # description="""
105
- # Categorize the user query in one of the following categories,
106
-
107
- # Examples:
108
- # - Specific topic: "What are the specificities of ESRS E1 ?"
109
- # - Implementation reco: "How should I compute my scope 1 reduction target ?"
110
- # - KPI extraction: "When will the CSRD be mandatory for my small French company ?"
111
- # """,
112
- # )
113
-
114
- # sources: str = Field(
115
- # enum=["ESRS", "External"],
116
- # description="""
117
- # Given a user question choose which documents would be most relevant for answering their question,
118
- # - ESRS is for questions about a specific environmental, social or governance topic, as well as CSRD's general principles and disclosures
119
- # - External is for questions about how to implement the CSRD, or general questions about CSRD's context
120
- # """,
121
- # )
 
1
+ import re
 
 
 
 
 
 
2
 
3
 
4
+ def make_esrs_categorization_node():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def categorize_message(state):
7
  query = state["query"]
8
+ pattern = r"ESRS \d|ESRS [A-Z]\d|ESRS [A-Z] \d"
9
+ esrs_truth = [
10
+ "ESRS 1",
11
+ "ESRS 2",
12
+ "ESRS E1",
13
+ "ESRS E2",
14
+ "ESRS E3",
15
+ "ESRS E4",
16
+ "ESRS E5",
17
+ "ESRS S1",
18
+ "ESRS S2",
19
+ "ESRS S3",
20
+ "ESRS S4",
21
+ "ESRS G1",
22
+ ]
23
+
24
+ matches = re.findall(pattern, query.upper())
25
+ if matches:
26
+ true_matches = [match for match in matches if match in esrs_truth]
27
+ output = {"esrs_type": true_matches if true_matches else "wrong_esrs"}
28
+
29
+ else:
30
+ output = {"esrs_type": "none"}
31
 
32
  return output
33
 
34
  return categorize_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
celsius_csrd_chatbot/chains/esrs_intent.py CHANGED
@@ -44,17 +44,16 @@ class ESRSAnalysis(BaseModel):
44
 
45
  Follow these guidelines :
46
 
47
- - Do not take into account upper or lower case letter distinction. For example, 'esrs 1', 'Esrs 1' and 'ESRS 1' should be considered as 'ESRS 1'.
48
- - Some questions could be related to multiple ESRS. In this case, keep all options and format the output as such : 'ESRS 1', 'ESRS 2'.
49
  - Remember, if the question is not related to any ESRS, the output should be 'none'.
50
  """,
51
  )
52
 
53
 
54
- def make_esrs_categorization_chain(llm):
55
  parser = PydanticOutputParser(pydantic_object=ESRSAnalysis)
56
  prompt_template = """
57
- The following question is about ESRS related topics. Please analyze the question and indicate if it refers to specific ESRS.
58
 
59
  {format_instructions}
60
 
@@ -74,26 +73,16 @@ def make_esrs_categorization_chain(llm):
74
  return chain
75
 
76
 
77
- def make_esrs_categorization_node(llm):
78
 
79
- def categorize_message(state):
80
  query = state["query"]
81
- categorization_chain = make_esrs_categorization_chain(llm)
82
  output = categorization_chain.invoke(query)
83
 
84
- if not output:
85
- raise OutputParserException("Output is empty")
86
-
87
- try:
88
- # Attempt to parse the output as JSON
89
- parsed_output = json.loads(output)
90
- except json.JSONDecodeError as e:
91
- # Raise a more informative error if the output is not valid JSON
92
- raise OutputParserException(f"Invalid JSON output: {output}") from e
93
-
94
  return output
95
 
96
- return categorize_message
97
 
98
  # intent: str = Field(
99
  # enum=[
 
44
 
45
  Follow these guidelines :
46
 
47
+ - Some questions could be related to multiple ESRS. In such case, choose the most appropriate one.
 
48
  - Remember, if the question is not related to any ESRS, the output should be 'none'.
49
  """,
50
  )
51
 
52
 
53
+ def make_esrs_intent_chain(llm):
54
  parser = PydanticOutputParser(pydantic_object=ESRSAnalysis)
55
  prompt_template = """
56
+ The following question is about ESRS related topics. Please analyze the question and indicate if it refers to a specific ESRS.
57
 
58
  {format_instructions}
59
 
 
73
  return chain
74
 
75
 
76
+ def make_esrs_intent_node(llm):
77
 
78
+ def intent_message(state):
79
  query = state["query"]
80
+ categorization_chain = make_esrs_intent_chain(llm)
81
  output = categorization_chain.invoke(query)
82
 
 
 
 
 
 
 
 
 
 
 
83
  return output
84
 
85
+ return intent_message
86
 
87
  # intent: str = Field(
88
  # enum=[
celsius_csrd_chatbot/chains/retriever.py CHANGED
@@ -6,14 +6,15 @@ def make_retriever_node(vectorstore, k=10):
6
  if sources == "none":
7
  filters_full = {}
8
  else:
9
- filters_full = {"ESRS": {"$in": [sources]}}
10
  docs = []
11
- retriever = vectorstore.as_retriever()
12
- docs = retriever.similarity_search_with_score(
13
  query=query, filter=filters_full, k=k
14
  )
15
- for doc in docs:
16
- doc.metadata["similarity_score"] = doc.metadata["similarity_score"]
 
 
17
 
18
  docs = sorted(docs, key=lambda x: x.metadata["similarity_score"], reverse=True)
19
  new_state = {"documents": docs}
 
6
  if sources == "none":
7
  filters_full = {}
8
  else:
9
+ filters_full = {"ESRS_filter": {"$in": sources}}
10
  docs = []
11
+ docs_retrieved = vectorstore.similarity_search_with_score(
 
12
  query=query, filter=filters_full, k=k
13
  )
14
+ for doc in docs_retrieved:
15
+ doc_append = doc[0]
16
+ doc_append.metadata["similarity_score"] = doc[1]
17
+ docs.append(doc_append)
18
 
19
  docs = sorted(docs, key=lambda x: x.metadata["similarity_score"], reverse=True)
20
  new_state = {"documents": docs}
celsius_csrd_chatbot/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from typing import Tuple, List
3
  from dotenv import load_dotenv
4
  from msal import ConfidentialClientApplication
@@ -58,13 +59,13 @@ def make_pairs(lst):
58
  def make_html_source(i, doc):
59
  if doc.metadata["source"] == "ESRS":
60
  return f"""
61
- <div class="card">
62
  <div class="card-content">
63
  <h3>Doc {i}</h2>
64
  <p>{doc.page_content}</p>
65
  </div>
66
  <div class="card-footer">
67
- <span>{doc.metadata['ESRS']} \n</span>
68
  <span>DR: {doc.metadata['DR']} \n</span>
69
  <span>Data type: {doc.metadata['Data type']} \n</span>
70
  </div>
@@ -82,3 +83,24 @@ def make_html_source(i, doc):
82
  </div>
83
  </div>
84
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
  from typing import Tuple, List
4
  from dotenv import load_dotenv
5
  from msal import ConfidentialClientApplication
 
59
  def make_html_source(i, doc):
60
  if doc.metadata["source"] == "ESRS":
61
  return f"""
62
+ <div class="card" id="doc{i}">
63
  <div class="card-content">
64
  <h3>Doc {i}</h2>
65
  <p>{doc.page_content}</p>
66
  </div>
67
  <div class="card-footer">
68
+ <span>{doc.metadata['ESRS_filter']} \n</span>
69
  <span>DR: {doc.metadata['DR']} \n</span>
70
  <span>Data type: {doc.metadata['Data type']} \n</span>
71
  </div>
 
83
  </div>
84
  </div>
85
  """
86
+
87
+
88
+ def parse_output_llm_with_sources(output):
89
+ # Split the content into a list of text and "[Doc X]" references
90
+ content_parts = re.split(r"\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]", output)
91
+ parts = []
92
+ for part in content_parts:
93
+ if part.startswith("Doc"):
94
+ subparts = part.split(",")
95
+ subparts = [
96
+ subpart.lower().replace("doc", "").strip() for subpart in subparts
97
+ ]
98
+ subparts = [
99
+ f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>"""
100
+ for subpart in subparts
101
+ ]
102
+ parts.append("".join(subparts))
103
+ else:
104
+ parts.append(part)
105
+ content_parts = "".join(parts)
106
+ return content_parts
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ version = "0.1.0"
4
  description = ""
5
  authors = ["Miguel Omenaca Muro <[email protected]>"]
6
  readme = "README.md"
7
- package-mode = false
8
 
9
  [tool.poetry.dependencies]
10
  python = ">=3.10,<3.13"
@@ -17,7 +17,10 @@ loadenv = "^0.1.1"
17
  openai = "^1.34.0"
18
  langchain-openai = "^0.1.8"
19
  pinecone = "^4.0.0"
20
- pinecone-client = "^4.1.1"
 
 
 
21
 
22
 
23
  [build-system]
 
4
  description = ""
5
  authors = ["Miguel Omenaca Muro <[email protected]>"]
6
  readme = "README.md"
7
+ package-mode = true
8
 
9
  [tool.poetry.dependencies]
10
  python = ">=3.10,<3.13"
 
17
  openai = "^1.34.0"
18
  langchain-openai = "^0.1.8"
19
  pinecone = "^4.0.0"
20
+ pinecone-client = "^5.0.1"
21
+ langgraph = "^0.2.0"
22
+ langchain-core = "^0.2.29"
23
+ ipython = "^8.26.0"
24
 
25
 
26
  [build-system]