domenicrosati commited on
Commit
3f1f616
Β·
1 Parent(s): bdb2b00

update to use api

Browse files
Files changed (1) hide show
  1. app.py +165 -135
app.py CHANGED
@@ -12,15 +12,15 @@ import torch
12
 
13
  SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
14
 
15
- class CrossEncoder:
16
- def __init__(self, model_path: str, **kwargs):
17
- self.model = CE(model_path, **kwargs)
18
 
19
- def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
20
- return self.model.predict(
21
- sentences=sentences,
22
- batch_size=batch_size,
23
- show_progress_bar=show_progress_bar)
24
 
25
 
26
  def remove_html(x):
@@ -134,23 +134,23 @@ def find_source(text, docs, matched):
134
  return None
135
 
136
 
137
- @st.experimental_singleton
138
- def init_models():
139
- nltk.download('stopwords')
140
- nltk.download('punkt')
141
- from nltk.corpus import stopwords
142
- stop = set(stopwords.words('english') + list(string.punctuation))
143
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
- question_answerer = pipeline(
145
- "question-answering", model='nlpconnect/roberta-base-squad2-nq',
146
- device=device, handle_impossible_answer=False,
147
- )
148
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
149
- # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
150
- # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
- return question_answerer, reranker, stop, device
152
 
153
- qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
154
 
155
 
156
  def clean_query(query, strict=True, clean=True):
@@ -206,32 +206,32 @@ Answers are linked to source documents containing citations where users can expl
206
  For example try: Do tanning beds cause cancer?
207
  """)
208
 
209
- st.markdown("""
210
- <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
211
- """, unsafe_allow_html=True)
212
-
213
- with st.expander("Settings (strictness, context limit, top hits)"):
214
- concat_passages = st.radio(
215
- "Concatenate passages as one long context?",
216
- ('yes', 'no'))
217
- present_impossible = st.radio(
218
- "Present impossible answers? (if the model thinks its impossible to answer should it still try?)",
219
- ('yes', 'no'))
220
- support_all = st.radio(
221
- "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
222
- ('no', 'yes'))
223
- support_abstracts = st.radio(
224
- "Use abstracts as a source document?",
225
- ('yes', 'no', 'abstract only'))
226
- strict_lenient_mix = st.radio(
227
- "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
228
- ('mix', 'fallback'))
229
- confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
230
- use_reranking = st.radio(
231
- "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
232
- ('yes', 'no'))
233
- top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
234
- context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
235
 
236
  # def paraphrase(text, max_length=128):
237
  # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
@@ -272,38 +272,120 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
272
  return None
273
 
274
 
275
- def run_query(query, progress_bar):
276
- # if use_query_exp == 'yes':
277
- # query_exp = paraphrase(f"question2question: {query}")
278
- # st.markdown(f"""
279
- # If you are not getting good results try one of:
280
- # * {query_exp}
281
- # """)
282
-
283
- # could also try fallback if there are no good answers by score...
284
- limit = top_hits_limit or 100
285
- context_limit = context_lim or 10
286
- contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only')
287
- if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
288
- contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only')
289
- contexts = list(
290
- set(contexts_strict + contexts_lenient)
291
- )
292
- orig_docs = orig_docs_strict + orig_docs_lenient
293
- elif strict_lenient_mix == 'mix':
294
- contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False)
295
- contexts = list(
296
- set(contexts_strict + contexts_lenient)
297
- )
298
- orig_docs = orig_docs_strict + orig_docs_lenient
299
- else:
300
- contexts = list(
301
- set(contexts_strict)
302
- )
303
- orig_docs = orig_docs_strict
304
- progress_bar.progress(25)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- if len(contexts) == 0 or not ''.join(contexts).strip():
307
  return st.markdown("""
308
  <div class="container-fluid">
309
  <div class="row align-items-start">
@@ -314,58 +396,7 @@ def run_query(query, progress_bar):
314
  </div>
315
  """, unsafe_allow_html=True)
316
 
317
- if use_reranking == 'yes':
318
- sentence_pairs = [[query, context] for context in contexts]
319
- scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
320
- hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
321
- sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
322
- contexts = sorted_contexts[:context_limit]
323
- else:
324
- contexts = contexts[:context_limit]
325
-
326
- progress_bar.progress(50)
327
- if concat_passages == 'yes':
328
- context = '\n---'.join(contexts)
329
- model_results = qa_model(question=query, context=context, top_k=10, doc_stride=512 // 2, max_answer_len=128, max_seq_len=512, handle_impossible_answer=present_impossible=='yes')
330
- else:
331
- context = ['\n---\n'+ctx for ctx in contexts]
332
- model_results = qa_model(question=[query]*len(contexts), context=context, handle_impossible_answer=present_impossible=='yes')
333
-
334
- results = []
335
-
336
- progress_bar.progress(75)
337
- for i, result in enumerate(model_results):
338
- if concat_passages == 'yes':
339
- matched = matched_context(result['start'], result['end'], context)
340
- else:
341
- matched = matched_context(result['start'], result['end'], context[i])
342
- support = find_source(result['answer'], orig_docs, matched)
343
- if not support:
344
- continue
345
- results.append({
346
- "answer": support['text'],
347
- "title": support['source_title'],
348
- "link": support['source_link'],
349
- "context": support['citation_statement'],
350
- "score": result['score'],
351
- "doi": support["supporting"]
352
- })
353
-
354
- grouped_results = group_results_by_context(results)
355
- sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
356
-
357
- if confidence_threshold == 0:
358
- threshold = 0
359
- else:
360
- threshold = (confidence_threshold or 10) / 100
361
-
362
- sorted_result = list(filter(
363
- lambda x: x['score'] > threshold,
364
- sorted_result
365
- ))
366
-
367
- progress_bar.progress(100)
368
- for r in sorted_result:
369
  ctx = remove_html(r["context"])
370
  for answer in r['texts']:
371
  ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
@@ -377,5 +408,4 @@ def run_query(query, progress_bar):
377
  query = st.text_input("Ask scientific literature a question", "")
378
  if query != "":
379
  with st.spinner('Loading...'):
380
- progress_bar = st.progress(0)
381
- run_query(query, progress_bar)
 
12
 
13
  SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
14
 
15
+ # class CrossEncoder:
16
+ # def __init__(self, model_path: str, **kwargs):
17
+ # self.model = CE(model_path, **kwargs)
18
 
19
+ # def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
20
+ # return self.model.predict(
21
+ # sentences=sentences,
22
+ # batch_size=batch_size,
23
+ # show_progress_bar=show_progress_bar)
24
 
25
 
26
  def remove_html(x):
 
134
  return None
135
 
136
 
137
+ # @st.experimental_singleton
138
+ # def init_models():
139
+ # nltk.download('stopwords')
140
+ # nltk.download('punkt')
141
+ # from nltk.corpus import stopwords
142
+ # stop = set(stopwords.words('english') + list(string.punctuation))
143
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
+ # question_answerer = pipeline(
145
+ # "question-answering", model='nlpconnect/roberta-base-squad2-nq',
146
+ # device=0 if torch.cuda.is_available() else -1, handle_impossible_answer=False,
147
+ # )
148
+ # reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
149
+ # # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
150
+ # # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
+ # return question_answerer, reranker, stop, device
152
 
153
+ # qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
154
 
155
 
156
  def clean_query(query, strict=True, clean=True):
 
206
  For example try: Do tanning beds cause cancer?
207
  """)
208
 
209
+ # st.markdown("""
210
+ # <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
211
+ # """, unsafe_allow_html=True)
212
+
213
+ # with st.expander("Settings (strictness, context limit, top hits)"):
214
+ # concat_passages = st.radio(
215
+ # "Concatenate passages as one long context?",
216
+ # ('yes', 'no'))
217
+ # present_impossible = st.radio(
218
+ # "Present impossible answers? (if the model thinks its impossible to answer should it still try?)",
219
+ # ('yes', 'no'))
220
+ # support_all = st.radio(
221
+ # "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
222
+ # ('no', 'yes'))
223
+ # support_abstracts = st.radio(
224
+ # "Use abstracts as a source document?",
225
+ # ('yes', 'no', 'abstract only'))
226
+ # strict_lenient_mix = st.radio(
227
+ # "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
228
+ # ('mix', 'fallback'))
229
+ # confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
230
+ # use_reranking = st.radio(
231
+ # "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
232
+ # ('yes', 'no'))
233
+ # top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
234
+ # context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
235
 
236
  # def paraphrase(text, max_length=128):
237
  # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
 
272
  return None
273
 
274
 
275
+ # def run_query_full(query, progress_bar):
276
+ # # if use_query_exp == 'yes':
277
+ # # query_exp = paraphrase(f"question2question: {query}")
278
+ # # st.markdown(f"""
279
+ # # If you are not getting good results try one of:
280
+ # # * {query_exp}
281
+ # # """)
282
+
283
+ # # could also try fallback if there are no good answers by score...
284
+ # limit = top_hits_limit or 100
285
+ # context_limit = context_lim or 10
286
+ # contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only')
287
+ # if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
288
+ # contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only')
289
+ # contexts = list(
290
+ # set(contexts_strict + contexts_lenient)
291
+ # )
292
+ # orig_docs = orig_docs_strict + orig_docs_lenient
293
+ # elif strict_lenient_mix == 'mix':
294
+ # contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False)
295
+ # contexts = list(
296
+ # set(contexts_strict + contexts_lenient)
297
+ # )
298
+ # orig_docs = orig_docs_strict + orig_docs_lenient
299
+ # else:
300
+ # contexts = list(
301
+ # set(contexts_strict)
302
+ # )
303
+ # orig_docs = orig_docs_strict
304
+ # progress_bar.progress(25)
305
+
306
+ # if len(contexts) == 0 or not ''.join(contexts).strip():
307
+ # return st.markdown("""
308
+ # <div class="container-fluid">
309
+ # <div class="row align-items-start">
310
+ # <div class="col-md-12 col-sm-12">
311
+ # Sorry... no results for that question! Try another...
312
+ # </div>
313
+ # </div>
314
+ # </div>
315
+ # """, unsafe_allow_html=True)
316
+
317
+ # if use_reranking == 'yes':
318
+ # sentence_pairs = [[query, context] for context in contexts]
319
+ # scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
320
+ # hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
321
+ # sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
322
+ # contexts = sorted_contexts[:context_limit]
323
+ # else:
324
+ # contexts = contexts[:context_limit]
325
+
326
+ # progress_bar.progress(50)
327
+ # if concat_passages == 'yes':
328
+ # context = '\n---'.join(contexts)
329
+ # model_results = qa_model(question=query, context=context, top_k=10, doc_stride=512 // 2, max_answer_len=128, max_seq_len=512, handle_impossible_answer=present_impossible=='yes')
330
+ # else:
331
+ # context = ['\n---\n'+ctx for ctx in contexts]
332
+ # model_results = qa_model(question=[query]*len(contexts), context=context, handle_impossible_answer=present_impossible=='yes')
333
+
334
+ # results = []
335
+
336
+ # progress_bar.progress(75)
337
+ # for i, result in enumerate(model_results):
338
+ # if concat_passages == 'yes':
339
+ # matched = matched_context(result['start'], result['end'], context)
340
+ # else:
341
+ # matched = matched_context(result['start'], result['end'], context[i])
342
+ # support = find_source(result['answer'], orig_docs, matched)
343
+ # if not support:
344
+ # continue
345
+ # results.append({
346
+ # "answer": support['text'],
347
+ # "title": support['source_title'],
348
+ # "link": support['source_link'],
349
+ # "context": support['citation_statement'],
350
+ # "score": result['score'],
351
+ # "doi": support["supporting"]
352
+ # })
353
+
354
+ # grouped_results = group_results_by_context(results)
355
+ # sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
356
+
357
+ # if confidence_threshold == 0:
358
+ # threshold = 0
359
+ # else:
360
+ # threshold = (confidence_threshold or 10) / 100
361
+
362
+ # sorted_result = list(filter(
363
+ # lambda x: x['score'] > threshold,
364
+ # sorted_result
365
+ # ))
366
+
367
+ # progress_bar.progress(100)
368
+ # for r in sorted_result:
369
+ # ctx = remove_html(r["context"])
370
+ # for answer in r['texts']:
371
+ # ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
372
+ # # .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
373
+ # title = r.get("title", '')
374
+ # score = round(round(r["score"], 4) * 100, 2)
375
+ # card(title, ctx, score, r['link'], r['doi'])
376
+
377
+
378
+ def run_query(query):
379
+ api_location = 'http://74.82.31.93'
380
+ resp_raw = requests.get(
381
+ f'{api_location}/question-answer?query={query}'
382
+ )
383
+ try:
384
+ resp = resp_raw.json()
385
+ except:
386
+ resp = {'results': []}
387
 
388
+ if len(resp.get('results', [])) == 0:
389
  return st.markdown("""
390
  <div class="container-fluid">
391
  <div class="row align-items-start">
 
396
  </div>
397
  """, unsafe_allow_html=True)
398
 
399
+ for r in resp['results']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  ctx = remove_html(r["context"])
401
  for answer in r['texts']:
402
  ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
 
408
  query = st.text_input("Ask scientific literature a question", "")
409
  if query != "":
410
  with st.spinner('Loading...'):
411
+ run_query(query)