Spaces:
Runtime error
Runtime error
domenicrosati
commited on
Commit
Β·
00e4b2e
1
Parent(s):
577cb80
add more settings
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import pipeline
|
3 |
import requests
|
4 |
from bs4 import BeautifulSoup
|
5 |
from nltk.corpus import stopwords
|
@@ -80,9 +80,11 @@ def init_models():
|
|
80 |
device=device
|
81 |
)
|
82 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
|
83 |
-
|
|
|
|
|
84 |
|
85 |
-
qa_model, reranker, stop, device = init_models()
|
86 |
|
87 |
|
88 |
def clean_query(query, strict=True, clean=True):
|
@@ -134,7 +136,8 @@ st.title("Scientific Question Answering with Citations")
|
|
134 |
|
135 |
st.write("""
|
136 |
Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
|
137 |
-
Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer.
|
|
|
138 |
""")
|
139 |
|
140 |
st.markdown("""
|
@@ -145,13 +148,35 @@ with st.expander("Settings (strictness, context limit, top hits)"):
|
|
145 |
strict_mode = st.radio(
|
146 |
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
|
147 |
('strict', 'lenient'))
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
def run_query(query):
|
152 |
-
|
153 |
-
|
|
|
|
|
154 |
|
|
|
|
|
|
|
|
|
155 |
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|
156 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
157 |
return st.markdown("""
|
@@ -164,12 +189,15 @@ def run_query(query):
|
|
164 |
</div>
|
165 |
""", unsafe_allow_html=True)
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
171 |
|
172 |
-
context = '\n'.join(sorted_contexts[:context_limit])
|
173 |
results = []
|
174 |
model_results = qa_model(question=query, context=context, top_k=10)
|
175 |
for result in model_results:
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead
|
3 |
import requests
|
4 |
from bs4 import BeautifulSoup
|
5 |
from nltk.corpus import stopwords
|
|
|
80 |
device=device
|
81 |
)
|
82 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
|
83 |
+
queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
84 |
+
queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
85 |
+
return question_answerer, reranker, stop, device, queryexp_model, queryexp_tokenizer
|
86 |
|
87 |
+
qa_model, reranker, stop, device, queryexp_model, queryexp_tokenizer = init_models()
|
88 |
|
89 |
|
90 |
def clean_query(query, strict=True, clean=True):
|
|
|
136 |
|
137 |
st.write("""
|
138 |
Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
|
139 |
+
Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer. For example try:
|
140 |
+
Are tanning beds safe to use? Does size of venture capital fund correlate with returns?
|
141 |
""")
|
142 |
|
143 |
st.markdown("""
|
|
|
148 |
strict_mode = st.radio(
|
149 |
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
|
150 |
('strict', 'lenient'))
|
151 |
+
use_reranking = st.radio(
|
152 |
+
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
153 |
+
('yes', 'no'))
|
154 |
+
use_query_exp = st.radio(
|
155 |
+
"(Experimental) use query expansion? Right now it just recommends queries",
|
156 |
+
('yes', 'no'))
|
157 |
+
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200 if torch.cuda.is_available() else 100)
|
158 |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
|
159 |
|
160 |
+
def paraphrase(text, max_length=128):
|
161 |
+
|
162 |
+
input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
163 |
+
|
164 |
+
generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=5, num_beams=5, max_length=max_length)
|
165 |
+
|
166 |
+
preds = '\n'.join([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
|
167 |
+
return preds
|
168 |
+
|
169 |
+
|
170 |
def run_query(query):
|
171 |
+
if use_query_exp == 'yes':
|
172 |
+
query_exp = paraphrase(f"question2question: {query}")
|
173 |
+
st.markdown(f"""
|
174 |
+
If you are not getting good results try one of:
|
175 |
|
176 |
+
{query_exp}
|
177 |
+
""")
|
178 |
+
limit = top_hits_limit or 100
|
179 |
+
context_limit = context_lim or 10
|
180 |
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|
181 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
182 |
return st.markdown("""
|
|
|
189 |
</div>
|
190 |
""", unsafe_allow_html=True)
|
191 |
|
192 |
+
if use_reranking == 'yes':
|
193 |
+
sentence_pairs = [[query, context] for context in contexts]
|
194 |
+
scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False)
|
195 |
+
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
196 |
+
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
197 |
+
context = '\n'.join(sorted_contexts[:context_limit])
|
198 |
+
else:
|
199 |
+
context = '\n'.join(contexts[:context_limit])
|
200 |
|
|
|
201 |
results = []
|
202 |
model_results = qa_model(question=query, context=context, top_k=10)
|
203 |
for result in model_results:
|