Spaces:
Runtime error
Runtime error
domenicrosati
commited on
Commit
Β·
69d7ac6
1
Parent(s):
4c36cd4
add ability to specify strict or lenient
Browse files
app.py
CHANGED
@@ -41,6 +41,10 @@ def search(term, limit=10, clean=True, strict=True):
|
|
41 |
'Authorization': f'Bearer {SCITE_API_KEY}'
|
42 |
}
|
43 |
)
|
|
|
|
|
|
|
|
|
44 |
return (
|
45 |
[remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']],
|
46 |
[(doc['doi'], doc['citations'], doc['title'])
|
@@ -80,6 +84,7 @@ def init_models():
|
|
80 |
|
81 |
qa_model, reranker, stop, device = init_models()
|
82 |
|
|
|
83 |
def clean_query(query, strict=True, clean=True):
|
84 |
operator = ' '
|
85 |
if strict:
|
@@ -136,6 +141,10 @@ st.markdown("""
|
|
136 |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
137 |
""", unsafe_allow_html=True)
|
138 |
|
|
|
|
|
|
|
|
|
139 |
def run_query(query):
|
140 |
if device == 'cpu':
|
141 |
limit = 50
|
@@ -143,7 +152,7 @@ def run_query(query):
|
|
143 |
else:
|
144 |
limit = 100
|
145 |
context_limit = 25
|
146 |
-
contexts, orig_docs = search(query, limit=limit)
|
147 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
148 |
return st.markdown("""
|
149 |
<div class="container-fluid">
|
|
|
41 |
'Authorization': f'Bearer {SCITE_API_KEY}'
|
42 |
}
|
43 |
)
|
44 |
+
try:
|
45 |
+
req.json()
|
46 |
+
except:
|
47 |
+
return [], []
|
48 |
return (
|
49 |
[remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']],
|
50 |
[(doc['doi'], doc['citations'], doc['title'])
|
|
|
84 |
|
85 |
qa_model, reranker, stop, device = init_models()
|
86 |
|
87 |
+
|
88 |
def clean_query(query, strict=True, clean=True):
|
89 |
operator = ' '
|
90 |
if strict:
|
|
|
141 |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
142 |
""", unsafe_allow_html=True)
|
143 |
|
144 |
+
strict_mode = st.radio(
|
145 |
+
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
|
146 |
+
('strict', 'lenient'))
|
147 |
+
|
148 |
def run_query(query):
|
149 |
if device == 'cpu':
|
150 |
limit = 50
|
|
|
152 |
else:
|
153 |
limit = 100
|
154 |
context_limit = 25
|
155 |
+
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|
156 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
157 |
return st.markdown("""
|
158 |
<div class="container-fluid">
|