import streamlit as st from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead import requests from bs4 import BeautifulSoup from nltk.corpus import stopwords import nltk import string from streamlit.components.v1 import html from sentence_transformers.cross_encoder import CrossEncoder as CE import numpy as np from typing import List, Tuple import torch class CrossEncoder: def __init__(self, model_path: str, **kwargs): self.model = CE(model_path, **kwargs) def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]: return self.model.predict( sentences=sentences, batch_size=batch_size, show_progress_bar=show_progress_bar) SCITE_API_KEY = st.secrets["SCITE_API_KEY"] def remove_html(x): soup = BeautifulSoup(x, 'html.parser') text = soup.get_text() return text def search(term, limit=10, clean=True, strict=True): term = clean_query(term, clean=clean, strict=strict) # heuristic, 2 searches strict and not? and then merge? search = f"https://api.scite.ai/search?mode=citations&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" req = requests.get( search, headers={ 'Authorization': f'Bearer {SCITE_API_KEY}' } ) try: req.json() except: return [], [] return ( [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']], [(doc['doi'], doc['citations'], doc['title']) for doc in req.json()['hits']] ) def find_source(text, docs): for doc in docs: if text in remove_html(doc[1][0]['snippet']): new_text = text for snip in remove_html(doc[1][0]['snippet']).split('.'): if text in snip: new_text = snip return { 'citation_statement': doc[1][0]['snippet'].replace('', '').replace('', ''), 'text': new_text, 'from': doc[1][0]['source'], 'supporting': doc[1][0]['target'], 'source_title': doc[2], 'source_link': f"https://scite.ai/reports/{doc[0]}" } return None @st.experimental_singleton def init_models(): nltk.download('stopwords') stop = set(stopwords.words('english') + list(string.punctuation)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") question_answerer = pipeline( "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B', device=device ) reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device) queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1") queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1") return question_answerer, reranker, stop, device, queryexp_model, queryexp_tokenizer qa_model, reranker, stop, device, queryexp_model, queryexp_tokenizer = init_models() def clean_query(query, strict=True, clean=True): operator = ' ' if strict: operator = ' AND ' query = operator.join( [i for i in query.lower().split(' ') if clean and i not in stop]) if clean: query = query.translate(str.maketrans('', '', string.punctuation)) return query def card(title, context, score, link, supporting): st.markdown(f"""

{context} [Score: {score}]
From {title}
""", unsafe_allow_html=True) html(f"""
""", width=None, height=42, scrolling=False) st.title("Scientific Question Answering with Citations") st.write(""" Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements. Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer. For example try: Are tanning beds safe to use? Does size of venture capital fund correlate with returns? """) st.markdown(""" """, unsafe_allow_html=True) with st.expander("Settings (strictness, context limit, top hits)"): strict_mode = st.radio( "Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.", ('lenient', 'strict')) use_reranking = st.radio( "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.", ('yes', 'no')) use_query_exp = st.radio( "(Experimental) use query expansion? Right now it just recommends queries", ('yes', 'no')) top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200 if torch.cuda.is_available() else 100) context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10) def paraphrase(text, max_length=128): input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True) generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=5, num_beams=5, max_length=max_length) queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]) preds = '\n * '.join(queries) return preds def run_query(query): if use_query_exp == 'yes': query_exp = paraphrase(f"question2question: {query}") st.markdown(f""" If you are not getting good results try one of: * {query_exp} """) limit = top_hits_limit or 100 context_limit = context_lim or 10 contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict') if len(contexts) == 0 or not ''.join(contexts).strip(): return st.markdown("""
Sorry... no results for that question! Try another...
""", unsafe_allow_html=True) if use_reranking == 'yes': sentence_pairs = [[query, context] for context in contexts] scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False) hits = {contexts[idx]: scores[idx] for idx in range(len(scores))} sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)] context = '\n'.join(sorted_contexts[:context_limit]) else: context = '\n'.join(contexts[:context_limit]) results = [] model_results = qa_model(question=query, context=context, top_k=10) for result in model_results: support = find_source(result['answer'], orig_docs) if not support: continue results.append({ "answer": support['text'], "title": support['source_title'], "link": support['source_link'], "context": support['citation_statement'], "score": result['score'], "doi": support["supporting"] }) sorted_result = sorted(results, key=lambda x: x['score'], reverse=True) sorted_result = list({ result['context']: result for result in sorted_result }.values()) sorted_result = sorted( sorted_result, key=lambda x: x['score'], reverse=True) for r in sorted_result: answer = r["answer"] ctx = remove_html(r["context"]).replace(answer, f"{answer}").replace( '