lfoppiano commited on
Commit
d74cacd
·
1 Parent(s): 0188e45

update application

Browse files
Files changed (1) hide show
  1. streamlit_app.py +44 -20
streamlit_app.py CHANGED
@@ -5,8 +5,11 @@ from tempfile import NamedTemporaryFile
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
- from langchain.llms.huggingface_hub import HuggingFaceHub
9
  from langchain.memory import ConversationBufferWindowMemory
 
 
 
 
10
  from streamlit_pdf_viewer import pdf_viewer
11
 
12
  from document_qa.ner_client_generic import NERClientGeneric
@@ -14,9 +17,6 @@ from document_qa.ner_client_generic import NERClientGeneric
14
  dotenv.load_dotenv(override=True)
15
 
16
  import streamlit as st
17
- from langchain.chat_models import ChatOpenAI
18
- from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
19
-
20
  from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
21
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
22
 
@@ -157,9 +157,11 @@ def init_qa(model, api_key=None):
157
  embeddings = OpenAIEmbeddings()
158
 
159
  elif model in OPEN_MODELS:
160
- chat = HuggingFaceHub(
161
  repo_id=OPEN_MODELS[model],
162
- model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048}
 
 
163
  )
164
  embeddings = HuggingFaceEmbeddings(
165
  model_name="all-MiniLM-L6-v2")
@@ -305,16 +307,24 @@ question = st.chat_input(
305
  disabled=not uploaded_file
306
  )
307
 
 
 
 
 
 
 
308
  with st.sidebar:
309
  st.header("Settings")
310
  mode = st.radio(
311
  "Query mode",
312
- ("LLM", "Embeddings"),
313
  disabled=not uploaded_file,
314
  index=0,
315
  horizontal=True,
 
316
  help="LLM will respond the question, Embedding will show the "
317
- "paragraphs relevant to the question in the paper."
 
318
  )
319
 
320
  # Add a checkbox for showing annotations
@@ -429,10 +439,12 @@ with right_column:
429
  if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
430
  for message in st.session_state.messages:
431
  with st.chat_message(message["role"]):
432
- if message['mode'] == "LLM":
433
  st.markdown(message["content"], unsafe_allow_html=True)
434
- elif message['mode'] == "Embeddings":
435
  st.write(message["content"])
 
 
436
  if model not in st.session_state['rqa']:
437
  st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
438
  st.stop()
@@ -442,16 +454,28 @@ with right_column:
442
  st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
443
 
444
  text_response = None
445
- if mode == "Embeddings":
 
 
 
 
 
 
 
446
  with st.spinner("Generating LLM response..."):
447
- text_response, coordinates = st.session_state['rqa'][model].query_storage(question,
448
- st.session_state.doc_id,
449
- context_size=context_size)
450
- elif mode == "LLM":
451
- with st.spinner("Generating response..."):
452
- _, text_response, coordinates = st.session_state['rqa'][model].query_document(question,
453
- st.session_state.doc_id,
454
- context_size=context_size)
 
 
 
 
 
455
 
456
  annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
457
  for coord_doc in coordinates]
@@ -466,7 +490,7 @@ with right_column:
466
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
467
 
468
  with st.chat_message("assistant"):
469
- if mode == "LLM":
470
  if st.session_state['ner_processing']:
471
  with st.spinner("Processing NER on LLM response..."):
472
  entities = gqa.process_single_text(text_response)
 
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
 
8
  from langchain.memory import ConversationBufferWindowMemory
9
+ from langchain_community.chat_models.openai import ChatOpenAI
10
+ from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
11
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
12
+ from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
13
  from streamlit_pdf_viewer import pdf_viewer
14
 
15
  from document_qa.ner_client_generic import NERClientGeneric
 
17
  dotenv.load_dotenv(override=True)
18
 
19
  import streamlit as st
 
 
 
20
  from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
21
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
22
 
 
157
  embeddings = OpenAIEmbeddings()
158
 
159
  elif model in OPEN_MODELS:
160
+ chat = HuggingFaceEndpoint(
161
  repo_id=OPEN_MODELS[model],
162
+ temperature=0.01,
163
+ max_new_tokens=2048,
164
+ model_kwargs={"max_length": 4096}
165
  )
166
  embeddings = HuggingFaceEmbeddings(
167
  model_name="all-MiniLM-L6-v2")
 
307
  disabled=not uploaded_file
308
  )
309
 
310
+ query_modes = {
311
+ "llm": "LLM Q/A",
312
+ "embeddings": "Embeddings",
313
+ "question_coefficient": "Question coefficient"
314
+ }
315
+
316
  with st.sidebar:
317
  st.header("Settings")
318
  mode = st.radio(
319
  "Query mode",
320
+ ("llm", "embeddings", "question_coefficient"),
321
  disabled=not uploaded_file,
322
  index=0,
323
  horizontal=True,
324
+ format_func=lambda x: query_modes[x],
325
  help="LLM will respond the question, Embedding will show the "
326
+ "relevant paragraphs to the question in the paper. "
327
+ "Question coefficient attempt to estimate how effective the question will be answered."
328
  )
329
 
330
  # Add a checkbox for showing annotations
 
439
  if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
440
  for message in st.session_state.messages:
441
  with st.chat_message(message["role"]):
442
+ if message['mode'] == "llm":
443
  st.markdown(message["content"], unsafe_allow_html=True)
444
+ elif message['mode'] == "embeddings":
445
  st.write(message["content"])
446
+ if message['mode'] == "question_coefficient":
447
+ st.markdown(message["content"], unsafe_allow_html=True)
448
  if model not in st.session_state['rqa']:
449
  st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
450
  st.stop()
 
454
  st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
455
 
456
  text_response = None
457
+ if mode == "embeddings":
458
+ with st.spinner("Fetching the relevant context..."):
459
+ text_response, coordinates = st.session_state['rqa'][model].query_storage(
460
+ question,
461
+ st.session_state.doc_id,
462
+ context_size=context_size
463
+ )
464
+ elif mode == "llm":
465
  with st.spinner("Generating LLM response..."):
466
+ _, text_response, coordinates = st.session_state['rqa'][model].query_document(
467
+ question,
468
+ st.session_state.doc_id,
469
+ context_size=context_size
470
+ )
471
+
472
+ elif mode == "question_coefficient":
473
+ with st.spinner("Estimate question/context relevancy..."):
474
+ text_response, coordinates = st.session_state['rqa'][model].analyse_query(
475
+ question,
476
+ st.session_state.doc_id,
477
+ context_size=context_size
478
+ )
479
 
480
  annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
481
  for coord_doc in coordinates]
 
490
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
491
 
492
  with st.chat_message("assistant"):
493
+ if mode == "llm":
494
  if st.session_state['ner_processing']:
495
  with st.spinner("Processing NER on LLM response..."):
496
  entities = gqa.process_single_text(text_response)