Spaces:
Running
Running
import streamlit as st | |
from textwrap import dedent | |
from utils.audit.rag import get_text_from_content_for_doc,get_text_from_content_for_audio | |
from utils.audit.response_llm import * | |
from langchain_core.messages import AIMessage, HumanMessage | |
from st_copy_to_clipboard import st_copy_to_clipboard | |
from utils.kg.construct_kg import get_graph,get_advanced_graph | |
from audit_page.knowledge_graph import * | |
import json | |
from time import sleep | |
import pickle | |
def graph_doc_to_json(graph): | |
nodes = [] | |
edges = [] | |
for node in graph.nodes: | |
node_id = node.id.replace(" ", "_") | |
label = node.id | |
type = node.type | |
nodes.append({"id": node_id, "label": label, "type": type}) | |
for relationship in graph.relationships: | |
source = relationship.source | |
source_id = source.id.replace(" ", "_") | |
target = relationship.target | |
target_id = target.id.replace(" ", "_") | |
label = relationship.type | |
edges.append({"source": source_id, "label": label, "cible": target_id}) | |
return {"noeuds": nodes, "relations": edges} | |
def advanced_graph_to_json(graph:KnowledgeGraph): | |
nodes = [] | |
edges = [] | |
for node in graph.entities: | |
node_id = node.name.replace(" ", "_") | |
label = node.name | |
type = node.label | |
nodes.append({"id": node_id, "label": label, "type": type}) | |
for relationship in graph.relationships: | |
source = relationship.startEntity | |
source_id = source.name.replace(" ", "_") | |
target = relationship.endEntity | |
target_id = target.name.replace(" ", "_") | |
label = relationship.name | |
edges.append({"source": source_id, "label": label, "cible": target_id}) | |
return {"noeuds": nodes, "relations": edges} | |
def chat_history_formatter(chat_history): | |
formatted_chat = "" | |
for message in chat_history: | |
if isinstance(message, AIMessage): | |
formatted_chat += f"AI:{message.content}\n\n" | |
elif isinstance(message, HumanMessage): | |
formatted_chat += f"Human:{message.content}\n\n" | |
return formatted_chat | |
def filter_correspondance(source_list:list[str],ref_dict:dict,reverse=False): | |
source_list = [item.lower().strip() for item in source_list] | |
if reverse: | |
return [key for key, value in ref_dict.items() if value.lower().strip() in source_list] | |
else: | |
# st.write(source_list) | |
# st.write(ref_dict.keys()) | |
return [value for key, value in ref_dict.items() if key.lower().strip() in source_list] | |
def radio_choice(): | |
options = ["compte_rendu","graphe de connaissance"] | |
choice = st.radio("Choisissez une option",options,index=st.session_state.radio_choice,horizontal=True,label_visibility="collapsed") | |
sleep(1) | |
if choice and options.index(choice) != st.session_state.radio_choice: | |
sleep(1) | |
st.session_state.radio_choice = options.index(choice) | |
return choice | |
def format_cr(cr:report): | |
formatted_cr = f"### Résumé :\n{cr.summary}\n\n### Notes :\n{cr.Notes}\n\n### Actions :\n{cr.Actions}" | |
return formatted_cr | |
def load_text_from_pkl(file_path:str): | |
with open(file_path,"rb") as f: | |
return pickle.load(f) | |
def load_graph_from_pkl(file_path:str): | |
with open(file_path,"rb") as f: | |
return pickle.load(f) | |
def doc_dialog_main(): | |
st.title("Dialogue avec le document") | |
#init cr and chat history cr | |
if "cr" not in st.session_state: | |
st.session_state.cr = "" | |
if "cr_chat_history" not in st.session_state: | |
st.session_state.cr_chat_history = [ | |
] | |
#init graph and filter views | |
if "graph" not in st.session_state: | |
st.session_state.graph = None | |
st.session_state.current_chunk_index = 0 | |
st.session_state.number_of_entities = 0 | |
st.session_state.number_of_relationships = 0 | |
if "filter_views" not in st.session_state: | |
st.session_state.filter_views = {} | |
if "current_view" not in st.session_state: | |
st.session_state.current_view = None | |
if "node_types" not in st.session_state: | |
st.session_state.node_types = None | |
# if "summary" not in st.session_state: | |
# st.session_state.summary = None | |
if "chat_graph_history" not in st.session_state: | |
st.session_state.chat_graph_history = [] | |
global_graph = load_graph_from_pkl("./utils/assets/kg_ia_signature.pkl") | |
st.write("graphe global chargé") | |
st.session_state.graph = global_graph | |
st.write("graphe global assigné") | |
# st.session_state.current_chunk_index = 0 | |
# st.session_state.filter_views = {} | |
# st.session_state.current_view = None | |
# st.session_state.node_types = None | |
# st.session_state.chat_graph_history = [] | |
st.write("searching for node types") | |
node_types = get_node_types_advanced(st.session_state.graph) | |
st.write("types de noeuds obtenus") | |
list_node_types = list(node_types) | |
sorted_node_types = sorted(list_node_types,key=lambda x: x.lower()) | |
print(sorted_node_types) | |
st.write("tri des types de noeuds effectué") | |
nodes_type_dict = list_to_dict_colors(sorted_node_types) | |
st.write("dictionnaire de types de noeuds créé") | |
st.session_state.node_types = nodes_type_dict | |
st.session_state.filter_views["Vue par défaut"] = list(node_types) | |
st.session_state.current_view = "Vue par défaut" | |
st.write("finished init") | |
####################################################################### | |
#init a radio button for the choice | |
if "radio_choice" not in st.session_state: | |
st.session_state.radio_choice = None | |
# if "choice" not in st.session_state: | |
# st.session_state.choice = st.radio("Choisissez une option",["compte_rendu","graphe de connaissance"],index=st.session_state.radio_choice,horizontal=True,label_visibility="collapsed") | |
# choice = radio_choice() | |
options = ["compte_rendu","graphe de connaissance"] | |
choice = st.radio("Choisissez une option",options,index=st.session_state.radio_choice,horizontal=True,label_visibility="collapsed") | |
if choice and options.index(choice) != st.session_state.radio_choice: | |
st.session_state.radio_choice = options.index(choice) | |
audit = {"Mots clés": ""} | |
content = {} | |
text = load_text_from_pkl("./utils/assets/scenes.pkl") | |
st.write(text) | |
prompt_cr = dedent(f''' | |
À partir du document ci-dessous, générez un compte rendu détaillé contenant les sections suivantes : | |
2. **Résumé** : Fournissez une synthèse complète du document, en mettant en avant les points principaux, les relations essentielles, les concepts , les dates et les lieux, les conclusions et les détails importants. | |
3. **Notes** : | |
- Présentez les points clés sous forme de liste à puces avec des émojis pertinents pour souligner la nature de chaque point. | |
- N'oubliez pas de relever tout les entités et les relations. | |
- Incluez des sous-points (sans émojis) sous les points principaux pour offrir des détails ou explications supplémentaires. | |
4. **Actions** : Identifiez et listez les actions spécifiques, tâches ou étapes recommandées ou nécessaires selon le contenu du document. | |
**Document :** | |
{text} | |
**Format de sortie :** | |
### Résumé : | |
[Fournissez un résumé concis du document ici;n'oubliez pas de relever tout les entités et les relations.] | |
### Notes : | |
- 📌 **Point Principal 1** | |
- Sous-point A | |
- Sous-point B | |
- 📈 **Point Principal 2** | |
- Sous-point C | |
- Sous-point D | |
- 📝 **Point Principal 3** | |
- Sous-point E | |
- Sous-point F | |
### Actions : | |
1. [Action 1] | |
2. [Action 2] | |
3. [Action 3] | |
4. ... | |
--- | |
''') | |
prompt_cr2 = dedent(F''' | |
À partir du document ci-dessous, identifiez le type d'ecrit puis, générez un compte rendu détaillé contenant les sections suivantes : | |
2. **Résumé** : Fournissez une synthèse complète du document, en mettant en avant les points principaux, les relations essentielles, les concepts , les dates et les lieux, les conclusions et les détails importants. | |
3. **Notes** : | |
- Présentez les points clés sous forme de liste à puces avec des émojis pertinents pour souligner la nature de chaque point. | |
- N'oubliez pas de relever tout les entités et les relations. | |
- Incluez des sous-points (sans émojis) sous les points principaux pour offrir des détails ou explications supplémentaires. | |
4. **Actions** : Identifiez et listez les actions spécifiques, tâches ou étapes recommandées ou nécessaires selon le contenu du document. | |
**Document :** | |
{text} | |
*Sortie :** | |
Soit exhaustive dans votre réponse, en incluant toutes les informations pertinentes et en les structurant de manière claire et précise, voici des mots clés extraits du document: {audit['Mots clés'].strip()}. | |
''') | |
if choice == "compte_rendu": | |
if "cr" not in st.session_state or st.session_state.cr == "": | |
with st.spinner("Génération du compte rendu..."): | |
#cr = generate_response_via_langchain(prompt_cr,stream=False,model="gpt-4o") | |
cr = generate_structured_response(prompt_cr2) | |
st.session_state.cr = cr | |
st.session_state.cr_chat_history = [] | |
else: | |
cr = st.session_state.cr | |
if cr: | |
col1, col2 = st.columns([2.5, 1.5]) | |
with col1.container(border=True,height=850): | |
st.markdown("##### Compte rendu") | |
keywords_paragraph = f"### Mots clés extraits:\n- {audit['Mots clés'].strip()}" | |
with st.container(height=650,border=False): | |
st.markdown(keywords_paragraph) | |
st.write(format_cr(cr)) | |
# col_copy , col_success = st.columns([1,11]) | |
# if col_copy.button("📋",key="copy_cr"): | |
with st.container(height=50,border=False): | |
st_copy_to_clipboard(f"{keywords_paragraph}\n\n{cr}",key="cp_but_cr") | |
# col_success.success("Compte rendu copié dans le presse-papier") | |
with col2.container(border=True,height=850): | |
st.markdown("##### Dialoguer avec le CR") | |
user_query = st.chat_input("Par ici ...") | |
if user_query is not None and user_query != "": | |
st.session_state.cr_chat_history.append(HumanMessage(content=user_query)) | |
with st.container(height=600, border=False): | |
for message in st.session_state.cr_chat_history: | |
if isinstance(message, AIMessage): | |
with st.chat_message("AI"): | |
st.markdown(message.content) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("Human"): | |
st.write(message.content) | |
#check if last message is human message | |
if len(st.session_state.cr_chat_history) > 0: | |
last_message = st.session_state.cr_chat_history[-1] | |
if isinstance(last_message, HumanMessage): | |
with st.chat_message("AI"): | |
retreive = st.session_state.vectorstore.as_retriever() | |
context = retreive.invoke(last_message.content) | |
wrapped_prompt = f'''Étant donné le contexte suivant {context} et le compte rendu du document {cr}, {last_message.content}''' | |
response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True)) | |
st.session_state.cr_chat_history.append(AIMessage(content=response)) | |
# col_copy_c , col_success_c = st.columns([1,7]) | |
# if col_copy_c.button("📋",key="copy_cr_chat"): | |
with st.container(height=50,border=False): | |
chat_formatted = chat_history_formatter(st.session_state.cr_chat_history) | |
st_copy_to_clipboard(chat_formatted,key="cp_but_cr_chat",show_text=False) | |
# col_success_c.success("Historique copié !") | |
elif choice == "graphe de connaissance": | |
# st.write(st.session_state.graph) | |
if "graph" not in st.session_state or st.session_state.graph == None: | |
keywords_list = [keyword.strip() for keyword in audit["Mots clés"].strip().split(",")] | |
with st.spinner("Construction du graphe de connaissance..."): | |
#graph = get_graph(text,allowed_nodes=allowed_nodes_types) | |
# chunk = st.session_state.chunks[st.session_state.current_chunk_index] | |
# print(chunk) | |
graph = global_graph | |
st.session_state.graph = graph | |
st.session_state.current_chunk_index = 0 | |
st.session_state.filter_views = {} | |
st.session_state.current_view = None | |
st.session_state.node_types = None | |
st.session_state.chat_graph_history = [] | |
node_types = get_node_types_advanced(graph) | |
list_node_types = list(node_types) | |
sorted_node_types = sorted(list_node_types,key=lambda x: x.lower()) | |
print(sorted_node_types) | |
nodes_type_dict = list_to_dict_colors(sorted_node_types) | |
st.session_state.node_types = nodes_type_dict | |
st.session_state.filter_views["Vue par défaut"] = list(node_types) | |
st.session_state.current_view = "Vue par défaut" | |
else: | |
graph = st.session_state.graph | |
if graph is not None: | |
#st.write(graph) | |
edges,nodes,config = convert_advanced_neo4j_to_agraph(graph,st.session_state.node_types) | |
if "number_of_entities" not in st.session_state or "number_of_relationships" not in st.session_state or st.session_state.number_of_entities == 0 or st.session_state.number_of_relationships == 0: | |
st.session_state.number_of_entities = len(nodes) | |
st.session_state.number_of_relationships = len(edges) | |
col1, col2 = st.columns([2.5, 1.5]) | |
with col1.container(border=True,height=900): | |
st.write("##### Visualisation du graphe (**"+st.session_state.current_view+"**)") | |
filter_col,add_view_col,change_view_col,color_col = st.columns([9,1,1,1]) | |
if color_col.button("🎨",help="Changer la couleur"): | |
change_color_dialog() | |
if change_view_col.button("🔍",help="Changer de vue"): | |
change_view_dialog() | |
#add mots cles to evry label in audit["Mots clés"] | |
#filter_labels = [ label + " (mot clé)" if label.strip().lower() in audit["Mots clés"].strip().lower().split(",") else label for label in st.session_state.filter_views[st.session_state.current_view] ] | |
keywords_list = [keyword.strip().lower() for keyword in audit["Mots clés"].strip().split(",")] | |
dict_filters = {label: "Mot clé : "+label if label.strip().lower() in keywords_list else label for label in st.session_state.filter_views[st.session_state.current_view]} | |
default_target_filter = filter_correspondance(st.session_state.filter_views[st.session_state.current_view],dict_filters) | |
# st.write(default_target_filter) | |
# st.write(dict_filters) | |
sorted_default_target_filter = sorted(default_target_filter,key=lambda x: x.lower()) | |
target_filter = filter_correspondance(list(st.session_state.node_types.keys()),dict_filters) | |
target_filter = sorted(target_filter,key=lambda x: x.lower()) | |
filter = filter_col.multiselect("Filtrer selon l'étiquette",target_filter,placeholder="Sélectionner une ou plusieurs étiquettes",default=default_target_filter,label_visibility="collapsed") | |
filter = filter_correspondance(filter,dict_filters,reverse=True) | |
if add_view_col.button("➕",help="Ajouter une vue"): | |
add_view_dialog(filter) | |
if filter: | |
nodes = filter_nodes_by_types(nodes,filter) | |
selected = display_graph(edges,nodes,config) | |
# col_copy , col_success = st.columns([1,11]) | |
# if col_copy.button("📋",key="copy_graph"): | |
with st.container(height=100,border=False): | |
graph_json = advanced_graph_to_json(graph) | |
subcol1,subcol2,subcol3 = st.columns([1,2,7]) | |
with subcol1: | |
st_copy_to_clipboard(json.dumps(graph_json),key="cp_but_graph") | |
generate_button = subcol2.button("génerer plus",key="generate_more") | |
with subcol3: | |
if generate_button: | |
if st.session_state.current_chunk_index >= len(st.session_state.chunks): | |
st.info("Tous les chunks ont été traités") | |
else: | |
with st.spinner(f"Regénération du graphe en incluant le chunk {st.session_state.current_chunk_index} ..."): | |
new_graph = get_advanced_graph(st.session_state.chunks[st.session_state.current_chunk_index],st.session_state.graph) | |
st.session_state.graph = new_graph | |
st.session_state.current_chunk_index += 1 | |
st.session_state.number_of_entities = len(new_graph.entities) | |
st.session_state.number_of_relationships = len(new_graph.relationships) | |
st.session_state.filter_views = {} | |
st.session_state.current_view = None | |
st.session_state.node_types = None | |
st.session_state.node_types = get_node_types_advanced(new_graph) | |
list_node_types = list(st.session_state.node_types) | |
sorted_node_types = sorted(list_node_types,key=lambda x: x.lower()) | |
nodes_type_dict = list_to_dict_colors(sorted_node_types) | |
st.session_state.node_types = nodes_type_dict | |
st.session_state.filter_views["Vue par défaut"] = list(st.session_state.node_types) | |
st.session_state.current_view = "Vue par défaut" | |
st.rerun() | |
else: | |
st.write(f"{st.session_state.current_chunk_index}/ {len(st.session_state.chunks)} chunks traités ({st.session_state.number_of_entities} entités, {st.session_state.number_of_relationships} relations)") | |
# col_success.success("Graphe copié dans le presse-papier") | |
with col2.container(border=True,height=900): | |
st.markdown("##### Dialoguer avec le graphe") | |
user_query = st.chat_input("Par ici ...") | |
if user_query is not None and user_query != "": | |
st.session_state.chat_graph_history.append(HumanMessage(content=user_query)) | |
with st.container(height=600, border=False): | |
for message in st.session_state.chat_graph_history: | |
if isinstance(message, AIMessage): | |
with st.chat_message("AI"): | |
st.markdown(message.content) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("Human"): | |
st.write(message.content) | |
#check if last message is human message | |
if len(st.session_state.chat_graph_history) > 0: | |
last_message = st.session_state.chat_graph_history[-1] | |
if isinstance(last_message, HumanMessage): | |
with st.chat_message("AI"): | |
retreive = st.session_state.vectorstore.as_retriever() | |
context = retreive.invoke(last_message.content) | |
wrapped_prompt = f"Étant donné le contexte suivant {context}, et le graph de connaissance: {graph}, {last_message.content}" | |
response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True)) | |
st.session_state.chat_graph_history.append(AIMessage(content=response)) | |
if selected is not None: | |
with st.chat_message("AI"): | |
st.markdown(f" EXPLORER LES DONNEES CONTENUES DANS **{selected}**") | |
prompts = [f"Extrait moi toutes les informations du noeud ''{selected}'' ➡️", | |
f"Montre moi les conversations autour du noeud ''{selected}'' ➡️"] | |
for i,prompt in enumerate(prompts): | |
button = st.button(prompt,key=f"p_{i}",on_click=lambda i=i: st.session_state.chat_graph_history.append(HumanMessage(content=prompts[i]))) | |
# col_copy_c , col_success_c = st.columns([1,7]) | |
# if col_copy_c.button("📋",key="copy_graph_chat"): | |
with st.container(height=50,border=False): | |
st_copy_to_clipboard(chat_history_formatter(st.session_state.chat_graph_history),key="cp_but_graph_chat",show_text=False) | |
# col_success_c.success("Historique copié !") | |
doc_dialog_main() | |