Spaces:
Runtime error
Runtime error
domenicrosati
commited on
Commit
Β·
5ed186b
1
Parent(s):
a776895
support abstracts in QA
Browse files
app.py
CHANGED
@@ -37,30 +37,51 @@ def remove_html(x):
|
|
37 |
# all search
|
38 |
|
39 |
|
40 |
-
def search(term, limit=10, clean=True, strict=True, abstracts=True):
|
41 |
term = clean_query(term, clean=clean, strict=strict)
|
42 |
# heuristic, 2 searches strict and not? and then merge?
|
43 |
# https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true
|
44 |
-
|
45 |
-
if not
|
46 |
-
mode = '
|
47 |
-
|
48 |
-
|
49 |
-
search
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
citation_contexts = [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']]
|
60 |
return (
|
61 |
-
|
62 |
-
|
63 |
-
for doc in req.json()['hits']]
|
64 |
)
|
65 |
|
66 |
|
@@ -69,15 +90,28 @@ def find_source(text, docs):
|
|
69 |
for snippet in doc[1]:
|
70 |
if text in remove_html(snippet.get('snippet', '')):
|
71 |
new_text = text
|
72 |
-
for
|
73 |
-
if text in
|
74 |
-
new_text =
|
75 |
return {
|
76 |
'citation_statement': snippet['snippet'].replace('<strong class="highlight">', '').replace('</strong>', ''),
|
77 |
'text': new_text,
|
78 |
'from': snippet['source'],
|
79 |
'supporting': snippet['target'],
|
80 |
-
'source_title': doc[2],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
82 |
}
|
83 |
return None
|
@@ -159,9 +193,12 @@ st.markdown("""
|
|
159 |
""", unsafe_allow_html=True)
|
160 |
|
161 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
162 |
-
|
163 |
-
"Use abstracts as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
164 |
('yes', 'no'))
|
|
|
|
|
|
|
165 |
strict_lenient_mix = st.radio(
|
166 |
"Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
|
167 |
('fallback', 'mix'))
|
@@ -170,7 +207,7 @@ with st.expander("Settings (strictness, context limit, top hits)"):
|
|
170 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
171 |
('yes', 'no'))
|
172 |
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
|
173 |
-
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300,
|
174 |
|
175 |
# def paraphrase(text, max_length=128):
|
176 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
@@ -190,9 +227,9 @@ def run_query(query):
|
|
190 |
# could also try fallback if there are no good answers by score...
|
191 |
limit = top_hits_limit or 100
|
192 |
context_limit = context_lim or 10
|
193 |
-
contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, abstracts=support_abstracts == 'yes')
|
194 |
if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
|
195 |
-
contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, abstracts=support_abstracts == 'yes')
|
196 |
contexts = list(
|
197 |
set(contexts_strict + contexts_lenient)
|
198 |
)
|
|
|
37 |
# all search
|
38 |
|
39 |
|
40 |
+
def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=True, abstract_only=False):
|
41 |
term = clean_query(term, clean=clean, strict=strict)
|
42 |
# heuristic, 2 searches strict and not? and then merge?
|
43 |
# https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true
|
44 |
+
contexts, docs = [], []
|
45 |
+
if not abstract_only:
|
46 |
+
mode = 'all'
|
47 |
+
if not all_mode:
|
48 |
+
mode = 'citations'
|
49 |
+
search = f"https://api.scite.ai/search?mode={mode}&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
|
50 |
+
req = requests.get(
|
51 |
+
search,
|
52 |
+
headers={
|
53 |
+
'Authorization': f'Bearer {SCITE_API_KEY}'
|
54 |
+
}
|
55 |
+
)
|
56 |
+
try:
|
57 |
+
req.json()
|
58 |
+
except:
|
59 |
+
pass
|
60 |
+
|
61 |
+
contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']]
|
62 |
+
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
|
63 |
+
for doc in req.json()['hits']]
|
64 |
+
|
65 |
+
if abstracts or abstract_only:
|
66 |
+
search = f"https://api.scite.ai/search?mode=papers&abstract={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
|
67 |
+
req = requests.get(
|
68 |
+
search,
|
69 |
+
headers={
|
70 |
+
'Authorization': f'Bearer {SCITE_API_KEY}'
|
71 |
+
}
|
72 |
+
)
|
73 |
+
try:
|
74 |
+
req.json()
|
75 |
+
contexts += [remove_html(doc['abstract'] or '') for doc in req.json()['hits']]
|
76 |
+
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
|
77 |
+
for doc in req.json()['hits']]
|
78 |
+
except:
|
79 |
+
pass
|
80 |
+
|
81 |
|
|
|
82 |
return (
|
83 |
+
contexts,
|
84 |
+
docs
|
|
|
85 |
)
|
86 |
|
87 |
|
|
|
90 |
for snippet in doc[1]:
|
91 |
if text in remove_html(snippet.get('snippet', '')):
|
92 |
new_text = text
|
93 |
+
for sent in remove_html(snippet.get('snippet', '')).split('.'):
|
94 |
+
if text in sent:
|
95 |
+
new_text = sent
|
96 |
return {
|
97 |
'citation_statement': snippet['snippet'].replace('<strong class="highlight">', '').replace('</strong>', ''),
|
98 |
'text': new_text,
|
99 |
'from': snippet['source'],
|
100 |
'supporting': snippet['target'],
|
101 |
+
'source_title': remove_html(doc[2]),
|
102 |
+
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
103 |
+
}
|
104 |
+
if text in remove_html(doc[3]):
|
105 |
+
new_text = text
|
106 |
+
for sent in remove_html(doc[3]).split('.'):
|
107 |
+
if text in sent:
|
108 |
+
new_text = sent
|
109 |
+
return {
|
110 |
+
'citation_statement': "ABSTRACT: " + remove_html(doc[3]).replace('<strong class="highlight">', '').replace('</strong>', ''),
|
111 |
+
'text': new_text,
|
112 |
+
'from': '...',
|
113 |
+
'supporting': '...',
|
114 |
+
'source_title': "ABSTRACT of " + remove_html(doc[2]),
|
115 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
116 |
}
|
117 |
return None
|
|
|
193 |
""", unsafe_allow_html=True)
|
194 |
|
195 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
196 |
+
support_all = st.radio(
|
197 |
+
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
198 |
('yes', 'no'))
|
199 |
+
support_abstracts = st.radio(
|
200 |
+
"Use abstracts as a source document?",
|
201 |
+
('yes', 'no', 'abstract only'))
|
202 |
strict_lenient_mix = st.radio(
|
203 |
"Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
|
204 |
('fallback', 'mix'))
|
|
|
207 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
208 |
('yes', 'no'))
|
209 |
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
|
210 |
+
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
211 |
|
212 |
# def paraphrase(text, max_length=128):
|
213 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
|
227 |
# could also try fallback if there are no good answers by score...
|
228 |
limit = top_hits_limit or 100
|
229 |
context_limit = context_lim or 10
|
230 |
+
contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only')
|
231 |
if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
|
232 |
+
contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only')
|
233 |
contexts = list(
|
234 |
set(contexts_strict + contexts_lenient)
|
235 |
)
|