domenicrosati commited on
Commit
e15c8b9
Β·
1 Parent(s): a812db5

improve effeciency

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead
3
  import requests
4
  from bs4 import BeautifulSoup
5
- from nltk.corpus import stopwords
6
  import nltk
7
  import string
8
  from streamlit.components.v1 import html
@@ -78,18 +77,19 @@ def find_source(text, docs):
78
  @st.experimental_singleton
79
  def init_models():
80
  nltk.download('stopwords')
 
81
  stop = set(stopwords.words('english') + list(string.punctuation))
82
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
  question_answerer = pipeline(
84
  "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
85
  device=device
86
  )
87
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
88
- queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
89
- queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
90
- return question_answerer, reranker, stop, device, queryexp_model, queryexp_tokenizer
91
 
92
- qa_model, reranker, stop, device, queryexp_model, queryexp_tokenizer = init_models()
93
 
94
 
95
  def clean_query(query, strict=True, clean=True):
@@ -157,27 +157,27 @@ with st.expander("Settings (strictness, context limit, top hits)"):
157
  use_reranking = st.radio(
158
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
159
  ('yes', 'no'))
160
- top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200 if torch.cuda.is_available() else 50)
161
- context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
162
  use_query_exp = st.radio(
163
  "(Experimental) use query expansion? Right now it just recommends queries",
164
  ('yes', 'no'))
165
  suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
166
 
167
- def paraphrase(text, max_length=128):
168
- input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
169
- generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length)
170
- queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
171
- preds = '\n * '.join(queries)
172
- return preds
173
 
174
  def run_query(query):
175
- if use_query_exp == 'yes':
176
- query_exp = paraphrase(f"question2question: {query}")
177
- st.markdown(f"""
178
- If you are not getting good results try one of:
179
- * {query_exp}
180
- """)
181
  limit = top_hits_limit or 100
182
  context_limit = context_lim or 10
183
  contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
  import requests
4
  from bs4 import BeautifulSoup
 
5
  import nltk
6
  import string
7
  from streamlit.components.v1 import html
 
77
  @st.experimental_singleton
78
  def init_models():
79
  nltk.download('stopwords')
80
+ from nltk.corpus import stopwords
81
  stop = set(stopwords.words('english') + list(string.punctuation))
82
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
  question_answerer = pipeline(
84
  "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
85
  device=device
86
  )
87
+ reranker = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2', device=device)
88
+ # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
89
+ # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
90
+ return question_answerer, reranker, stop, device # uqeryexp_model, queryexp_tokenizer
91
 
92
+ qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
93
 
94
 
95
  def clean_query(query, strict=True, clean=True):
 
157
  use_reranking = st.radio(
158
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
159
  ('yes', 'no'))
160
+ top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200)
161
+ context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
162
  use_query_exp = st.radio(
163
  "(Experimental) use query expansion? Right now it just recommends queries",
164
  ('yes', 'no'))
165
  suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
166
 
167
+ # def paraphrase(text, max_length=128):
168
+ # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
169
+ # generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length)
170
+ # queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
171
+ # preds = '\n * '.join(queries)
172
+ # return preds
173
 
174
  def run_query(query):
175
+ # if use_query_exp == 'yes':
176
+ # query_exp = paraphrase(f"question2question: {query}")
177
+ # st.markdown(f"""
178
+ # If you are not getting good results try one of:
179
+ # * {query_exp}
180
+ # """)
181
  limit = top_hits_limit or 100
182
  context_limit = context_lim or 10
183
  contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')