|
from climateqa.engine.embeddings import get_embeddings_function |
|
embeddings_function = get_embeddings_function() |
|
|
|
from climateqa.knowledge.openalex import OpenAlex |
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
oa = OpenAlex() |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import time |
|
import re |
|
import json |
|
|
|
from gradio import ChatMessage |
|
|
|
|
|
|
|
from io import BytesIO |
|
import base64 |
|
|
|
from datetime import datetime |
|
from azure.storage.fileshare import ShareServiceClient |
|
|
|
from utils import create_user_id |
|
|
|
from langchain_chroma import Chroma |
|
from collections import defaultdict |
|
from gradio_modal import Modal |
|
|
|
|
|
|
|
|
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.vectorstore import get_pinecone_vectorstore |
|
|
|
from climateqa.engine.reranker import get_reranker |
|
from climateqa.engine.embeddings import get_embeddings_function |
|
from climateqa.engine.chains.prompts import audience_prompts |
|
from climateqa.sample_questions import QUESTIONS |
|
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES |
|
from climateqa.utils import get_image_from_azure_blob_storage |
|
from climateqa.engine.keywords import make_keywords_chain |
|
|
|
from climateqa.engine.graph import make_graph_agent,display_graph |
|
from climateqa.engine.embeddings import get_embeddings_function |
|
|
|
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs |
|
|
|
from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox |
|
|
|
|
|
try: |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
except Exception as e: |
|
pass |
|
|
|
|
|
|
|
theme = gr.themes.Base( |
|
primary_hue="blue", |
|
secondary_hue="red", |
|
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
|
|
|
|
init_prompt = "" |
|
|
|
system_template = { |
|
"role": "system", |
|
"content": init_prompt, |
|
} |
|
|
|
account_key = os.environ["BLOB_ACCOUNT_KEY"] |
|
if len(account_key) == 86: |
|
account_key += "==" |
|
|
|
credential = { |
|
"account_key": account_key, |
|
"account_name": os.environ["BLOB_ACCOUNT_NAME"], |
|
} |
|
|
|
account_url = os.environ["BLOB_ACCOUNT_URL"] |
|
file_share_name = "climateqa" |
|
service = ShareServiceClient(account_url=account_url, credential=credential) |
|
share_client = service.get_share_client(file_share_name) |
|
|
|
user_id = create_user_id() |
|
|
|
vectorstore_graphs = Chroma(persist_directory="/home/tim/ai4s/climate_qa/climate-question-answering/data/vectorstore_owid", embedding_function=embeddings_function) |
|
|
|
|
|
|
|
vectorstore = get_pinecone_vectorstore(embeddings_function) |
|
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) |
|
reranker = get_reranker("nano") |
|
|
|
|
|
|
|
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker) |
|
|
|
async def chat(query,history,audience,sources,reports,current_graphs): |
|
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: |
|
(messages in gradio format, messages in langchain format, source documents)""" |
|
|
|
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
print(f">> NEW QUESTION ({date_now}) : {query}") |
|
|
|
if audience == "Children": |
|
audience_prompt = audience_prompts["children"] |
|
elif audience == "General public": |
|
audience_prompt = audience_prompts["general"] |
|
elif audience == "Experts": |
|
audience_prompt = audience_prompts["experts"] |
|
else: |
|
audience_prompt = audience_prompts["experts"] |
|
|
|
|
|
if sources is None or len(sources) == 0: |
|
sources = ["IPCC", "IPBES", "IPOS"] |
|
|
|
if reports is None or len(reports) == 0: |
|
reports = [] |
|
|
|
inputs = {"user_input": query,"audience": audience_prompt,"sources":sources} |
|
result = agent.astream_events(inputs,version = "v1") |
|
|
|
|
|
|
|
|
|
|
|
|
|
docs = [] |
|
docs_used = True |
|
docs_html = "" |
|
output_query = "" |
|
output_language = "" |
|
output_keywords = "" |
|
gallery = [] |
|
updates = [] |
|
start_streaming = False |
|
graphs_html = "" |
|
figures = '<div class="figures-container"><p></p> </div>' |
|
|
|
steps_display = { |
|
"categorize_intent":("ποΈ Analyzing user message",True), |
|
"transform_query":("ποΈ Thinking step by step to answer the question",True), |
|
"retrieve_documents":("ποΈ Searching in the knowledge base",False), |
|
} |
|
|
|
used_documents = [] |
|
answer_message_content = "" |
|
try: |
|
async for event in result: |
|
if "langgraph_node" in event["metadata"]: |
|
node = event["metadata"]["langgraph_node"] |
|
|
|
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" : |
|
try: |
|
docs = event["data"]["output"]["documents"] |
|
docs_html = [] |
|
textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"] |
|
for i, d in enumerate(textual_docs, 1): |
|
if d.metadata["chunk_type"] == "text": |
|
docs_html.append(make_html_source(d, i)) |
|
|
|
used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs] |
|
history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents)) |
|
|
|
docs_html = "".join(docs_html) |
|
|
|
except Exception as e: |
|
print(f"Error getting documents: {e}") |
|
print(event) |
|
|
|
elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": |
|
event_description,display_output = steps_display[node] |
|
if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: |
|
history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description})) |
|
|
|
elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]: |
|
if start_streaming == False: |
|
start_streaming = True |
|
history.append(ChatMessage(role="assistant", content = "")) |
|
answer_message_content += event["data"]["chunk"].content |
|
answer_message_content = parse_output_llm_with_sources(answer_message_content) |
|
history[-1] = ChatMessage(role="assistant", content = answer_message_content) |
|
|
|
|
|
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end": |
|
try: |
|
recommended_content = event["data"]["output"]["recommended_content"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_graphs = [] |
|
seen_embeddings = set() |
|
|
|
for x in recommended_content: |
|
embedding = x.metadata["returned_content"] |
|
|
|
|
|
if embedding not in seen_embeddings: |
|
unique_graphs.append({ |
|
"embedding": embedding, |
|
"metadata": { |
|
"source": x.metadata["source"], |
|
"category": x.metadata["category"] |
|
} |
|
}) |
|
|
|
seen_embeddings.add(embedding) |
|
|
|
|
|
categories = {} |
|
for graph in unique_graphs: |
|
category = graph['metadata']['category'] |
|
if category not in categories: |
|
categories[category] = [] |
|
categories[category].append(graph['embedding']) |
|
|
|
|
|
for category, embeddings in categories.items(): |
|
graphs_html += f"<h3>{category}</h3>" |
|
for embedding in embeddings: |
|
graphs_html += f"<div>{embedding}</div>" |
|
|
|
|
|
except Exception as e: |
|
print(f"Error getting graphs: {e}") |
|
|
|
|
|
|
|
if event["name"] == "transform_query" and event["event"] =="on_chain_end": |
|
if hasattr(history[-1],"content"): |
|
history[-1].content += "Decompose question into sub-questions: \n\n - " + "\n - ".join([q["question"] for q in event["data"]["output"]["remaining_questions"]]) |
|
|
|
if event["name"] == "categorize_intent" and event["event"] == "on_chain_start": |
|
print("X") |
|
|
|
yield history, docs_html, output_query, output_language, gallery, figures, graphs_html |
|
|
|
except Exception as e: |
|
print(event, "has failed") |
|
raise gr.Error(f"{e}") |
|
|
|
|
|
try: |
|
|
|
if os.getenv("GRADIO_ENV") != "local": |
|
timestamp = str(datetime.now().timestamp()) |
|
file = timestamp + ".json" |
|
prompt = history[1]["content"] |
|
logs = { |
|
"user_id": str(user_id), |
|
"prompt": prompt, |
|
"query": prompt, |
|
"question":output_query, |
|
"sources":sources, |
|
"docs":serialize_docs(docs), |
|
"answer": history[-1].content, |
|
"time": timestamp, |
|
} |
|
log_on_azure(file, logs, share_client) |
|
except Exception as e: |
|
print(f"Error logging on Azure Blob Storage: {e}") |
|
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"] |
|
for i, doc in enumerate(docs_figures): |
|
if doc.metadata["chunk_type"] == "image": |
|
try: |
|
key = f"Image {i+1}" |
|
|
|
image_path = doc.metadata["image_path"].split("documents/")[1] |
|
img = get_image_from_azure_blob_storage(image_path) |
|
|
|
|
|
buffered = BytesIO() |
|
img.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
figures = figures + make_html_figure_sources(doc, i, img_str) |
|
|
|
gallery.append(img) |
|
|
|
except Exception as e: |
|
print(f"Skipped adding image {i} because of {e}") |
|
|
|
|
|
|
|
|
|
yield history, docs_html, output_query, output_language, gallery, figures, graphs_html |
|
|
|
|
|
def save_feedback(feed: str, user_id): |
|
if len(feed) > 1: |
|
timestamp = str(datetime.now().timestamp()) |
|
file = user_id + timestamp + ".json" |
|
logs = { |
|
"user_id": user_id, |
|
"feedback": feed, |
|
"time": timestamp, |
|
} |
|
log_on_azure(file, logs, share_client) |
|
return "Feedback submitted, thank you!" |
|
|
|
|
|
|
|
|
|
def log_on_azure(file, logs, share_client): |
|
logs = json.dumps(logs) |
|
file_client = share_client.get_file_client(file) |
|
file_client.upload_file(logs) |
|
|
|
|
|
def generate_keywords(query): |
|
chain = make_keywords_chain(llm) |
|
keywords = chain.invoke(query) |
|
keywords = " AND ".join(keywords["keywords"]) |
|
return keywords |
|
|
|
|
|
|
|
papers_cols_widths = { |
|
"doc":50, |
|
"id":100, |
|
"title":300, |
|
"doi":100, |
|
"publication_year":100, |
|
"abstract":500, |
|
"rerank_score":100, |
|
"is_oa":50, |
|
} |
|
|
|
papers_cols = list(papers_cols_widths.keys()) |
|
papers_cols_widths = list(papers_cols_widths.values()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
init_prompt = """ |
|
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**. |
|
|
|
β How to use |
|
- **Language**: You can ask me your questions in any language. |
|
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer. |
|
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both. |
|
|
|
β οΈ Limitations |
|
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* |
|
|
|
π Information |
|
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information. |
|
|
|
|
|
What do you want to learn ? |
|
""" |
|
|
|
|
|
def vote(data: gr.LikeData): |
|
if data.liked: |
|
print(data.value) |
|
else: |
|
print(data) |
|
|
|
def save_graph(saved_graphs_state, embedding, category): |
|
print(f"\nCategory:\n{saved_graphs_state}\n") |
|
if category not in saved_graphs_state: |
|
saved_graphs_state[category] = [] |
|
if embedding not in saved_graphs_state[category]: |
|
saved_graphs_state[category].append(embedding) |
|
return saved_graphs_state, gr.Button("Graph Saved") |
|
|
|
|
|
|
|
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo: |
|
chat_completed_state = gr.State(0) |
|
current_graphs = gr.State([]) |
|
saved_graphs = gr.State({}) |
|
|
|
with gr.Tab("ClimateQ&A"): |
|
|
|
with gr.Row(elem_id="chatbot-row"): |
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
value = [ChatMessage(role="assistant", content=init_prompt)], |
|
type = "messages", |
|
show_copy_button=True, |
|
show_label = False, |
|
elem_id="chatbot", |
|
layout = "panel", |
|
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"), |
|
max_height="80vh", |
|
height="100vh" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(elem_id = "input-message"): |
|
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox") |
|
|
|
|
|
with gr.Column(scale=2, variant="panel",elem_id = "right-panel"): |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0): |
|
|
|
examples_hidden = gr.Textbox(visible = False) |
|
first_key = list(QUESTIONS.keys())[0] |
|
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples") |
|
|
|
samples = [] |
|
for i,key in enumerate(QUESTIONS.keys()): |
|
|
|
examples_visible = True if i == 0 else False |
|
|
|
with gr.Row(visible = examples_visible) as group_examples: |
|
|
|
examples_questions = gr.Examples( |
|
QUESTIONS[key], |
|
[examples_hidden], |
|
examples_per_page=8, |
|
run_on_click=False, |
|
elem_id=f"examples{i}", |
|
api_name=f"examples{i}", |
|
|
|
|
|
) |
|
|
|
samples.append(group_examples) |
|
|
|
|
|
with gr.Tab("Sources",elem_id = "tab-citations",id = 1) as tab_sources: |
|
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") |
|
docs_textbox = gr.State("") |
|
|
|
|
|
|
|
|
|
with gr.Tab("Configuration",elem_id = "tab-config",id = 2) as tab_config: |
|
|
|
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!") |
|
|
|
|
|
dropdown_sources = gr.CheckboxGroup( |
|
["IPCC", "IPBES","IPOS"], |
|
label="Select source", |
|
value=["IPCC"], |
|
interactive=True, |
|
) |
|
|
|
dropdown_reports = gr.Dropdown( |
|
POSSIBLE_REPORTS, |
|
label="Or select specific reports", |
|
multiselect=True, |
|
value=None, |
|
interactive=True, |
|
) |
|
|
|
dropdown_audience = gr.Dropdown( |
|
["Children","General public","Experts"], |
|
label="Select audience", |
|
value="Experts", |
|
interactive=True, |
|
) |
|
|
|
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False) |
|
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False) |
|
|
|
with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures: |
|
with Modal(visible=False, elem_id="modal_figure_galery") as modal: |
|
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh") |
|
|
|
show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True) |
|
show_full_size_figures.click(lambda : Modal(visible=True),None,modal) |
|
|
|
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures") |
|
|
|
|
|
with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=4) as tab_recommended_content: |
|
graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>") |
|
current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("About",elem_classes = "max-height other-tabs"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)") |
|
|
|
|
|
def start_chat(query,history): |
|
|
|
|
|
history = history + [ChatMessage(role="user", content=query)] |
|
return (gr.update(interactive = False),gr.update(selected=1),history) |
|
|
|
def finish_chat(): |
|
return (gr.update(interactive = True,value = ""),gr.update(selected=3)) |
|
|
|
|
|
def change_completion_status(current_state): |
|
current_state = 1 - current_state |
|
return current_state |
|
|
|
def update_sources_number_display(sources_textbox, figures_cards, current_graphs): |
|
sources_number = sources_textbox.count("<h2>") |
|
figures_number = figures_cards.count("<h2>") |
|
graphs_number = current_graphs.count("<iframe") |
|
sources_notif_label = f"Sources ({sources_number})" |
|
figures_notif_label = f"Figures ({figures_number})" |
|
graphs_notif_label = f"Recommended content ({graphs_number})" |
|
|
|
|
|
|
|
|
|
return gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label) |
|
|
|
(textbox |
|
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox") |
|
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, figures_cards, current_graphs],concurrency_limit = 8,api_name = "chat_textbox") |
|
.then(finish_chat, None, [textbox,tabs],api_name = "finish_chat_textbox") |
|
.then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content] ) |
|
) |
|
|
|
(examples_hidden |
|
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples") |
|
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, figures_cards, current_graphs],concurrency_limit = 8,api_name = "chat_examples") |
|
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples") |
|
.then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content] ) |
|
|
|
) |
|
|
|
|
|
def change_sample_questions(key): |
|
index = list(QUESTIONS.keys()).index(key) |
|
visible_bools = [False] * len(samples) |
|
visible_bools[index] = True |
|
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))] |
|
|
|
|
|
|
|
dropdown_samples.change(change_sample_questions,dropdown_samples,samples) |
|
|
|
|
|
demo.queue() |
|
|
|
demo.launch(ssr_mode=False) |
|
|