ArticleChatbot / streamlit_app.py
mqcm2's picture
Update streamlit_app.py
d4690c5 verified
import os
import re
from hashlib import blake2b
from tempfile import NamedTemporaryFile
import dotenv
from grobid_quantities.quantities import QuantitiesAPI
from langchain.memory import ConversationBufferWindowMemory
# from langchain_community.callbacks import PromptLayerCallbackHandler
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import OpenAIEmbeddings
from streamlit_pdf_viewer import pdf_viewer
from document_qa.ner_client_generic import NERClientGeneric
dotenv.load_dotenv(override=True)
import streamlit as st
from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
OPENAI_MODELS = ['gpt-3.5-turbo',
"gpt-4",
"gpt-4-1106-preview"]
OPENAI_EMBEDDINGS = [
'text-embedding-ada-002',
'text-embedding-3-large',
'openai-text-embedding-3-small'
]
OPEN_MODELS = {
'Mistral-Nemo-Instruct-2407': 'mistralai/Mistral-Nemo-Instruct-2407',
'mistral-7b-instruct-v0.3': 'mistralai/Mistral-7B-Instruct-v0.3',
'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct"
}
DEFAULT_OPEN_EMBEDDING_NAME = 'Default (all-MiniLM-L6-v2)'
OPEN_EMBEDDINGS = {
DEFAULT_OPEN_EMBEDDING_NAME: 'all-MiniLM-L6-v2',
'SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral',
'SFR-Embedding-2_R': 'Salesforce/SFR-Embedding-2_R',
'NV-Embed': 'nvidia/NV-Embed-v1',
'e5-mistral-7b-instruct': 'intfloat/e5-mistral-7b-instruct'
}
if 'rqa' not in st.session_state:
st.session_state['rqa'] = {}
if 'model' not in st.session_state:
st.session_state['model'] = None
if 'api_keys' not in st.session_state:
st.session_state['api_keys'] = {}
if 'doc_id' not in st.session_state:
st.session_state['doc_id'] = None
if 'loaded_embeddings' not in st.session_state:
st.session_state['loaded_embeddings'] = None
if 'hash' not in st.session_state:
st.session_state['hash'] = None
if 'git_rev' not in st.session_state:
st.session_state['git_rev'] = "unknown"
if os.path.exists("revision.txt"):
with open("revision.txt", 'r') as fr:
from_file = fr.read()
st.session_state['git_rev'] = from_file if len(from_file) > 0 else "unknown"
if "messages" not in st.session_state:
st.session_state.messages = []
if 'ner_processing' not in st.session_state:
st.session_state['ner_processing'] = False
if 'uploaded' not in st.session_state:
st.session_state['uploaded'] = False
if 'memory' not in st.session_state:
st.session_state['memory'] = None
if 'binary' not in st.session_state:
st.session_state['binary'] = None
if 'annotations' not in st.session_state:
st.session_state['annotations'] = None
if 'should_show_annotations' not in st.session_state:
st.session_state['should_show_annotations'] = True
if 'pdf' not in st.session_state:
st.session_state['pdf'] = None
if 'embeddings' not in st.session_state:
st.session_state['embeddings'] = None
if 'scroll_to_first_annotation' not in st.session_state:
st.session_state['scroll_to_first_annotation'] = False
st.set_page_config(
page_title="Articel Chatbot",
page_icon="📝",
initial_sidebar_state="expanded",
layout="wide",
menu_items={
'About': "Upload a scientific article in PDF, ask questions, get insights."
}
)
st.markdown(
"""
<style>
.block-container {
padding-top: 3rem;
padding-bottom: 1rem;
padding-left: 1rem;
padding-right: 1rem;
}
</style>
""",
unsafe_allow_html=True
)
def new_file():
st.session_state['loaded_embeddings'] = None
st.session_state['doc_id'] = None
st.session_state['uploaded'] = True
if st.session_state['memory']:
st.session_state['memory'].clear()
def clear_memory():
st.session_state['memory'].clear()
# @st.cache_resource
def init_qa(model, embeddings_name=None, api_key=None):
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
if model in OPENAI_MODELS:
if embeddings_name is None:
embeddings_name = 'text-embedding-ada-002'
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
if api_key:
chat = ChatOpenAI(model_name=model,
temperature=0,
openai_api_key=api_key,
frequency_penalty=0.1)
if embeddings_name not in OPENAI_EMBEDDINGS:
st.error(f"The embeddings provided {embeddings_name} are not supported by this model {model}.")
st.stop()
return
embeddings = OpenAIEmbeddings(model=embeddings_name, openai_api_key=api_key)
else:
chat = ChatOpenAI(model_name=model,
temperature=0,
frequency_penalty=0.1)
embeddings = OpenAIEmbeddings(model=embeddings_name)
elif model in OPEN_MODELS:
if embeddings_name is None:
embeddings_name = DEFAULT_OPEN_EMBEDDING_NAME
chat = HuggingFaceEndpoint(
repo_id=OPEN_MODELS[model],
temperature=0.01,
max_new_tokens=4092,
model_kwargs={"max_length": 8192},
# callbacks=[PromptLayerCallbackHandler(pl_tags=[model, "document-qa"])]
)
embeddings = HuggingFaceEmbeddings(
model_name=OPEN_EMBEDDINGS[embeddings_name])
# st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
else:
st.error("The model was not loaded properly. Try reloading. ")
st.stop()
return
storage = DataStorage(embeddings)
return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
@st.cache_resource
def init_ner():
quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True)
materials_client = NERClientGeneric(ping=True)
config_materials = {
'grobid': {
"server": os.environ['GROBID_MATERIALS_URL'],
'sleep_time': 5,
'timeout': 60,
'url_mapping': {
'processText_disable_linking': "/service/process/text?disableLinking=True",
# 'processText_disable_linking': "/service/process/text"
}
}
}
materials_client.set_config(config_materials)
gqa = GrobidAggregationProcessor(grobid_quantities_client=quantities_client,
grobid_superconductors_client=materials_client)
return gqa
gqa = init_ner()
def get_file_hash(fname):
hash_md5 = blake2b()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def play_old_messages(container):
if st.session_state['messages']:
for message in st.session_state['messages']:
if message['role'] == 'user':
container.chat_message("user").markdown(message['content'])
elif message['role'] == 'assistant':
if mode == "LLM":
container.chat_message("assistant").markdown(message['content'], unsafe_allow_html=True)
else:
container.chat_message("assistant").write(message['content'])
# is_api_key_provided = st.session_state['api_key']
with st.sidebar:
st.title("Articel Chatbot")
st.markdown("Upload a scientific article in PDF, ask questions, get insights.")
st.divider()
st.session_state['model'] = model = st.selectbox(
"Model:",
options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index(
os.environ["DEFAULT_MODEL"]) if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] else 0,
placeholder="Select model",
help="Select the LLM model:",
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
)
embedding_choices = OPENAI_EMBEDDINGS if model in OPENAI_MODELS else OPEN_EMBEDDINGS
st.session_state['embeddings'] = embedding_name = st.selectbox(
"Embeddings:",
options=embedding_choices,
index=0,
placeholder="Select embedding",
help="Select the Embedding function:",
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
)
if (model in OPEN_MODELS) and model not in st.session_state['api_keys']:
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
api_key = st.text_input('Huggingface API Key', type="password")
st.markdown("Get it [here](https://huggingface.co/docs/hub/security-tokens)")
else:
api_key = os.environ['HUGGINGFACEHUB_API_TOKEN']
if api_key:
# st.session_state['api_key'] = is_api_key_provided = True
if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
with st.spinner("Preparing environment"):
st.session_state['api_keys'][model] = api_key
# if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
st.session_state['rqa'][model] = init_qa(model, embedding_name)
elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
if 'OPENAI_API_KEY' not in os.environ:
api_key = st.text_input('OpenAI API Key', type="password")
st.markdown("Get it [here](https://platform.openai.com/account/api-keys)")
else:
api_key = os.environ['OPENAI_API_KEY']
if api_key:
if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
with st.spinner("Preparing environment"):
st.session_state['api_keys'][model] = api_key
if 'OPENAI_API_KEY' not in os.environ:
st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'], api_key)
else:
st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'])
# else:
# is_api_key_provided = st.session_state['api_key']
# st.button(
# 'Reset chat memory.',
# key="reset-memory-button",
# on_click=clear_memory,
# help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
# disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
left_column, right_column = st.columns([5, 4])
right_column = right_column.container(border=True)
left_column = left_column.container(border=True)
with right_column:
uploaded_file = st.file_uploader(
"Upload a scientific article",
type=("pdf"),
on_change=new_file,
disabled=st.session_state['model'] is not None and st.session_state['model'] not in
st.session_state['api_keys'],
help="The full-text is extracted using Grobid."
)
placeholder = st.empty()
messages = st.container(height=300)
question = st.chat_input(
"Ask something about the article",
# placeholder="Can you give me a short summary?",
disabled=not uploaded_file
)
query_modes = {
"llm": "LLM Q/A",
"embeddings": "Embeddings",
"question_coefficient": "Question coefficient"
}
with st.sidebar:
st.header("Settings")
mode = st.radio(
"Query mode",
("llm", "embeddings", "question_coefficient"),
disabled=not uploaded_file,
index=0,
horizontal=True,
format_func=lambda x: query_modes[x],
help="LLM will respond the question, Embedding will show the "
"relevant paragraphs to the question in the paper. "
"Question coefficient attempt to estimate how effective the question will be answered."
)
st.session_state['scroll_to_first_annotation'] = st.checkbox(
"Scroll to context",
help='The PDF viewer will automatically scroll to the first relevant passage in the document.'
)
st.session_state['ner_processing'] = st.checkbox(
"Identify materials and properties.",
help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.'
)
# Add a checkbox for showing annotations
# st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True)
# st.session_state['should_show_annotations'] = st.checkbox("Show annotations", value=True)
chunk_size = st.slider("Text chunks size", -1, 2000, value=-1,
help="Size of chunks in which split the document. -1: use paragraphs, > 0 paragraphs are aggregated.",
disabled=uploaded_file is not None)
if chunk_size == -1:
context_size = st.slider("Context size (paragraphs)", 3, 20, value=10,
help="Number of paragraphs to consider when answering a question",
disabled=not uploaded_file)
else:
context_size = st.slider("Context size (chunks)", 3, 10, value=4,
help="Number of chunks to consider when answering a question",
disabled=not uploaded_file)
st.divider()
st.markdown(
"""Upload a scientific article as PDF document. Once the spinner stops, you can proceed to ask your questions.""")
if st.session_state['git_rev'] != "unknown":
st.markdown("**Revision number**: [" + st.session_state[
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
if uploaded_file and not st.session_state.loaded_embeddings:
if model not in st.session_state['api_keys']:
st.error("Before uploading a document, you must enter the API key. ")
st.stop()
with left_column:
with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
binary = uploaded_file.getvalue()
tmp_file = NamedTemporaryFile()
tmp_file.write(bytearray(binary))
st.session_state['binary'] = binary
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
chunk_size=chunk_size,
perc_overlap=0.1)
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []
def rgb_to_hex(rgb):
return "#{:02x}{:02x}{:02x}".format(*rgb)
def generate_color_gradient(num_elements):
# Define warm and cold colors in RGB format
warm_color = (255, 165, 0) # Orange
cold_color = (0, 0, 255) # Blue
# Generate a linear gradient of colors
color_gradient = [
rgb_to_hex(tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in
zip(warm_color, cold_color)))
for i in range(num_elements)
]
return color_gradient
with right_column:
if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
for message in st.session_state.messages:
# with messages.chat_message(message["role"]):
if message['mode'] == "llm":
messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
elif message['mode'] == "embeddings":
messages.chat_message(message["role"]).write(message["content"])
elif message['mode'] == "question_coefficient":
messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
if model not in st.session_state['rqa']:
st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
st.stop()
text_response = None
if mode == "embeddings":
with placeholder:
with st.spinner("Fetching the relevant context..."):
text_response, coordinates = st.session_state['rqa'][model].query_storage(
question,
st.session_state.doc_id,
context_size=context_size
)
elif mode == "llm":
with placeholder:
with st.spinner("Generating LLM response..."):
_, text_response, coordinates = st.session_state['rqa'][model].query_document(
question,
st.session_state.doc_id,
context_size=context_size
)
elif mode == "question_coefficient":
with st.spinner("Estimate question/context relevancy..."):
text_response, coordinates = st.session_state['rqa'][model].analyse_query(
question,
st.session_state.doc_id,
context_size=context_size
)
annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
for coord_doc in coordinates]
gradients = generate_color_gradient(len(annotations))
for i, color in enumerate(gradients):
for annotation in annotations[i]:
annotation['color'] = color
st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in
annotation_doc]
if not text_response:
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
if mode == "llm":
if st.session_state['ner_processing']:
with st.spinner("Processing NER on LLM response..."):
entities = gqa.process_single_text(text_response)
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
text_response = decorated_text
messages.chat_message("assistant").markdown(text_response, unsafe_allow_html=True)
else:
messages.chat_message("assistant").write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
play_old_messages(messages)
with left_column:
if st.session_state['binary']:
with st.container(height=600):
pdf_viewer(
input=st.session_state['binary'],
annotation_outline_size=2,
annotations=st.session_state['annotations'],
render_text=True,
scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state['scroll_to_first_annotation']) else None
)