domenicrosati commited on
Commit
577cb80
Β·
1 Parent(s): 69d7ac6

add settings

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -141,17 +141,17 @@ st.markdown("""
141
  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
142
  """, unsafe_allow_html=True)
143
 
144
- strict_mode = st.radio(
145
- "Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
146
- ('strict', 'lenient'))
 
 
 
147
 
148
  def run_query(query):
149
- if device == 'cpu':
150
- limit = 50
151
- context_limit = 10
152
- else:
153
- limit = 100
154
- context_limit = 25
155
  contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
156
  if len(contexts) == 0 or not ''.join(contexts).strip():
157
  return st.markdown("""
 
141
  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
142
  """, unsafe_allow_html=True)
143
 
144
+ with st.expander("Settings (strictness, context limit, top hits)"):
145
+ strict_mode = st.radio(
146
+ "Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
147
+ ('strict', 'lenient'))
148
+ top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100 if torch.cuda.is_available() else 25)
149
+ context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
150
 
151
  def run_query(query):
152
+ limit = top_hits_limit or 100
153
+ context_limit = context_lim or 50
154
+
 
 
 
155
  contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
156
  if len(contexts) == 0 or not ''.join(contexts).strip():
157
  return st.markdown("""