Ilyas KHIAT commited on
Commit
d07e69f
·
1 Parent(s): ba94ff8

graph improv

Browse files
audit_page/knowledge_graph.py CHANGED
@@ -6,13 +6,33 @@ from streamlit_agraph import agraph, Node, Edge, Config
6
 
7
  from utils.kg.construct_kg import get_graph # if still needed for something else
8
  from utils.audit.rag import get_text_from_content_for_doc, get_text_from_content_for_audio
9
- from utils.audit.response_llm import generate_response_via_langchain
10
  from utils.audit.rag import get_vectorstore
11
  from langchain_core.messages import AIMessage, HumanMessage
12
  from langchain_core.prompts import PromptTemplate
13
 
14
  from itext2kg.models import KnowledgeGraph
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  ################################################################################
18
  # Utility Functions
@@ -22,7 +42,7 @@ def if_node_exists(nodes, node_id):
22
  """Check if a node with the given id already exists in a list of Node objects."""
23
  for node in nodes:
24
  if node.id == node_id:
25
- return True
26
  return False
27
 
28
  def generate_random_color():
@@ -234,6 +254,7 @@ def convert_advanced_neo4j_to_agraph(neo4j_graph: KnowledgeGraph, node_colors):
234
  )
235
  return edges, nodes, config
236
 
 
237
  def display_graph(edges, nodes, config):
238
  """Render Agraph."""
239
  return agraph(edges=edges, nodes=nodes, config=config)
@@ -247,6 +268,22 @@ def filter_nodes_by_types(nodes, node_types_filter):
247
  return nodes
248
  return [node for node in nodes if node.title in node_types_filter]
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  ################################################################################
252
  # Dialog Components (same as your original code)
@@ -328,6 +365,29 @@ def change_color_dialog():
328
  if st.button("Valider"):
329
  st.rerun()
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  ################################################################################
333
  # Main KG Function
@@ -345,6 +405,7 @@ def kg_main():
345
  # Depending on how you stored it, it might be a tuple (graph, extra_info)
346
  # or directly a single object. Adjust as needed.
347
  st.session_state.graph = pickle.load(f)
 
348
  print("Graph loaded.")
349
 
350
  # 2. Initialize other session keys if they don’t exist
@@ -359,8 +420,6 @@ def kg_main():
359
 
360
  st.title("Graphe de connaissance")
361
 
362
- edges,nodes,config = None, None, None
363
-
364
  # If we haven’t set up node types yet, do it now
365
  if st.session_state.node_types is None:
366
  # st.session_state.graph is presumably a list/tuple => st.session_state.graph[0]
@@ -370,13 +429,21 @@ def kg_main():
370
  print("Couleurs attribuées")
371
  # Initialize a default filter view
372
  st.session_state.filter_views["Vue par défaut"] = list(node_types)
373
- st.session_state.current_view = "Vue par défaut"
 
 
 
 
 
 
 
 
 
374
 
375
  # 3. Convert the graph to agraph format
376
- edges, nodes, config = convert_advanced_neo4j_to_agraph(
377
- st.session_state.graph, # or st.session_state.graph[0] if needed
378
- st.session_state.node_types
379
- )
380
  print("Graph converti en Agraph")
381
 
382
  # 4. UI layout: (left) the graph itself, (right) the chat
@@ -408,8 +475,21 @@ def kg_main():
408
  # Filter out nodes that don’t match the chosen types
409
  filtered_nodes = filter_nodes_by_types(nodes, filter_selection)
410
 
 
411
  # Render the graph
412
- selected_node_id = display_graph(edges, filtered_nodes, config)
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
  # 5. Chat UI
415
  with col2.container(border=True,height=800):
@@ -435,6 +515,35 @@ def kg_main():
435
  with st.chat_message("AI"):
436
  # Example retrieval (if you have a vectorstore in session state)
437
  # and want to incorporate scenes or graph data:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  if "vectorstore" in st.session_state:
439
  retriever = st.session_state.vectorstore.as_retriever()
440
  context = retriever.invoke(last_message.content)
 
6
 
7
  from utils.kg.construct_kg import get_graph # if still needed for something else
8
  from utils.audit.rag import get_text_from_content_for_doc, get_text_from_content_for_audio
9
+ from utils.audit.response_llm import generate_response_via_langchain,generate_llm_with_tools
10
  from utils.audit.rag import get_vectorstore
11
  from langchain_core.messages import AIMessage, HumanMessage
12
  from langchain_core.prompts import PromptTemplate
13
 
14
  from itext2kg.models import KnowledgeGraph
15
 
16
+ from typing_extensions import Annotated, TypedDict
17
+
18
+ import json
19
+
20
+ class AddRelationship(TypedDict):
21
+ '''Ajouter une relation au graphe'''
22
+
23
+ source_id : Annotated[str, 'The source node ID']
24
+ target_id : Annotated[str, 'The target node ID']
25
+ relationship_type : Annotated[str, 'The type of relationship']
26
+
27
+ class DeleteRelationship(TypedDict):
28
+ '''Supprimer une relation du graphe, une information est donnée sur la relation à supprimer'''
29
+
30
+ source_id : Annotated[str, 'The source node ID']
31
+ target_id : Annotated[str, 'The target node ID']
32
+ relationship_type : Annotated[str, 'The type of relationship']
33
+
34
+ tools = [AddRelationship, DeleteRelationship]
35
+
36
 
37
  ################################################################################
38
  # Utility Functions
 
42
  """Check if a node with the given id already exists in a list of Node objects."""
43
  for node in nodes:
44
  if node.id == node_id:
45
+ return node
46
  return False
47
 
48
  def generate_random_color():
 
254
  )
255
  return edges, nodes, config
256
 
257
+
258
  def display_graph(edges, nodes, config):
259
  """Render Agraph."""
260
  return agraph(edges=edges, nodes=nodes, config=config)
 
268
  return nodes
269
  return [node for node in nodes if node.title in node_types_filter]
270
 
271
+ def format_relationships(relationships : list[Edge]):
272
+ """Format relationships for display in the chat."""
273
+ return "\n".join(
274
+ f"- **{rel.source}** -- {rel.label} --> **{rel.to}**"
275
+ for rel in relationships
276
+ )
277
+
278
+ def add_relationship_to_graph(source_id, target_id, relationship_type):
279
+ st.session_state.edges.append(Edge(source=source_id, label=relationship_type, target=target_id))
280
+
281
+ def delete_relationship_from_graph(source_id, target_id, relationship_type):
282
+ st.session_state.edges = [edge for edge in st.session_state.edges if not (
283
+ edge.source == source_id and edge.to == target_id and edge.label == relationship_type
284
+ )]
285
+
286
+
287
 
288
  ################################################################################
289
  # Dialog Components (same as your original code)
 
365
  if st.button("Valider"):
366
  st.rerun()
367
 
368
+ @st.dialog(title="Modifier l'etiquette du noeud")
369
+ def change_node_label_dialog(selected_node_id):
370
+ """Dialog to change the label of a node."""
371
+ node : Node = if_node_exists(st.session_state.nodes, selected_node_id)
372
+ st.write("- **Nom:** ", node.label)
373
+ st.write("- **Etiquette:** ", node.title)
374
+ if node:
375
+ new_label = st.selectbox("Etiquette du noeud",list(st.session_state.node_types.keys())+["Autre"],index=list(st.session_state.node_types.keys()).index(node.title))
376
+
377
+ if new_label == "Autre":
378
+ new_label_text = st.text_input("Nouvelle étiquette")
379
+
380
+ if st.button("Valider") and new_label:
381
+ if new_label == "Autre" and new_label_text:
382
+ st.session_state.node_types[new_label_text] = rgb_to_hex(generate_random_color())
383
+ node.title = new_label_text
384
+ st.success(f"Etiquette du noeud {selected_node_id} modifiée en {new_label_text}")
385
+ st.rerun()
386
+ node.title = new_label
387
+ st.success(f"Etiquette du noeud {selected_node_id} modifiée en {new_label}")
388
+ st.rerun()
389
+
390
+
391
 
392
  ################################################################################
393
  # Main KG Function
 
405
  # Depending on how you stored it, it might be a tuple (graph, extra_info)
406
  # or directly a single object. Adjust as needed.
407
  st.session_state.graph = pickle.load(f)
408
+
409
  print("Graph loaded.")
410
 
411
  # 2. Initialize other session keys if they don’t exist
 
420
 
421
  st.title("Graphe de connaissance")
422
 
 
 
423
  # If we haven’t set up node types yet, do it now
424
  if st.session_state.node_types is None:
425
  # st.session_state.graph is presumably a list/tuple => st.session_state.graph[0]
 
429
  print("Couleurs attribuées")
430
  # Initialize a default filter view
431
  st.session_state.filter_views["Vue par défaut"] = list(node_types)
432
+ st.session_state.filter_views["Personnages"] = "Person"
433
+ st.session_state.filter_views["Lieux"] = ["Location"]
434
+ st.session_state.filter_views["Concepts"] = ["Concept"]
435
+ st.session_state.current_view = "Personnages"
436
+
437
+ if "edges" not in st.session_state or "nodes" not in st.session_state:
438
+ # Convert the graph to Agraph format
439
+ st.session_state.edges, st.session_state.nodes, st.session_state.config = convert_advanced_neo4j_to_agraph(
440
+ st.session_state.graph, st.session_state.node_types
441
+ )
442
 
443
  # 3. Convert the graph to agraph format
444
+ edges = st.session_state.edges
445
+ nodes = st.session_state.nodes
446
+ config = st.session_state.config
 
447
  print("Graph converti en Agraph")
448
 
449
  # 4. UI layout: (left) the graph itself, (right) the chat
 
475
  # Filter out nodes that don’t match the chosen types
476
  filtered_nodes = filter_nodes_by_types(nodes, filter_selection)
477
 
478
+ col_graph , col_buttons = st.columns([12, 1])
479
  # Render the graph
480
+ print("Affichage du graphe")
481
+ with col_graph.container():
482
+ selected_node_id = display_graph(edges, filtered_nodes, config)
483
+ print("Graphe affiché")
484
+ with col_buttons.container():
485
+ # modify node button with emoji
486
+ if selected_node_id:
487
+ if st.button("📝",key="change label"):
488
+ st.write(f"**Node sélectionné**: `{selected_node_id}`")
489
+ change_node_label_dialog(selected_node_id)
490
+
491
+ if selected_node_id:
492
+ st.write(f"**Noeud sélectionné**: `{selected_node_id}`")
493
 
494
  # 5. Chat UI
495
  with col2.container(border=True,height=800):
 
515
  with st.chat_message("AI"):
516
  # Example retrieval (if you have a vectorstore in session state)
517
  # and want to incorporate scenes or graph data:
518
+ prompt_tool_calling = ("Ta mission est de decider selon la query de l'utilisateur s'il y'a un outil qui correspont et il faut l'appeler, tu dois aussi savoir si on va supprimer une relation ou plutot ajouter une relation\n"
519
+ "Tu as 2 outils , un pour supprimer une relation et l'autre ajouter une relation dans un graphe\n"
520
+ "si un outil est appelé, tu dois le dire à l'utilisateur et tu dois bien extraire les id des noeuds et le type de relation\n"
521
+ "si l'id du noeud existe dans le graphe, extrait le exactement et si le type de relation existe dans le graphe, extrait le exactement\n"
522
+ f"**query de l'utilisateur** : {last_message.content}\n"
523
+ f"**Graph**: {format_relationships(st.session_state.edges)}\n"
524
+ f"sinon tu dois renvoyé: 'Pas d'outils appelé'\n"
525
+ f"les outils sont: {tools}\n"
526
+ f"Output: tu dois ecrire soit 'outil appelé' apres avoir identifier les differents elements soit 'Pas d'outils appelé'\n")
527
+ tools_called = generate_llm_with_tools(tools=tools,query=prompt_tool_calling)
528
+ print(tools_called)
529
+ if 'tool_calls' in tools_called.additional_kwargs:
530
+ for tool_call in tools_called.additional_kwargs['tool_calls']:
531
+ func_name = tool_call["function"]["name"]
532
+
533
+ raw_args = tool_call["function"]["arguments"]
534
+ parsed_args = json.loads(raw_args) # Convert JSON string to dict
535
+
536
+ source_id = parsed_args["source_id"]
537
+ target_id = parsed_args["target_id"]
538
+ relationship_type = parsed_args["relationship_type"]
539
+
540
+ if func_name == "AddRelationship":
541
+ add_relationship_to_graph(source_id, target_id, relationship_type)
542
+ st.write(f"Relation ajoutée: {source_id} -- {relationship_type} --> {target_id}")
543
+ elif func_name == "DeleteRelationship":
544
+ delete_relationship_from_graph(source_id, target_id, relationship_type)
545
+ st.write(f"Relation supprimée: {source_id} -- {relationship_type} --> {target_id}")
546
+
547
  if "vectorstore" in st.session_state:
548
  retriever = st.session_state.vectorstore.as_retriever()
549
  context = retriever.invoke(last_message.content)
test.ipynb CHANGED
@@ -2163,6 +2163,44 @@
2163
  " print(graph)\n",
2164
  " graphs.append(graph)"
2165
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2166
  }
2167
  ],
2168
  "metadata": {
 
2163
  " print(graph)\n",
2164
  " graphs.append(graph)"
2165
  ]
2166
+ },
2167
+ {
2168
+ "cell_type": "code",
2169
+ "execution_count": 7,
2170
+ "metadata": {},
2171
+ "outputs": [],
2172
+ "source": [
2173
+ "#test\n",
2174
+ "\n",
2175
+ "dict = {\n",
2176
+ " \"name\": \"test\",\n",
2177
+ " \"age\": 20\n",
2178
+ "}\n",
2179
+ "\n",
2180
+ "if \"name\" not in dict:\n",
2181
+ " print(\"yes\")"
2182
+ ]
2183
+ },
2184
+ {
2185
+ "cell_type": "code",
2186
+ "execution_count": 2,
2187
+ "metadata": {},
2188
+ "outputs": [
2189
+ {
2190
+ "ename": "TypeError",
2191
+ "evalue": "argument of type 'type' is not iterable",
2192
+ "output_type": "error",
2193
+ "traceback": [
2194
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
2195
+ "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
2196
+ "Cell \u001b[1;32mIn[2], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mname\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mdict\u001b[39;49m\n",
2197
+ "\u001b[1;31mTypeError\u001b[0m: argument of type 'type' is not iterable"
2198
+ ]
2199
+ }
2200
+ ],
2201
+ "source": [
2202
+ "\"name\" in dict"
2203
+ ]
2204
  }
2205
  ],
2206
  "metadata": {
utils/assets/kg_ia_signature.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:55b49436038a45405798f6d05591464b1a35360409d83dbead163921707ac592
3
- size 7354091
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf25348b339fce7a0b9fc1368bb2e5d8c06e06c0c2b47350ca4f9220fc230513
3
+ size 7476666
utils/audit/response_llm.py CHANGED
@@ -66,3 +66,9 @@ def generate_structured_response(query: str, stream: bool = False, model: str =
66
 
67
  # Invoke the LLM chain and return the result
68
  return llm_chain.invoke({"query": query})
 
 
 
 
 
 
 
66
 
67
  # Invoke the LLM chain and return the result
68
  return llm_chain.invoke({"query": query})
69
+
70
+ def generate_llm_with_tools(tools,query: str,model = "gpt-4o-2024-08-06"):
71
+ # Define the prompt template
72
+ llm = ChatOpenAI(model=model)
73
+ llm_with_tools = llm.bind_tools(tools)
74
+ return llm_with_tools.invoke(query)