""", unsafe_allow_html=True)
html(f"""
""", width=None, height=42, scrolling=False)
st.title("Scientific Question Answering with Citations")
st.write("""
Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer.
For example try: Are tanning beds safe to use? Does size of venture capital fund correlate with returns?
""")
st.markdown("""
""", unsafe_allow_html=True)
with st.expander("Settings (strictness, context limit, top hits)"):
strict_mode = st.radio(
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
('lenient', 'strict'))
use_reranking = st.radio(
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
('yes', 'no'))
use_query_exp = st.radio(
"(Experimental) use query expansion? Right now it just recommends queries",
('yes', 'no'))
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200 if torch.cuda.is_available() else 100)
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
def paraphrase(text, max_length=128):
input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=5, num_beams=5, max_length=max_length)
queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
preds = '\n * '.join(queries)
return preds
def run_query(query):
if use_query_exp == 'yes':
query_exp = paraphrase(f"question2question: {query}")
st.markdown(f"""
If you are not getting good results try one of:
* {query_exp}
""")
limit = top_hits_limit or 100
context_limit = context_lim or 10
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
if len(contexts) == 0 or not ''.join(contexts).strip():
return st.markdown("""
Sorry... no results for that question! Try another...
""", unsafe_allow_html=True)
if use_reranking == 'yes':
sentence_pairs = [[query, context] for context in contexts]
scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False)
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
context = '\n'.join(sorted_contexts[:context_limit])
else:
context = '\n'.join(contexts[:context_limit])
results = []
model_results = qa_model(question=query, context=context, top_k=10)
for result in model_results:
support = find_source(result['answer'], orig_docs)
if not support:
continue
results.append({
"answer": support['text'],
"title": support['source_title'],
"link": support['source_link'],
"context": support['citation_statement'],
"score": result['score'],
"doi": support["supporting"]
})
sorted_result = sorted(results, key=lambda x: x['score'], reverse=True)
sorted_result = list({
result['context']: result for result in sorted_result
}.values())
sorted_result = sorted(
sorted_result, key=lambda x: x['score'], reverse=True)
for r in sorted_result:
answer = r["answer"]
ctx = remove_html(r["context"]).replace(answer, f"{answer}").replace(
'