File size: 25,310 Bytes
f1342ba
5e3fa8e
7b897df
 
5e3fa8e
 
 
 
d07e69f
5e3fa8e
9b707db
c2f2340
5e3fa8e
0222cea
1e3f619
d07e69f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e3f619
5e3fa8e
 
 
1e3f619
5e3fa8e
 
1e3f619
 
d07e69f
1e3f619
 
7b897df
5e3fa8e
7b897df
 
 
 
 
 
5e3fa8e
7b897df
 
 
5e3fa8e
 
 
 
 
 
7b897df
 
5e3fa8e
 
 
 
7b897df
 
 
5e3fa8e
 
7b897df
 
 
5e3fa8e
 
 
 
7b897df
5e3fa8e
 
7b897df
5e3fa8e
 
 
 
 
 
 
 
 
 
 
 
7b897df
5e3fa8e
 
7b897df
5e3fa8e
 
 
1e3f619
5e3fa8e
 
 
 
 
 
 
 
 
 
 
 
 
1e3f619
5e3fa8e
 
 
1e3f619
 
 
 
5e3fa8e
1e3f619
5e3fa8e
1e3f619
5e3fa8e
 
 
 
 
 
 
 
 
 
 
f903fca
1e3f619
5e3fa8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b897df
 
5e3fa8e
0222cea
5e3fa8e
 
0222cea
 
 
 
5e3fa8e
0222cea
5e3fa8e
0222cea
5e3fa8e
 
 
 
 
 
 
 
 
 
 
0222cea
5e3fa8e
0222cea
 
 
 
5e3fa8e
 
0222cea
5e3fa8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0222cea
 
d07e69f
7b897df
5e3fa8e
9b707db
6237635
 
5e3fa8e
 
 
 
 
 
 
7b897df
d07e69f
 
 
 
 
 
 
07a1739
 
 
 
 
 
 
d07e69f
 
07a1739
 
 
 
 
 
 
 
 
 
 
 
 
 
d07e69f
 
 
 
 
07a1739
 
d07e69f
 
f1342ba
5e3fa8e
 
 
5596129
 
5e3fa8e
 
 
 
5596129
 
 
 
 
5e3fa8e
5596129
 
 
 
5e3fa8e
 
6237635
 
5596129
 
5e3fa8e
 
5596129
 
5e3fa8e
33ab192
5e3fa8e
 
 
 
 
 
 
 
33ab192
 
 
5e3fa8e
 
 
 
5596129
 
 
5e3fa8e
5596129
 
5e3fa8e
 
 
5596129
 
5e3fa8e
5596129
 
5e3fa8e
 
 
5596129
 
5e3fa8e
5596129
 
5e3fa8e
5596129
5e3fa8e
 
 
 
 
07a1739
 
5e3fa8e
5596129
 
 
 
d07e69f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5596129
5e3fa8e
 
 
f1342ba
 
5e3fa8e
 
 
 
 
c2f2340
1e3f619
5e3fa8e
 
 
 
d07e69f
5e3fa8e
 
07a1739
 
5e3fa8e
5596129
 
 
 
7b897df
 
9b707db
 
1e3f619
5e3fa8e
c2f2340
5e3fa8e
 
 
 
 
 
 
 
5596129
d07e69f
 
 
 
 
 
 
 
 
 
7b897df
5e3fa8e
d07e69f
 
 
5e3fa8e
 
07a1739
 
 
 
 
 
 
 
 
5e3fa8e
 
7b897df
5e3fa8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b897df
5e3fa8e
 
 
 
 
7b897df
d07e69f
5e3fa8e
d07e69f
 
 
 
 
 
 
 
 
 
 
 
 
5e3fa8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d07e69f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e3fa8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
import streamlit as st
import pickle
import random
import math
from streamlit_agraph import agraph, Node, Edge, Config

from utils.kg.construct_kg import get_graph  # if still needed for something else
from utils.audit.rag import get_text_from_content_for_doc, get_text_from_content_for_audio
from utils.audit.response_llm import generate_response_via_langchain,generate_llm_with_tools
from utils.audit.rag import get_vectorstore
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import PromptTemplate

from itext2kg.models import KnowledgeGraph

from typing_extensions import Annotated, TypedDict

import json

class AddRelationship(TypedDict):
    '''Ajouter une relation au graphe'''

    source_id : Annotated[str, 'The source node ID']
    target_id : Annotated[str, 'The target node ID']
    relationship_type : Annotated[str, 'The type of relationship']

class DeleteRelationship(TypedDict):
    '''Supprimer une relation du graphe, une information est donnée sur la relation à supprimer'''

    source_id : Annotated[str, 'The source node ID']
    target_id : Annotated[str, 'The target node ID']
    relationship_type : Annotated[str, 'The type of relationship']

tools = [AddRelationship, DeleteRelationship]


################################################################################
# Utility Functions
################################################################################

def if_node_exists(nodes, node_id):
    """Check if a node with the given id already exists in a list of Node objects."""
    for node in nodes:
        if node.id == node_id:
            return node
    return False

def generate_random_color():
    """Generate a random pastel-ish RGB color."""
    r = random.randint(180, 255)
    g = random.randint(180, 255)
    b = random.randint(180, 255)
    return (r, g, b)

def rgb_to_hex(rgb):
    """Convert an (R, G, B) tuple to a hex string like '#aabbcc'."""
    return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])

def color_distance(color1, color2):
    """Calculate Euclidean distance between two RGB colors."""
    return math.sqrt(
        (color1[0] - color2[0])**2 +
        (color1[1] - color2[1])**2 +
        (color1[2] - color2[2])**2
    )

def generate_distinct_colors(num_colors, min_distance=30):
    """
    Generate a list of distinct pastel-ish colors (in hex), ensuring each is
    at least `min_distance` away from the others in RGB space.
    """
    colors = []
    while len(colors) < num_colors:
        new_color = generate_random_color()
        if all(color_distance(new_color, existing_color) >= min_distance 
               for existing_color in colors):
            colors.append(new_color)
    return [rgb_to_hex(color) for color in colors]

def list_to_dict_colors(node_types):
    """
    Create a dict mapping each node type to a random (distinct) hex color.
    """
    number_of_colors = len(node_types)
    color_hexes = generate_distinct_colors(number_of_colors)
    return {typ: color_hexes[i] for i, typ in enumerate(node_types)}

def get_node_types_advanced(graph: KnowledgeGraph):
    """
    Extract the set of node labels from an itext2kg KnowledgeGraph.
    (graph.entities have .label, relationships have .startEntity, .endEntity)
    """
    node_types = set()
    dict_node_colors = {}
    for node in graph.entities:
        node_types.add(node.label)
    for relationship in graph.relationships:
        node_types.add(relationship.startEntity.label)
        node_types.add(relationship.endEntity.label)

    dict_node_colors = {node:rgb_to_hex(generate_random_color()) for node in node_types}
    return node_types, dict_node_colors

################################################################################
# Graph Conversion
################################################################################

def get_node_types(graph):
    """
    Extract the set of node types from a graph that has:
        graph.nodes -> [ Node(id, type) ... ]
        graph.relationships -> [ Relationship(source, target, type) ... ]
    """
    node_types = set()
    for node in graph.nodes:
        node_types.add(node.type)
    for rel in graph.relationships:
        node_types.add(rel.source.type)
        node_types.add(rel.target.type)
    return node_types

def convert_neo4j_to_agraph(neo4j_graph, node_colors):
    """
    Convert a “Neo4j-like” object into Agraph Nodes & Edges.
    """
    nodes = []
    edges = []

    # Create nodes
    for node in neo4j_graph.nodes:
        node_id = node.id.replace(" ", "_")
        label = node.id
        type_ = node.type

        new_node = Node(
            id=node_id,
            title=type_,   # 'title' effectively becomes "type"
            label=label,
            size=25,
            shape="circle",
            color=node_colors.get(type_, "#cccccc")
        )
        if not if_node_exists(nodes, node_id):
            nodes.append(new_node)

    # Create edges
    for rel in neo4j_graph.relationships:
        source_id = rel.source.id.replace(" ", "_")
        target_id = rel.target.id.replace(" ", "_")

        # Ensure nodes exist (if not from the loop above):
        if not if_node_exists(nodes, source_id):
            nodes.append(Node(
                id=source_id,
                title=rel.source.type,
                label=rel.source.id,
                size=25,
                shape="circle",
                color=node_colors.get(rel.source.type, "#cccccc")
            ))
        if not if_node_exists(nodes, target_id):
            nodes.append(Node(
                id=target_id,
                title=rel.target.type,
                label=rel.target.id,
                size=25,
                shape="circle",
                color=node_colors.get(rel.target.type, "#cccccc")
            ))

        edges.append(Edge(
            source=source_id,
            label=rel.type,
            target=target_id
        ))

    config = Config(
        width=1200,
        height=800,
        directed=True,
        physics=True,
        hierarchical=True,
        from_json="config.json"
    )
    return edges, nodes, config

def convert_advanced_neo4j_to_agraph(neo4j_graph: KnowledgeGraph, node_colors):
    """
    Same logic as above, but adapted to an itext2kg.models.KnowledgeGraph object
    (graph.entities, graph.relationships). 
    """
    nodes = []
    edges = []

    # Create nodes
    for node in neo4j_graph.entities:
        node_id = node.name.replace(" ", "_")
        label = node.name
        type_ = node.label
        new_node = Node(
            id=node_id,
            title=type_,
            label=label,
            size=25,
            shape="circle",
            color=node_colors[type_]
        )
        if not if_node_exists(nodes, new_node.id):
            nodes.append(new_node)

    # Create edges
    for relationship in neo4j_graph.relationships:
        source = relationship.startEntity
        target = relationship.endEntity

        source_id = source.name.replace(" ", "_")
        target_id = target.name.replace(" ", "_")

        # Ensure existence of the source node
        if not if_node_exists(nodes, source_id):
            nodes.append(Node(
                id=source_id,
                title=source.label,
                label=source.name,
                size=25,
                shape="circle",
                color=node_colors.get(source.label, "#CCCCCC")
            ))

        # Ensure existence of the target node
        if not if_node_exists(nodes, target_id):
            nodes.append(Node(
                id=target_id,
                title=target.label,
                label=target.name,
                size=25,
                shape="circle",
                color=node_colors.get(target.label, "#CCCCCC")
            ))

        edges.append(Edge(
            source=source_id,
            label=relationship.name,
            target=target_id
        ))

    config = Config(
        width=1200, 
        height=800, 
        directed=True, 
        physics=True, 
        hierarchical=True,
        from_json="config.json"
    )
    return edges, nodes, config


def display_graph(edges, nodes, config):
    """Render Agraph."""
    return agraph(edges=edges, nodes=nodes, config=config)


def filter_nodes_by_types(nodes, node_types_filter):
    """
    Filter out Agraph nodes by the node’s 'title' field (which is used as 'type' here).
    """
    if not node_types_filter:
        return nodes
    return [node for node in nodes if node.title in node_types_filter]

def format_relationships(relationships : list[Edge]):
    """Format relationships for display in the chat."""
    return "\n".join(
        f"- **{rel.source}** -- {rel.label} --> **{rel.to}**"
        for rel in relationships
    )

def fortmat_nodes(nodes : list[Node]):
    """Format nodes for display in the chat."""
    return "\n".join(
        f"- **{node.label}** ({node.title})"
        for node in nodes
    )

def add_relationship_to_graph(source_id, target_id, relationship_type):
    st.session_state.edges.append(Edge(source=source_id, label=relationship_type, target=target_id))
    print(f"Relation ajoutée: {source_id} -- {relationship_type} --> {target_id}")

    if not if_node_exists(st.session_state.nodes, source_id):
        st.session_state.nodes.append(Node(
            id=source_id,
            title="Autre",
            label=source_id,
            size=25,
            shape="circle",
            color=st.session_state.node_types.get(target.label, "#CCCCCC")
        ))
        print(f"Node ajouté: {source_id}")

    print(f"Nodes: {fortmat_nodes(st.session_state.nodes)}")

def delete_relationship_from_graph(source_id, target_id, relationship_type):
    st.session_state.edges = [edge for edge in st.session_state.edges if not (
        edge.source == source_id and edge.to == target_id and edge.label == relationship_type
    )]


    


################################################################################
# Dialog Components (same as your original code)
################################################################################
@st.dialog(title="Changer la vue")
def change_view_dialog():
    """
    Dialog to rename or delete existing views from st.session_state.filter_views
    and choose the active one (st.session_state.current_view).
    """
    st.write("Changer la vue")
    for index, item in enumerate(st.session_state.filter_views.keys()):
        emp = st.empty()
        col1, col2, col3 = emp.columns([8, 1, 1])

        # Delete the view (except for the default if you want)
        if index > 0 and col2.button("🗑️", key=f"del{index}"):
            del st.session_state.filter_views[item]
            st.session_state.current_view = "Vue par défaut"
            st.rerun()

        # Choose the view
        but_content = "🔍" if st.session_state.current_view != item else "✅"
        if col3.button(but_content, key=f"valid{index}"):
            st.session_state.current_view = item
            st.rerun()

        # Show details / rename
        if len(st.session_state.filter_views.keys()) > index:
            with col1.expander(item):
                # Don’t allow renaming the default view (index=0) if you want
                if index > 0:
                    change_name = st.text_input(
                        "Nom de la vue",
                        label_visibility="collapsed",
                        placeholder="Changez le nom de la vue",
                        key=f"change_name{index}"
                    )
                    if st.button("Renommer", key=f"rename{index}"):
                        if change_name.strip():
                            st.session_state.filter_views[change_name] = st.session_state.filter_views.pop(item)
                            st.session_state.current_view = change_name
                            st.rerun()
                st.markdown(
                    "\n".join(f"- {label.strip()}"
                              for label in st.session_state.filter_views[item])
                )
        else:
            emp.empty()


@st.dialog(title="Ajouter une vue")
def add_view_dialog(filters):
    """
    Dialog to add a new “view” to st.session_state.filter_views, specifying which types to filter by.
    """
    st.write("Ajouter une vue")
    view_name = st.text_input("Nom de la vue")
    st.markdown("Les filtres actuels :")
    st.write(filters)
    if st.button("Ajouter la vue"):
        if view_name.strip():
            st.session_state.filter_views[view_name] = filters
            st.session_state.current_view = view_name
        st.rerun()


@st.dialog(title="Changer la couleur")
def change_color_dialog():
    """Dialog to interactively change colors of each node type via color pickers."""
    st.write("Changer la couleur")
    for node_type, color in st.session_state.node_types.items():
        new_color = st.color_picker(
            f"La couleur de l'entité **{node_type.strip()}**",
            color
        )
        print("New color:", new_color)
        print("Old color:", color)
        st.session_state.node_types[node_type] = new_color
    
    if st.button("Valider"):
        st.rerun()

@st.dialog(title="Modifier l'etiquette du noeud")
def change_node_label_dialog(selected_node_id):
    """Dialog to change the label of a node."""
    node : Node = if_node_exists(st.session_state.nodes, selected_node_id)
    st.write("- **Nom:** ", node.label)
    st.write("- **Etiquette:** ", node.title)
    if node:
        new_label = st.selectbox("Etiquette du noeud",list(st.session_state.node_types.keys())+["Autre"],index=list(st.session_state.node_types.keys()).index(node.title))

        if new_label == "Autre":
            new_label_text = st.text_input("Nouvelle étiquette")

        if st.button("Valider") and new_label:
            if new_label == "Autre" and new_label_text:
                st.session_state.node_types[new_label_text] = rgb_to_hex(generate_random_color())
                node.title = new_label_text
                st.success(f"Etiquette du noeud {selected_node_id} modifiée en {new_label_text}")
                st.rerun()
            node.title = new_label
            st.success(f"Etiquette du noeud {selected_node_id} modifiée en {new_label}")
            st.rerun()
    


################################################################################
# Main KG Function
################################################################################

def kg_main():
    # 1. Load your pickles (if not already loaded in session state)
    if "scenes" not in st.session_state:
        with open("./utils/assets/scenes.pkl", "rb") as f:
            st.session_state.scenes = pickle.load(f)
            st.session_state.vectorstore = get_vectorstore(st.session_state.scenes)
    
    if "graph" not in st.session_state:
        with open("./utils/assets/kg_ia_signature.pkl", "rb") as f:
            # Depending on how you stored it, it might be a tuple (graph, extra_info) 
            # or directly a single object. Adjust as needed.
            st.session_state.graph = pickle.load(f)
            
            print("Graph loaded.")

    

    # 2. Initialize other session keys if they don’t exist
    if "filter_views" not in st.session_state:
        st.session_state.filter_views = {}
    if "current_view" not in st.session_state:
        st.session_state.current_view = None
    if "node_types" not in st.session_state:
        st.session_state.node_types = None
    if "chat_graph_history" not in st.session_state:
        st.session_state.chat_graph_history = []

    st.title("Graphe de connaissance")

    # If we haven’t set up node types yet, do it now
    if st.session_state.node_types is None:
        # st.session_state.graph is presumably a list/tuple => st.session_state.graph[0]
        # Or just st.session_state.graph if you stored it directly as a single obj
        node_types, st.session_state.node_types = get_node_types_advanced(st.session_state.graph)
        # st.write(f"Types d'entités trouvés : {node_types}")  
        print("Couleurs attribuées")
        # Initialize a default filter view
        st.session_state.filter_views["Vue par défaut"] = list(node_types)
        st.session_state.filter_views["Personnages"] = "Person"
        st.session_state.filter_views["Lieux"] = ["Location"]
        st.session_state.filter_views["Concepts"] = ["Concept"]
        st.session_state.current_view = "Personnages"

    if "edges" not in st.session_state or "nodes" not in st.session_state:
        # Convert the graph to Agraph format
        st.session_state.edges, st.session_state.nodes, st.session_state.config = convert_advanced_neo4j_to_agraph(
            st.session_state.graph, st.session_state.node_types
        )

    # 3. Convert the graph to agraph format
    edges = st.session_state.edges
    nodes = st.session_state.nodes
    config = st.session_state.config
    print("Graph converti en Agraph")

    #ask chatgpt to analyse the graph
    prompt = ("Tu es un expert en graphes de connaissances, analyse le graphe et donne une synthèse et differentes conclusions sur les elements du recit, tout en etant pertinent et precis",
                "**Graphe**:,"
                f"**Noeuds**: {fortmat_nodes(st.session_state.nodes)}\n"
                f"Relations: {format_relationships(st.session_state.edges)}",
                "Output: tu dois donner une synthèse et des conclusions sur les elements du recit , ca sera le premier message de la conversation"),  
    response = generate_response_via_langchain(prompt)
    st.session_state.chat_graph_history.append(AIMessage(content=response))

    # 4. UI layout: (left) the graph itself, (right) the chat
    col1, col2 = st.columns([3, 1])

    with col1.container(border=True,height=800):
        st.write(f"#### Visualisation du graphe (**{st.session_state.current_view}**)")

        filter_col, add_view_col, change_view_col, color_col = st.columns([9, 1, 1, 1])
        
        if color_col.button("🎨", help="Changer la couleur"):
            change_color_dialog()

        if change_view_col.button("🔍", help="Changer de vue"):
            change_view_dialog()

        # Currently selected filter for the chosen view
        current_filters = st.session_state.filter_views.get(st.session_state.current_view, [])
        filter_selection = filter_col.multiselect(
            "Filtrer selon l'étiquette",
            st.session_state.node_types.keys(),
            default=current_filters,
            label_visibility="collapsed"
        )
        
        if add_view_col.button("➕", help="Ajouter une vue"):
            add_view_dialog(filter_selection)

        # Filter out nodes that don’t match the chosen types
        filtered_nodes = filter_nodes_by_types(nodes, filter_selection)
        
        col_graph , col_buttons = st.columns([12, 1])
        # Render the graph
        print("Affichage du graphe")
        with col_graph.container():
            selected_node_id = display_graph(edges, filtered_nodes, config)
        print("Graphe affiché")
        with col_buttons.container():
            # modify node button with emoji
            if selected_node_id:
                if st.button("📝",key="change label"):
                    st.write(f"**Node sélectionné**: `{selected_node_id}`")
                    change_node_label_dialog(selected_node_id)

        if selected_node_id:
            st.write(f"**Noeud sélectionné**: `{selected_node_id}`")

    # 5. Chat UI
    with col2.container(border=True,height=800):
        st.markdown("#### Dialoguer avec le graphe")
        user_query = st.chat_input("Votre question ...")
        if user_query:
            st.session_state.chat_graph_history.append(HumanMessage(content=user_query))

        with st.container():
            # Display the existing chat
            for message in st.session_state.chat_graph_history:
                if isinstance(message, AIMessage):
                    with st.chat_message("AI"):
                        st.markdown(message.content)
                elif isinstance(message, HumanMessage):
                    with st.chat_message("Human"):
                        st.write(message.content)

            # If the last message is from the user, we try to generate a response
            if (len(st.session_state.chat_graph_history) > 0 and 
                isinstance(st.session_state.chat_graph_history[-1], HumanMessage)):
                last_message = st.session_state.chat_graph_history[-1]
                with st.chat_message("AI"):
                    # Example retrieval (if you have a vectorstore in session state)
                    # and want to incorporate scenes or graph data:
                    prompt_tool_calling = ("Ta mission est de decider selon la query de l'utilisateur s'il y'a un outil qui correspont et il faut l'appeler, tu dois aussi savoir si on va supprimer une relation ou plutot ajouter une relation\n"
                                           "Tu as 2 outils , un pour supprimer une relation et l'autre ajouter une relation dans un graphe\n" 
                                           "si un outil est appelé, tu dois le dire à l'utilisateur et tu dois bien extraire les id des noeuds et le type de relation\n"
                                           "si l'id du noeud existe dans le graphe, extrait le exactement et si le type de relation existe dans le graphe, extrait le exactement\n"
                                            f"**query de l'utilisateur** : {last_message.content}\n"
                                            f"**Graph**: {format_relationships(st.session_state.edges)}\n"
                                            f"sinon tu dois renvoyé: 'Pas d'outils appelé'\n"
                                            f"les outils sont: {tools}\n"
                                            f"Output: tu dois ecrire soit 'outil appelé' apres avoir identifier les differents elements soit 'Pas d'outils appelé'\n")
                    tools_called = generate_llm_with_tools(tools=tools,query=prompt_tool_calling)
                    print(tools_called)
                    if 'tool_calls' in tools_called.additional_kwargs:
                        for tool_call in tools_called.additional_kwargs['tool_calls']:
                            func_name = tool_call["function"]["name"]

                            raw_args = tool_call["function"]["arguments"]
                            parsed_args = json.loads(raw_args)  # Convert JSON string to dict

                            source_id = parsed_args["source_id"]
                            target_id = parsed_args["target_id"]
                            relationship_type = parsed_args["relationship_type"]

                            if func_name == "AddRelationship":
                                add_relationship_to_graph(source_id, target_id, relationship_type)
                                st.write(f"Relation ajoutée: {source_id} -- {relationship_type} --> {target_id}")
                            elif func_name == "DeleteRelationship":
                                delete_relationship_from_graph(source_id, target_id, relationship_type)
                                st.write(f"Relation supprimée: {source_id} -- {relationship_type} --> {target_id}")
                    
                    if "vectorstore" in st.session_state:
                        retriever = st.session_state.vectorstore.as_retriever()
                        context = retriever.invoke(last_message.content)
                        prompt = (
                            f"Contexte depuis les 'scenes': {st.session_state.scenes}\n"
                            f"Contexte vectorstore: {context}\n"
                            f"Question: {last_message.content}\n"
                            f"Graph: {st.session_state.graph}\n"  # If you want to embed your entire graph
                        )
                        response = st.write_stream(
                            generate_response_via_langchain(prompt, stream=True)
                        )
                        st.session_state.chat_graph_history.append(AIMessage(content=response))
                    else:
                        # Fallback if no vectorstore
                        st.write("Aucune base de vecteurs disponible.")
                        st.session_state.chat_graph_history.append(AIMessage(content="(Pas de vectorstore)"))

            # If the user clicked on a node in the graph, we can propose quick prompts
            if selected_node_id:
                with st.chat_message("AI"):
                    st.markdown(f"**Vous avez sélectionné**: `{selected_node_id}`")
                    quick_prompts = [
                        f"Donne-moi plus d'informations sur le noeud '{selected_node_id}'",
                        f"Montre-moi les relations de '{selected_node_id}' dans ce graphe"
                    ]
                    for i, qprompt in enumerate(quick_prompts):
                        if st.button(qprompt, key=f"qp_{i}"):
                            st.session_state.chat_graph_history.append(HumanMessage(content=qprompt))

kg_main()