domenicrosati commited on
Commit
00e4b2e
Β·
1 Parent(s): 577cb80

add more settings

Browse files
Files changed (1) hide show
  1. app.py +40 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
  import requests
4
  from bs4 import BeautifulSoup
5
  from nltk.corpus import stopwords
@@ -80,9 +80,11 @@ def init_models():
80
  device=device
81
  )
82
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
83
- return question_answerer, reranker, stop, device
 
 
84
 
85
- qa_model, reranker, stop, device = init_models()
86
 
87
 
88
  def clean_query(query, strict=True, clean=True):
@@ -134,7 +136,8 @@ st.title("Scientific Question Answering with Citations")
134
 
135
  st.write("""
136
  Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
137
- Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer.
 
138
  """)
139
 
140
  st.markdown("""
@@ -145,13 +148,35 @@ with st.expander("Settings (strictness, context limit, top hits)"):
145
  strict_mode = st.radio(
146
  "Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
147
  ('strict', 'lenient'))
148
- top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100 if torch.cuda.is_available() else 25)
 
 
 
 
 
 
149
  context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
150
 
 
 
 
 
 
 
 
 
 
 
151
  def run_query(query):
152
- limit = top_hits_limit or 100
153
- context_limit = context_lim or 50
 
 
154
 
 
 
 
 
155
  contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
156
  if len(contexts) == 0 or not ''.join(contexts).strip():
157
  return st.markdown("""
@@ -164,12 +189,15 @@ def run_query(query):
164
  </div>
165
  """, unsafe_allow_html=True)
166
 
167
- sentence_pairs = [[query, context] for context in contexts]
168
- scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False)
169
- hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
170
- sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
 
 
 
 
171
 
172
- context = '\n'.join(sorted_contexts[:context_limit])
173
  results = []
174
  model_results = qa_model(question=query, context=context, top_k=10)
175
  for result in model_results:
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead
3
  import requests
4
  from bs4 import BeautifulSoup
5
  from nltk.corpus import stopwords
 
80
  device=device
81
  )
82
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
83
+ queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
84
+ queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
85
+ return question_answerer, reranker, stop, device, queryexp_model, queryexp_tokenizer
86
 
87
+ qa_model, reranker, stop, device, queryexp_model, queryexp_tokenizer = init_models()
88
 
89
 
90
  def clean_query(query, strict=True, clean=True):
 
136
 
137
  st.write("""
138
  Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
139
+ Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer. For example try:
140
+ Are tanning beds safe to use? Does size of venture capital fund correlate with returns?
141
  """)
142
 
143
  st.markdown("""
 
148
  strict_mode = st.radio(
149
  "Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
150
  ('strict', 'lenient'))
151
+ use_reranking = st.radio(
152
+ "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
153
+ ('yes', 'no'))
154
+ use_query_exp = st.radio(
155
+ "(Experimental) use query expansion? Right now it just recommends queries",
156
+ ('yes', 'no'))
157
+ top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200 if torch.cuda.is_available() else 100)
158
  context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
159
 
160
+ def paraphrase(text, max_length=128):
161
+
162
+ input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
163
+
164
+ generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=5, num_beams=5, max_length=max_length)
165
+
166
+ preds = '\n'.join([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
167
+ return preds
168
+
169
+
170
  def run_query(query):
171
+ if use_query_exp == 'yes':
172
+ query_exp = paraphrase(f"question2question: {query}")
173
+ st.markdown(f"""
174
+ If you are not getting good results try one of:
175
 
176
+ {query_exp}
177
+ """)
178
+ limit = top_hits_limit or 100
179
+ context_limit = context_lim or 10
180
  contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
181
  if len(contexts) == 0 or not ''.join(contexts).strip():
182
  return st.markdown("""
 
189
  </div>
190
  """, unsafe_allow_html=True)
191
 
192
+ if use_reranking == 'yes':
193
+ sentence_pairs = [[query, context] for context in contexts]
194
+ scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False)
195
+ hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
196
+ sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
197
+ context = '\n'.join(sorted_contexts[:context_limit])
198
+ else:
199
+ context = '\n'.join(contexts[:context_limit])
200
 
 
201
  results = []
202
  model_results = qa_model(question=query, context=context, top_k=10)
203
  for result in model_results: