Spaces:
Runtime error
Runtime error
domenicrosati
commited on
Commit
Β·
e15c8b9
1
Parent(s):
a812db5
improve effeciency
Browse files
app.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import pipeline
|
3 |
import requests
|
4 |
from bs4 import BeautifulSoup
|
5 |
-
from nltk.corpus import stopwords
|
6 |
import nltk
|
7 |
import string
|
8 |
from streamlit.components.v1 import html
|
@@ -78,18 +77,19 @@ def find_source(text, docs):
|
|
78 |
@st.experimental_singleton
|
79 |
def init_models():
|
80 |
nltk.download('stopwords')
|
|
|
81 |
stop = set(stopwords.words('english') + list(string.punctuation))
|
82 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
83 |
question_answerer = pipeline(
|
84 |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
|
85 |
device=device
|
86 |
)
|
87 |
-
reranker = CrossEncoder('cross-encoder/ms-marco-
|
88 |
-
queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
89 |
-
queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
90 |
-
return question_answerer, reranker, stop, device
|
91 |
|
92 |
-
qa_model, reranker, stop, device
|
93 |
|
94 |
|
95 |
def clean_query(query, strict=True, clean=True):
|
@@ -157,27 +157,27 @@ with st.expander("Settings (strictness, context limit, top hits)"):
|
|
157 |
use_reranking = st.radio(
|
158 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
159 |
('yes', 'no'))
|
160 |
-
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200
|
161 |
-
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25
|
162 |
use_query_exp = st.radio(
|
163 |
"(Experimental) use query expansion? Right now it just recommends queries",
|
164 |
('yes', 'no'))
|
165 |
suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
|
166 |
|
167 |
-
def paraphrase(text, max_length=128):
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
|
174 |
def run_query(query):
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
If you are not getting good results try one of:
|
179 |
-
* {query_exp}
|
180 |
-
""")
|
181 |
limit = top_hits_limit or 100
|
182 |
context_limit = context_lim or 10
|
183 |
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline
|
3 |
import requests
|
4 |
from bs4 import BeautifulSoup
|
|
|
5 |
import nltk
|
6 |
import string
|
7 |
from streamlit.components.v1 import html
|
|
|
77 |
@st.experimental_singleton
|
78 |
def init_models():
|
79 |
nltk.download('stopwords')
|
80 |
+
from nltk.corpus import stopwords
|
81 |
stop = set(stopwords.words('english') + list(string.punctuation))
|
82 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
83 |
question_answerer = pipeline(
|
84 |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
|
85 |
device=device
|
86 |
)
|
87 |
+
reranker = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2', device=device)
|
88 |
+
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
89 |
+
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
90 |
+
return question_answerer, reranker, stop, device # uqeryexp_model, queryexp_tokenizer
|
91 |
|
92 |
+
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
93 |
|
94 |
|
95 |
def clean_query(query, strict=True, clean=True):
|
|
|
157 |
use_reranking = st.radio(
|
158 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
159 |
('yes', 'no'))
|
160 |
+
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200)
|
161 |
+
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
162 |
use_query_exp = st.radio(
|
163 |
"(Experimental) use query expansion? Right now it just recommends queries",
|
164 |
('yes', 'no'))
|
165 |
suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
|
166 |
|
167 |
+
# def paraphrase(text, max_length=128):
|
168 |
+
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
169 |
+
# generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length)
|
170 |
+
# queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
|
171 |
+
# preds = '\n * '.join(queries)
|
172 |
+
# return preds
|
173 |
|
174 |
def run_query(query):
|
175 |
+
# if use_query_exp == 'yes':
|
176 |
+
# query_exp = paraphrase(f"question2question: {query}")
|
177 |
+
# st.markdown(f"""
|
178 |
+
# If you are not getting good results try one of:
|
179 |
+
# * {query_exp}
|
180 |
+
# """)
|
181 |
limit = top_hits_limit or 100
|
182 |
context_limit = context_lim or 10
|
183 |
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|