domenicrosati commited on
Commit
5ed186b
Β·
1 Parent(s): a776895

support abstracts in QA

Browse files
Files changed (1) hide show
  1. app.py +65 -28
app.py CHANGED
@@ -37,30 +37,51 @@ def remove_html(x):
37
  # all search
38
 
39
 
40
- def search(term, limit=10, clean=True, strict=True, abstracts=True):
41
  term = clean_query(term, clean=clean, strict=strict)
42
  # heuristic, 2 searches strict and not? and then merge?
43
  # https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true
44
- mode = 'all'
45
- if not abstracts:
46
- mode = 'citations'
47
- search = f"https://api.scite.ai/search?mode={mode}&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
48
- req = requests.get(
49
- search,
50
- headers={
51
- 'Authorization': f'Bearer {SCITE_API_KEY}'
52
- }
53
- )
54
- try:
55
- req.json()
56
- except:
57
- return [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- citation_contexts = [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']]
60
  return (
61
- citation_contexts,
62
- [(doc['doi'], doc['citations'], doc['title'])
63
- for doc in req.json()['hits']]
64
  )
65
 
66
 
@@ -69,15 +90,28 @@ def find_source(text, docs):
69
  for snippet in doc[1]:
70
  if text in remove_html(snippet.get('snippet', '')):
71
  new_text = text
72
- for snip in remove_html(snippet.get('snippet', '')).split('.'):
73
- if text in snip:
74
- new_text = snip
75
  return {
76
  'citation_statement': snippet['snippet'].replace('<strong class="highlight">', '').replace('</strong>', ''),
77
  'text': new_text,
78
  'from': snippet['source'],
79
  'supporting': snippet['target'],
80
- 'source_title': doc[2],
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  'source_link': f"https://scite.ai/reports/{doc[0]}"
82
  }
83
  return None
@@ -159,9 +193,12 @@ st.markdown("""
159
  """, unsafe_allow_html=True)
160
 
161
  with st.expander("Settings (strictness, context limit, top hits)"):
162
- support_abstracts = st.radio(
163
- "Use abstracts as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
164
  ('yes', 'no'))
 
 
 
165
  strict_lenient_mix = st.radio(
166
  "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
167
  ('fallback', 'mix'))
@@ -170,7 +207,7 @@ with st.expander("Settings (strictness, context limit, top hits)"):
170
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
171
  ('yes', 'no'))
172
  top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
173
- context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 10)
174
 
175
  # def paraphrase(text, max_length=128):
176
  # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
@@ -190,9 +227,9 @@ def run_query(query):
190
  # could also try fallback if there are no good answers by score...
191
  limit = top_hits_limit or 100
192
  context_limit = context_lim or 10
193
- contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, abstracts=support_abstracts == 'yes')
194
  if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
195
- contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, abstracts=support_abstracts == 'yes')
196
  contexts = list(
197
  set(contexts_strict + contexts_lenient)
198
  )
 
37
  # all search
38
 
39
 
40
+ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=True, abstract_only=False):
41
  term = clean_query(term, clean=clean, strict=strict)
42
  # heuristic, 2 searches strict and not? and then merge?
43
  # https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true
44
+ contexts, docs = [], []
45
+ if not abstract_only:
46
+ mode = 'all'
47
+ if not all_mode:
48
+ mode = 'citations'
49
+ search = f"https://api.scite.ai/search?mode={mode}&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
50
+ req = requests.get(
51
+ search,
52
+ headers={
53
+ 'Authorization': f'Bearer {SCITE_API_KEY}'
54
+ }
55
+ )
56
+ try:
57
+ req.json()
58
+ except:
59
+ pass
60
+
61
+ contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']]
62
+ docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
63
+ for doc in req.json()['hits']]
64
+
65
+ if abstracts or abstract_only:
66
+ search = f"https://api.scite.ai/search?mode=papers&abstract={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
67
+ req = requests.get(
68
+ search,
69
+ headers={
70
+ 'Authorization': f'Bearer {SCITE_API_KEY}'
71
+ }
72
+ )
73
+ try:
74
+ req.json()
75
+ contexts += [remove_html(doc['abstract'] or '') for doc in req.json()['hits']]
76
+ docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
77
+ for doc in req.json()['hits']]
78
+ except:
79
+ pass
80
+
81
 
 
82
  return (
83
+ contexts,
84
+ docs
 
85
  )
86
 
87
 
 
90
  for snippet in doc[1]:
91
  if text in remove_html(snippet.get('snippet', '')):
92
  new_text = text
93
+ for sent in remove_html(snippet.get('snippet', '')).split('.'):
94
+ if text in sent:
95
+ new_text = sent
96
  return {
97
  'citation_statement': snippet['snippet'].replace('<strong class="highlight">', '').replace('</strong>', ''),
98
  'text': new_text,
99
  'from': snippet['source'],
100
  'supporting': snippet['target'],
101
+ 'source_title': remove_html(doc[2]),
102
+ 'source_link': f"https://scite.ai/reports/{doc[0]}"
103
+ }
104
+ if text in remove_html(doc[3]):
105
+ new_text = text
106
+ for sent in remove_html(doc[3]).split('.'):
107
+ if text in sent:
108
+ new_text = sent
109
+ return {
110
+ 'citation_statement': "ABSTRACT: " + remove_html(doc[3]).replace('<strong class="highlight">', '').replace('</strong>', ''),
111
+ 'text': new_text,
112
+ 'from': '...',
113
+ 'supporting': '...',
114
+ 'source_title': "ABSTRACT of " + remove_html(doc[2]),
115
  'source_link': f"https://scite.ai/reports/{doc[0]}"
116
  }
117
  return None
 
193
  """, unsafe_allow_html=True)
194
 
195
  with st.expander("Settings (strictness, context limit, top hits)"):
196
+ support_all = st.radio(
197
+ "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
198
  ('yes', 'no'))
199
+ support_abstracts = st.radio(
200
+ "Use abstracts as a source document?",
201
+ ('yes', 'no', 'abstract only'))
202
  strict_lenient_mix = st.radio(
203
  "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
204
  ('fallback', 'mix'))
 
207
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
208
  ('yes', 'no'))
209
  top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
210
+ context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
211
 
212
  # def paraphrase(text, max_length=128):
213
  # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
 
227
  # could also try fallback if there are no good answers by score...
228
  limit = top_hits_limit or 100
229
  context_limit = context_lim or 10
230
+ contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only')
231
  if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
232
+ contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only')
233
  contexts = list(
234
  set(contexts_strict + contexts_lenient)
235
  )