import streamlit as st from transformers import pipeline import requests from bs4 import BeautifulSoup from nltk.corpus import stopwords import nltk import string from streamlit.components.v1 import html from sentence_transformers.cross_encoder import CrossEncoder as CE import numpy as np from typing import List, Tuple import torch class CrossEncoder: def __init__(self, model_path: str, **kwargs): self.model = CE(model_path, **kwargs) def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]: return self.model.predict( sentences=sentences, batch_size=batch_size, show_progress_bar=show_progress_bar) SCITE_API_KEY = st.secrets["SCITE_API_KEY"] def remove_html(x): soup = BeautifulSoup(x, 'html.parser') text = soup.get_text() return text def search(term, limit=10, clean=True, strict=True): term = clean_query(term, clean=clean, strict=strict) # heuristic, 2 searches strict and not? and then merge? search = f"https://api.scite.ai/search?mode=citations&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" req = requests.get( search, headers={ 'Authorization': f'Bearer {SCITE_API_KEY}' } ) try: req.json() except: return [], [] return ( [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']], [(doc['doi'], doc['citations'], doc['title']) for doc in req.json()['hits']] ) def find_source(text, docs): for doc in docs: if text in remove_html(doc[1][0]['snippet']): new_text = text for snip in remove_html(doc[1][0]['snippet']).split('.'): if text in snip: new_text = snip return { 'citation_statement': doc[1][0]['snippet'].replace('', '').replace('', ''), 'text': new_text, 'from': doc[1][0]['source'], 'supporting': doc[1][0]['target'], 'source_title': doc[2], 'source_link': f"https://scite.ai/reports/{doc[0]}" } return None @st.experimental_singleton def init_models(): nltk.download('stopwords') stop = set(stopwords.words('english') + list(string.punctuation)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") question_answerer = pipeline( "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B', device=device ) reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device) return question_answerer, reranker, stop, device qa_model, reranker, stop, device = init_models() def clean_query(query, strict=True, clean=True): operator = ' ' if strict: operator = ' AND ' query = operator.join( [i for i in query.lower().split(' ') if clean and i not in stop]) if clean: query = query.translate(str.maketrans('', '', string.punctuation)) return query def card(title, context, score, link, supporting): st.markdown(f"""