domenicrosati commited on
Commit
b7e15be
Β·
1 Parent(s): a14da38

improve tokenization

Browse files
Files changed (1) hide show
  1. app.py +32 -13
app.py CHANGED
@@ -90,7 +90,7 @@ def find_source(text, docs):
90
  for snippet in doc[1]:
91
  if text in remove_html(snippet.get('snippet', '')):
92
  new_text = text
93
- for sent in remove_html(snippet.get('snippet', '')).split('.'):
94
  if text in sent:
95
  new_text = sent
96
  return {
@@ -103,7 +103,7 @@ def find_source(text, docs):
103
  }
104
  if text in remove_html(doc[3]):
105
  new_text = text
106
- for sent in remove_html(doc[3]).split('.'):
107
  if text in sent:
108
  new_text = sent
109
  return {
@@ -206,8 +206,8 @@ with st.expander("Settings (strictness, context limit, top hits)"):
206
  use_reranking = st.radio(
207
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
208
  ('yes', 'no'))
209
- top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
210
- context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
211
 
212
  # def paraphrase(text, max_length=128):
213
  # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
@@ -216,6 +216,22 @@ with st.expander("Settings (strictness, context limit, top hits)"):
216
  # preds = '\n * '.join(queries)
217
  # return preds
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def run_query(query):
220
  # if use_query_exp == 'yes':
221
  # query_exp = paraphrase(f"question2question: {query}")
@@ -224,6 +240,10 @@ def run_query(query):
224
  # * {query_exp}
225
  # """)
226
 
 
 
 
 
227
  # could also try fallback if there are no good answers by score...
228
  limit = top_hits_limit or 100
229
  context_limit = context_lim or 10
@@ -280,12 +300,9 @@ def run_query(query):
280
  "score": result['score'],
281
  "doi": support["supporting"]
282
  })
283
- sorted_result = sorted(results, key=lambda x: x['score'])
284
- sorted_result = list({
285
- result['context']: result for result in sorted_result
286
- }.values())
287
- sorted_result = sorted(
288
- sorted_result, key=lambda x: x['score'], reverse=True)
289
 
290
  if confidence_threshold == 0:
291
  threshold = 0
@@ -299,9 +316,11 @@ def run_query(query):
299
 
300
  for r in sorted_result:
301
  answer = r["answer"]
302
- ctx = remove_html(r["context"]).replace(answer, f"<mark>{answer}</mark>").replace(
303
- '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
304
- title = r.get("title", '').replace("_", " ")
 
 
305
  score = round(r["score"], 4)
306
  card(title, ctx, score, r['link'], r['doi'])
307
 
 
90
  for snippet in doc[1]:
91
  if text in remove_html(snippet.get('snippet', '')):
92
  new_text = text
93
+ for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
94
  if text in sent:
95
  new_text = sent
96
  return {
 
103
  }
104
  if text in remove_html(doc[3]):
105
  new_text = text
106
+ for sent in nltk.sent_tokenize(remove_html(doc[3])):
107
  if text in sent:
108
  new_text = sent
109
  return {
 
206
  use_reranking = st.radio(
207
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
208
  ('yes', 'no'))
209
+ top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 10)
210
+ context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 5)
211
 
212
  # def paraphrase(text, max_length=128):
213
  # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
 
216
  # preds = '\n * '.join(queries)
217
  # return preds
218
 
219
+
220
+ def group_results_by_context(results):
221
+ result_groups = {}
222
+ for result in results:
223
+ if result['context'] not in result_groups:
224
+ result_groups[result['context']] = result
225
+ result_groups[result['context']]['texts'] = []
226
+
227
+ result_groups[result['context']]['texts'].append(
228
+ result['answer']
229
+ )
230
+ if result['score'] > result_groups[result['context']]['score']:
231
+ result_groups[result['context']]['score'] = result['score']
232
+ return list(result_groups.values())
233
+
234
+
235
  def run_query(query):
236
  # if use_query_exp == 'yes':
237
  # query_exp = paraphrase(f"question2question: {query}")
 
240
  # * {query_exp}
241
  # """)
242
 
243
+ # address period in highlitht avoidability. Risk factors
244
+ # address poor tokenization Deletions involving chromosome region 4p16.3 cause WolfHirschhorn syndrome (WHS, OMIM 194190) [Battaglia et al, 2001].
245
+ # address highlight html
246
+
247
  # could also try fallback if there are no good answers by score...
248
  limit = top_hits_limit or 100
249
  context_limit = context_lim or 10
 
300
  "score": result['score'],
301
  "doi": support["supporting"]
302
  })
303
+
304
+ grouped_results = group_results_by_context(results)
305
+ sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
 
 
 
306
 
307
  if confidence_threshold == 0:
308
  threshold = 0
 
316
 
317
  for r in sorted_result:
318
  answer = r["answer"]
319
+ ctx = remove_html(r["context"])
320
+ for answer in r['texts']:
321
+ ctx = ctx.replace(answer, f"<mark>{answer}</mark>")
322
+ # .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
323
+ title = r.get("title", '')
324
  score = round(r["score"], 4)
325
  card(title, ctx, score, r['link'], r['doi'])
326