domenicrosati commited on
Commit
e996282
Β·
1 Parent(s): f2eab41

add treshold for predictions

Browse files
Files changed (1) hide show
  1. app.py +27 -15
app.py CHANGED
@@ -11,6 +11,8 @@ import numpy as np
11
  from typing import List, Tuple
12
  import torch
13
 
 
 
14
  class CrossEncoder:
15
  def __init__(self, model_path: str, **kwargs):
16
  self.model = CE(model_path, **kwargs)
@@ -22,18 +24,21 @@ class CrossEncoder:
22
  show_progress_bar=show_progress_bar)
23
 
24
 
25
- SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
26
-
27
-
28
  def remove_html(x):
29
  soup = BeautifulSoup(x, 'html.parser')
30
  text = soup.get_text()
31
  return text
32
 
33
 
 
 
 
 
 
34
  def search(term, limit=10, clean=True, strict=True):
35
  term = clean_query(term, clean=clean, strict=strict)
36
  # heuristic, 2 searches strict and not? and then merge?
 
37
  search = f"https://api.scite.ai/search?mode=citations&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
38
  req = requests.get(
39
  search,
@@ -67,6 +72,7 @@ def find_source(text, docs):
67
  'source_title': doc[2],
68
  'source_link': f"https://scite.ai/reports/{doc[0]}"
69
  }
 
70
  return None
71
 
72
 
@@ -79,7 +85,7 @@ def init_models():
79
  "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
80
  device=device
81
  )
82
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', device=device)
83
  queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
84
  queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
85
  return question_answerer, reranker, stop, device, queryexp_model, queryexp_tokenizer
@@ -98,7 +104,6 @@ def clean_query(query, strict=True, clean=True):
98
  return query
99
 
100
 
101
-
102
  def card(title, context, score, link, supporting):
103
  st.markdown(f"""
104
  <div class="container-fluid">
@@ -138,7 +143,7 @@ st.write("""
138
  Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
139
  Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer.
140
 
141
- For example try: Are tanning beds safe to use? Does size of venture capital fund correlate with returns?
142
  """)
143
 
144
  st.markdown("""
@@ -146,26 +151,27 @@ st.markdown("""
146
  """, unsafe_allow_html=True)
147
 
148
  with st.expander("Settings (strictness, context limit, top hits)"):
 
149
  strict_mode = st.radio(
150
  "Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
151
  ('lenient', 'strict'))
152
  use_reranking = st.radio(
153
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
154
  ('yes', 'no'))
 
 
155
  use_query_exp = st.radio(
156
  "(Experimental) use query expansion? Right now it just recommends queries",
157
  ('yes', 'no'))
158
- top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200 if torch.cuda.is_available() else 100)
159
- context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
160
 
161
  def paraphrase(text, max_length=128):
162
  input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
163
- generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=5, num_beams=5, max_length=max_length)
164
  queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
165
  preds = '\n * '.join(queries)
166
  return preds
167
 
168
-
169
  def run_query(query):
170
  if use_query_exp == 'yes':
171
  query_exp = paraphrase(f"question2question: {query}")
@@ -186,7 +192,6 @@ If you are not getting good results try one of:
186
  </div>
187
  </div>
188
  """, unsafe_allow_html=True)
189
-
190
  if use_reranking == 'yes':
191
  sentence_pairs = [[query, context] for context in contexts]
192
  scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False)
@@ -195,7 +200,6 @@ If you are not getting good results try one of:
195
  context = '\n'.join(sorted_contexts[:context_limit])
196
  else:
197
  context = '\n'.join(contexts[:context_limit])
198
-
199
  results = []
200
  model_results = qa_model(question=query, context=context, top_k=10)
201
  for result in model_results:
@@ -210,14 +214,23 @@ If you are not getting good results try one of:
210
  "score": result['score'],
211
  "doi": support["supporting"]
212
  })
213
-
214
- sorted_result = sorted(results, key=lambda x: x['score'], reverse=True)
215
  sorted_result = list({
216
  result['context']: result for result in sorted_result
217
  }.values())
218
  sorted_result = sorted(
219
  sorted_result, key=lambda x: x['score'], reverse=True)
220
 
 
 
 
 
 
 
 
 
 
 
221
  for r in sorted_result:
222
  answer = r["answer"]
223
  ctx = remove_html(r["context"]).replace(answer, f"<mark>{answer}</mark>").replace(
@@ -227,7 +240,6 @@ If you are not getting good results try one of:
227
  card(title, ctx, score, r['link'], r['doi'])
228
 
229
  query = st.text_input("Ask scientific literature a question", "")
230
-
231
  if query != "":
232
  with st.spinner('Loading...'):
233
  run_query(query)
 
11
  from typing import List, Tuple
12
  import torch
13
 
14
+ SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
15
+
16
  class CrossEncoder:
17
  def __init__(self, model_path: str, **kwargs):
18
  self.model = CE(model_path, **kwargs)
 
24
  show_progress_bar=show_progress_bar)
25
 
26
 
 
 
 
27
  def remove_html(x):
28
  soup = BeautifulSoup(x, 'html.parser')
29
  text = soup.get_text()
30
  return text
31
 
32
 
33
+ # 4 searches: strict y/n, supported y/n
34
+ # deduplicate
35
+ # search per query
36
+
37
+
38
  def search(term, limit=10, clean=True, strict=True):
39
  term = clean_query(term, clean=clean, strict=strict)
40
  # heuristic, 2 searches strict and not? and then merge?
41
+ # https://api.scite.ai/search?mode=citations&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true
42
  search = f"https://api.scite.ai/search?mode=citations&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
43
  req = requests.get(
44
  search,
 
72
  'source_title': doc[2],
73
  'source_link': f"https://scite.ai/reports/{doc[0]}"
74
  }
75
+ print("None found for", text)
76
  return None
77
 
78
 
 
85
  "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
86
  device=device
87
  )
88
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
89
  queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
90
  queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
91
  return question_answerer, reranker, stop, device, queryexp_model, queryexp_tokenizer
 
104
  return query
105
 
106
 
 
107
  def card(title, context, score, link, supporting):
108
  st.markdown(f"""
109
  <div class="container-fluid">
 
143
  Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
144
  Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer.
145
 
146
+ For example try: Do tanning beds cause cancer?
147
  """)
148
 
149
  st.markdown("""
 
151
  """, unsafe_allow_html=True)
152
 
153
  with st.expander("Settings (strictness, context limit, top hits)"):
154
+ confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
155
  strict_mode = st.radio(
156
  "Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
157
  ('lenient', 'strict'))
158
  use_reranking = st.radio(
159
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
160
  ('yes', 'no'))
161
+ top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200 if torch.cuda.is_available() else 50)
162
+ context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
163
  use_query_exp = st.radio(
164
  "(Experimental) use query expansion? Right now it just recommends queries",
165
  ('yes', 'no'))
166
+ suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
 
167
 
168
  def paraphrase(text, max_length=128):
169
  input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
170
+ generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length)
171
  queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
172
  preds = '\n * '.join(queries)
173
  return preds
174
 
 
175
  def run_query(query):
176
  if use_query_exp == 'yes':
177
  query_exp = paraphrase(f"question2question: {query}")
 
192
  </div>
193
  </div>
194
  """, unsafe_allow_html=True)
 
195
  if use_reranking == 'yes':
196
  sentence_pairs = [[query, context] for context in contexts]
197
  scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False)
 
200
  context = '\n'.join(sorted_contexts[:context_limit])
201
  else:
202
  context = '\n'.join(contexts[:context_limit])
 
203
  results = []
204
  model_results = qa_model(question=query, context=context, top_k=10)
205
  for result in model_results:
 
214
  "score": result['score'],
215
  "doi": support["supporting"]
216
  })
217
+ sorted_result = sorted(results, key=lambda x: x['score'])
 
218
  sorted_result = list({
219
  result['context']: result for result in sorted_result
220
  }.values())
221
  sorted_result = sorted(
222
  sorted_result, key=lambda x: x['score'], reverse=True)
223
 
224
+ if confidence_threshold == 0:
225
+ threshold = 0
226
+ else:
227
+ threshold = (confidence_threshold or 10) / 100
228
+
229
+ sorted_result = filter(
230
+ lambda x: x['score'] > threshold,
231
+ sorted_result
232
+ )
233
+
234
  for r in sorted_result:
235
  answer = r["answer"]
236
  ctx = remove_html(r["context"]).replace(answer, f"<mark>{answer}</mark>").replace(
 
240
  card(title, ctx, score, r['link'], r['doi'])
241
 
242
  query = st.text_input("Ask scientific literature a question", "")
 
243
  if query != "":
244
  with st.spinner('Loading...'):
245
  run_query(query)