domenicrosati commited on
Commit
69d7ac6
Β·
1 Parent(s): 4c36cd4

add ability to specify strict or lenient

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -41,6 +41,10 @@ def search(term, limit=10, clean=True, strict=True):
41
  'Authorization': f'Bearer {SCITE_API_KEY}'
42
  }
43
  )
 
 
 
 
44
  return (
45
  [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']],
46
  [(doc['doi'], doc['citations'], doc['title'])
@@ -80,6 +84,7 @@ def init_models():
80
 
81
  qa_model, reranker, stop, device = init_models()
82
 
 
83
  def clean_query(query, strict=True, clean=True):
84
  operator = ' '
85
  if strict:
@@ -136,6 +141,10 @@ st.markdown("""
136
  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
137
  """, unsafe_allow_html=True)
138
 
 
 
 
 
139
  def run_query(query):
140
  if device == 'cpu':
141
  limit = 50
@@ -143,7 +152,7 @@ def run_query(query):
143
  else:
144
  limit = 100
145
  context_limit = 25
146
- contexts, orig_docs = search(query, limit=limit)
147
  if len(contexts) == 0 or not ''.join(contexts).strip():
148
  return st.markdown("""
149
  <div class="container-fluid">
 
41
  'Authorization': f'Bearer {SCITE_API_KEY}'
42
  }
43
  )
44
+ try:
45
+ req.json()
46
+ except:
47
+ return [], []
48
  return (
49
  [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']],
50
  [(doc['doi'], doc['citations'], doc['title'])
 
84
 
85
  qa_model, reranker, stop, device = init_models()
86
 
87
+
88
  def clean_query(query, strict=True, clean=True):
89
  operator = ' '
90
  if strict:
 
141
  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
142
  """, unsafe_allow_html=True)
143
 
144
+ strict_mode = st.radio(
145
+ "Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
146
+ ('strict', 'lenient'))
147
+
148
  def run_query(query):
149
  if device == 'cpu':
150
  limit = 50
 
152
  else:
153
  limit = 100
154
  context_limit = 25
155
+ contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
156
  if len(contexts) == 0 or not ''.join(contexts).strip():
157
  return st.markdown("""
158
  <div class="container-fluid">