Carto-RSE / partie_prenante_carte.py
Ilyas KHIAT
test
5f3c554
import streamlit as st
import pandas as pd
import numpy as np
import re
import random
import time
import streamlit as st
from dotenv import load_dotenv
from langchain_experimental.text_splitter import SemanticChunker
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.chat_models import ChatOpenAI
from langchain import hub
from langchain_core.runnables import RunnablePassthrough
from langchain_community.document_loaders import WebBaseLoader,FireCrawlLoader,PyPDFLoader
from langchain_core.prompts.prompt import PromptTemplate
import os
from high_chart import test_chart
from chat_with_pps import get_response
from ecologits.tracers.utils import compute_llm_impacts
from codecarbon import EmissionsTracker
load_dotenv()
def get_docs_from_website(urls):
loader = WebBaseLoader(urls, header_template={
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36',
})
try:
docs = loader.load()
return docs
except Exception as e:
return None
def get_docs_from_website_fc(urls,firecrawl_api_key):
docs = []
try:
for url in urls:
loader = FireCrawlLoader(api_key=firecrawl_api_key, url = url,mode="scrape")
docs+=loader.load()
return docs
except Exception as e:
return None
def get_doc_chunks(docs):
# Split the loaded data
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=500,
# chunk_overlap=100)
text_splitter = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-small"))
docs = text_splitter.split_documents(docs)
return docs
def get_doc_chunks_fc(docs):
# Split the loaded data
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=500,
# chunk_overlap=100)
text_splitter = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-small"))
docs_splitted = []
for text in docs:
text_splitted = text_splitter.split_text(text)
docs_splitted+=text_splitted
return docs_splitted
def get_vectorstore_from_docs(doc_chunks):
embedding = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = FAISS.from_documents(documents=doc_chunks, embedding=embedding)
return vectorstore
def get_vectorstore_from_text(texts):
embedding = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = FAISS.from_texts(texts=texts, embedding=embedding)
return vectorstore
def get_conversation_chain(vectorstore):
llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048)
retriever=vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
# Chain
rag_chain = (
{"context": retriever , "question": RunnablePassthrough()}
| prompt
| llm
)
return rag_chain
# FILL THE PROMPT FOR THE QUESTION VARIABLE THAT WILL BE USED IN THE RAG PROMPT, ATTENTION NOT CONFUSE WITH THE RAG PROMPT
def fill_promptQ_template(input_variables, template):
prompt = PromptTemplate(input_variables=["BRAND_NAME","BRAND_DESCRIPTION"], template=template)
return prompt.format(BRAND_NAME=input_variables["BRAND_NAME"], BRAND_DESCRIPTION=input_variables["BRAND_DESCRIPTION"])
def text_to_list(text):
lines = text.replace("- ","").split('\n')
lines = [line.split() for line in lines]
items = [[' '.join(line[:-1]),line[-1]] for line in lines]
# Assuming `items` is the list of items
for item in items:
item[1] = re.sub(r'\D', '', item[1])
return items
def delete_pp(pps):
for pp in pps:
for i in range(len(st.session_state['pp_grouped'])):
if st.session_state['pp_grouped'][i]['name'] == pp:
del st.session_state['pp_grouped'][i]
break
def display_list_urls():
for index, item in enumerate(st.session_state["urls"]):
emp = st.empty() # Create an empty placeholder
col1, col2 = emp.columns([7, 3]) # Divide the space into two columns
# Button to delete the entry, placed in the second column
if col2.button("❌", key=f"but{index}"):
temp = st.session_state['parties_prenantes'][index]
delete_pp(temp)
del st.session_state.urls[index]
del st.session_state["parties_prenantes"][index]
st.rerun() # Rerun the app to update the display
if len(st.session_state.urls) > index:
# Instead of using markdown, use an expander in the first column
with col1.expander(f"Source {index+1}: {item}"):
pp = st.session_state["parties_prenantes"][index]
st.write(pd.DataFrame(pp, columns=["Partie prenante"]))
else:
emp.empty() # Clear the placeholder if the index exceeds the list
def colored_circle(color):
return f'<span style="display: inline-block; width: 15px; height: 15px; border-radius: 50%; background-color: {color};"></span>'
def display_list_pps():
for index, item in enumerate(st.session_state["pp_grouped"]):
emp = st.empty()
col1, col2 = emp.columns([7, 3])
if col2.button("❌", key=f"butp{index}"):
del st.session_state["pp_grouped"][index]
st.rerun()
if len(st.session_state["pp_grouped"]) > index:
name = st.session_state["pp_grouped"][index]["name"]
col1.markdown(f'<p>{colored_circle(st.session_state["pp_grouped"][index]["color"])} {st.session_state["pp_grouped"][index]["name"]}</p>',
unsafe_allow_html=True
)
else:
emp.empty()
def extract_pp(docs,input_variables):
template_extraction_PP = """
Objectif : Identifiez toutes les parties prenantes de la marque suivante :
Le nom de la marque de référence est le suivant : {BRAND_NAME}
TA RÉPONSE DOIT ÊTRE SOUS FORME DE LISTE DE NOMS DE MARQUES, CHAQUE NOM SUR UNE LIGNE SÉPARÉE.
"""
#don't forget to add the input variables from the maim function
if docs == None:
return "445"
#get text chunks
text_chunks = get_doc_chunks(docs)
#create vectorstore
vectorstore = get_vectorstore_from_docs(text_chunks)
chain = get_conversation_chain(vectorstore)
question = fill_promptQ_template(input_variables, template_extraction_PP)
start = time.perf_counter()
response = chain.invoke(question)
response_latency = time.perf_counter() - start
# version plus poussée a considérer
# each item in the list is a list with the name of the brand and the similarity percentage
# partie_prenante = text_to_list(response.content)
if "ne sais pas" in response.content:
return "444"
#calculate impact
nbre_out_tokens = response.response_metadata["token_usage"]["completion_tokens"]
provider = "openai"
model = "gpt-4o"
impact = compute_llm_impacts(
provider=provider,
model_name=model,
output_token_count=nbre_out_tokens,
request_latency=response_latency,
)
st.session_state["partial_emissions"]["extraction_pp"]["el"] += impact.gwp.value
#version simple
partie_prenante = response.content.replace("- ","").split('\n')
partie_prenante = [item.strip() for item in partie_prenante]
return partie_prenante
def generate_random_color():
# Generate random RGB values
r = random.randint(0, 255)
g = random.randint(0, 255)
b = random.randint(0, 255)
# Convert RGB to hexadecimal
color_hex = '#{:02x}{:02x}{:02x}'.format(r, g, b)
return color_hex
def format_pp_add_viz(pp):
y = 50
x = 50
for i in range(len(st.session_state['pp_grouped'])):
if st.session_state['pp_grouped'][i]['y'] == y and st.session_state['pp_grouped'][i]['x'] == x:
y += 5
if y > 95:
y = 50
x += 5
if st.session_state['pp_grouped'][i]['name'] == pp:
return None
else:
st.session_state['pp_grouped'].append({'name':pp, 'x':x,'y':y, 'color':generate_random_color()})
def add_pp(new_pp, default_value=50):
new_pp = sorted(new_pp)
new_pp = [item.lower().capitalize().strip() for item in new_pp]
st.session_state['parties_prenantes'].append(new_pp)
for pp in new_pp:
format_pp_add_viz(pp)
def add_existing_pps(pp,pouvoir,influence):
for i in range(len(st.session_state['pp_grouped'])):
if st.session_state['pp_grouped'][i]['name'] == pp:
st.session_state['pp_grouped'][i]['x'] = influence
st.session_state['pp_grouped'][i]['y'] = pouvoir
return None
st.session_state['pp_grouped'].append({'name':pp, 'x':influence,'y':pouvoir, 'color':generate_random_color()})
def load_csv(file):
df = pd.read_csv(file)
for index, row in df.iterrows():
add_existing_pps(row['parties prenantes'],row['pouvoir'],row['influence'])
def add_pp_input_text():
new_pp = st.text_input("Ajouter une partie prenante")
if st.button("Ajouter",key="add_single_pp"):
format_pp_add_viz(new_pp)
def complete_and_verify_url(partial_url):
# Regex pattern for validating a URL
regex = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,8}\.?|' # domain
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
regex = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,8}\.?|' # domain name
r'localhost|' # or localhost
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # or IPv4 address
r'(?::\d+)?' # optional port
r'(?:[/?#][^\s]*)?$', # optional path, query, or fragment
re.IGNORECASE)
# Complete the URL if it doesn't have http:// or https://
if not partial_url.startswith(('http://', 'https://', 'www.')):
if not partial_url.startswith('www.'):
complete_url = 'https://www.' + partial_url
else:
complete_url = 'https://' + partial_url
elif partial_url.startswith('www.'):
complete_url = 'https://' + partial_url
else:
complete_url = partial_url
# Check if the URL is valid
if re.match(regex, complete_url):
return (True, complete_url)
else:
return (False, complete_url)
@st.dialog("Conseil IA",width="large")
def show_conseil_ia():
prompt = "Prenant compte les données de l'entreprise (activité, produits, services ...), quelles sont les principales parties prenantes à animer pour une démarche RSE réussie ?"
st.markdown(f"**{prompt}**")
response = st.write_stream(get_response(prompt, "",st.session_state["latest_doc"][0].page_content))
st.warning("Quittez et saisissez une autre URL")
def display_pp():
if "emission" not in st.session_state:
tracker = EmissionsTracker()
tracker.start()
st.session_state["emission"] = tracker
load_dotenv()
fire_crawl_api_key = os.getenv("FIRECRAWL_API_KEY")
#check if brand name and description are already set
if "Nom de la marque" not in st.session_state:
st.session_state["Nom de la marque"] = ""
#check if urls and partie prenante are already set
if "urls" not in st.session_state:
st.session_state["urls"] = []
if "parties_prenantes" not in st.session_state:
st.session_state['parties_prenantes'] = []
if "pp_grouped" not in st.session_state: #servira pour le plot et la cartographie des parties prenantes, regroupe sans doublons
st.session_state['pp_grouped'] = []
if "latest_doc" not in st.session_state:
st.session_state['latest_doc'] = ""
if "not_pp" not in st.session_state:
st.session_state["not_pp"] = ""
st.title("IDENTIFIER ET ANIMER VOS PARTIES PRENANTES")
#set brand name and description
brand_name = st.text_input("Nom de la marque", st.session_state["Nom de la marque"])
st.session_state["Nom de la marque"] = brand_name
option = st.radio("Source", ("A partir de votre site web", "A partir de vos documents entreprise","A partir de cartographie existante"))
#if the user chooses to extract from website
if option == "A partir de votre site web":
url = st.text_input("Ajouter une URL")
captions = ["L’IA prend en compte uniquement les textes contenus dans les pages web analysées","L’IA prend en compte les textes, les images et les liens URL contenus dans les pages web analysées"]
scraping_option = st.radio("Mode", ("Analyse rapide", "Analyse profonde"),horizontal=True,captions = captions)
#if the user clicks on the button
if st.button("ajouter",key="add_pp"):
st.session_state["not_pp"] = ""
#complete and verify the url
is_valid,url = complete_and_verify_url(url)
if not is_valid:
st.error("URL invalide")
elif url in st.session_state["urls"] :
st.error("URL déjà ajoutée")
else:
if scraping_option == "Analyse profonde":
with st.spinner("Collecte des données..."):
docs = get_docs_from_website_fc([url],fire_crawl_api_key)
if docs is None:
st.warning("Erreur lors de la collecte des données, 2eme essai avec collecte rapide...")
with st.spinner("2eme essai, collecte rapide..."):
docs = get_docs_from_website([url])
if scraping_option == "Analyse rapide":
with st.spinner("Collecte des données..."):
docs = get_docs_from_website([url])
if docs is None:
st.error("Erreur lors de la collecte des données, URL unvalide")
st.session_state["latest_doc"] = ""
else:
# Création de l'expander
st.session_state["partial_emissions"]["Scrapping"]["cc"] = st.session_state["emission"].stop()
st.session_state["latest_doc"] = docs
with st.spinner("Processing..."):
#handle the extraction
input_variables = {"BRAND_NAME": brand_name, "BRAND_DESCRIPTION": ""}
partie_prenante = extract_pp(docs, input_variables)
if "444" in partie_prenante: #444 is the code for no brand found , chosen
st.session_state["not_pp"] = "444"
elif "445" in partie_prenante: #445 is the code for no website found with the given url
st.error("Aucun site web trouvé avec l'url donnée")
st.session_state["not_pp"] = ""
else:
st.session_state["not_pp"] = ""
partie_prenante = sorted(partie_prenante)
st.session_state["urls"].append(url)
add_pp(partie_prenante)
st.session_state["partial_emissions"]["extraction_pp"]["cc"] = st.session_state["emission"].stop()
# alphabet = [ pp[0] for pp in partie_prenante]
# pouvoir = [ 50 for _ in range(len(partie_prenante))]
# df = pd.DataFrame({'partie_prenante': partie_prenante, 'pouvoir': pouvoir, 'code couleur': partie_prenante})
# st.write(df)
# c = (
# alt.Chart(df)
# .mark_circle(size=300)
# .encode(x="partie_prenante", y=alt.Y("pouvoir",scale=alt.Scale(domain=[0,100])), color="code couleur")
# )
# st.subheader("Vertical Slider")
# age = st.slider("How old are you?", 0, 130, 25)
# st.write("I'm ", age, "years old")
# disp_vertical_slider(partie_prenante)
# st.altair_chart(c, use_container_width=True)
if option =="A partir de vos documents entreprise":
uploaded_file = st.file_uploader("Télécharger le fichier PDF", type="pdf")
if uploaded_file is not None:
if st.button("ajouter",key="add_pp_pdf"):
st.session_state["not_pp"] = ""
with st.spinner("Processing..."):
file_name = uploaded_file.name
with open(file_name, mode='wb') as w:
w.write(uploaded_file.getvalue())
pdf = PyPDFLoader(file_name)
text = pdf.load()
st.session_state["latest_doc"] = text
input_variables = {"BRAND_NAME": brand_name, "BRAND_DESCRIPTION": ""}
partie_prenante = extract_pp(text, input_variables)
if "444" in partie_prenante: #444 is the code for no brand found , chosen
st.session_state["not_pp"] = "444"
elif "445" in partie_prenante: #445 is the code for no website found with the given url
st.error("Aucun site web trouvé avec l'url donnée")
st.session_state["not_pp"] = ""
else:
st.session_state["not_pp"] = ""
partie_prenante = sorted(partie_prenante)
st.session_state["urls"].append(file_name)
add_pp(partie_prenante)
if option == "A partir de cartographie existante":
uploaded_file = st.file_uploader("Télécharger le fichier CSV", type="csv")
if uploaded_file is not None:
if st.button("ajouter",key="add_pp_csv"):
file_name = uploaded_file.name
with open(file_name, mode='wb') as w:
w.write(uploaded_file.getvalue())
try:
load_csv(file_name)
brand_name_from_csv = file_name.split("-")[1]
st.session_state["Nom de la marque"] = brand_name_from_csv
except Exception as e:
st.error("Erreur lors de la lecture du fichier")
if st.session_state["not_pp"] == "444":
st.warning("Aucune parties prenantes n'est identifiable sur l'URL fournie. Fournissez une autre URL ou bien cliquez sur le boutton ci-dessous pour un Conseils IA")
if st.button("Conseil IA"):
show_conseil_ia()
#display docs
if st.session_state["latest_doc"] != "":
with st.expander("Cliquez ici pour éditer et voir le document"):
docs = st.session_state["latest_doc"]
cleaned_text = re.sub(r'\n\n+', '\n\n', docs[0].page_content.strip())
text_value = st.text_area("Modifier le texte ci-dessous:", value=cleaned_text, height=300)
if st.button('Sauvegarder',key="save_doc_fake"):
st.success("Texte sauvegardé avec succès!")
display_list_urls()
with st.expander("Liste des parties prenantes"):
add_pp_input_text()
display_list_pps()
test_chart()