Spaces:
Runtime error
Runtime error
domenicrosati
commited on
Commit
Β·
a91b925
1
Parent(s):
e15c8b9
strict and then lenient
Browse files
app.py
CHANGED
@@ -151,18 +151,11 @@ st.markdown("""
|
|
151 |
|
152 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
153 |
confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
|
154 |
-
strict_mode = st.radio(
|
155 |
-
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
|
156 |
-
('lenient', 'strict'))
|
157 |
use_reranking = st.radio(
|
158 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
159 |
('yes', 'no'))
|
160 |
-
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300,
|
161 |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
162 |
-
use_query_exp = st.radio(
|
163 |
-
"(Experimental) use query expansion? Right now it just recommends queries",
|
164 |
-
('yes', 'no'))
|
165 |
-
suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
|
166 |
|
167 |
# def paraphrase(text, max_length=128):
|
168 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
@@ -180,7 +173,14 @@ def run_query(query):
|
|
180 |
# """)
|
181 |
limit = top_hits_limit or 100
|
182 |
context_limit = context_lim or 10
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
185 |
return st.markdown("""
|
186 |
<div class="container-fluid">
|
@@ -197,8 +197,7 @@ def run_query(query):
|
|
197 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
198 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
199 |
context = '\n'.join(sorted_contexts[:context_limit])
|
200 |
-
|
201 |
-
context = '\n'.join(contexts[:context_limit])
|
202 |
results = []
|
203 |
model_results = qa_model(question=query, context=context, top_k=10)
|
204 |
for result in model_results:
|
|
|
151 |
|
152 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
153 |
confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
|
|
|
|
|
|
|
154 |
use_reranking = st.radio(
|
155 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
156 |
('yes', 'no'))
|
157 |
+
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
|
158 |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
|
|
|
|
|
|
|
|
159 |
|
160 |
# def paraphrase(text, max_length=128):
|
161 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
|
173 |
# """)
|
174 |
limit = top_hits_limit or 100
|
175 |
context_limit = context_lim or 10
|
176 |
+
contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True)
|
177 |
+
contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False)
|
178 |
+
|
179 |
+
contexts = list(
|
180 |
+
set(contexts_strict + contexts_lenient)
|
181 |
+
)
|
182 |
+
orig_docs = orig_docs_strict + orig_docs_lenient
|
183 |
+
|
184 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
185 |
return st.markdown("""
|
186 |
<div class="container-fluid">
|
|
|
197 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
198 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
199 |
context = '\n'.join(sorted_contexts[:context_limit])
|
200 |
+
|
|
|
201 |
results = []
|
202 |
model_results = qa_model(question=query, context=context, top_k=10)
|
203 |
for result in model_results:
|