domenicrosati commited on
Commit
f1fd3e1
Β·
1 Parent(s): f5555cd

use ms2 for summarization

Browse files
Files changed (1) hide show
  1. app.py +64 -7
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
  import requests
4
  from bs4 import BeautifulSoup
5
  import nltk
@@ -149,10 +149,11 @@ def init_models():
149
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
150
  # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
  # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
152
- summarizer = pipeline("summarization")
153
- return question_answerer, reranker, stop, device, summarizer
 
154
 
155
- qa_model, reranker, stop, device, summarizer = init_models() # queryexp_model, queryexp_tokenizer
156
 
157
 
158
  def clean_query(query, strict=True, clean=True):
@@ -270,15 +271,71 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
270
  return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
271
  return None
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  def gen_summary(query, sorted_result):
275
- doc_sep = '\n'
276
- summary = summarizer(f'{query} '.join([f'{doc_sep}'.join(r['texts']) + r['context'] for r in sorted_result]))[0]['summary_text']
 
277
  st.markdown(f"""
278
  <div class="container-fluid">
279
  <div class="row align-items-start">
280
  <div class="col-md-12 col-sm-12">
281
- <strong>Answer:</strong> {summary}
282
  </div>
283
  </div>
284
  </div>
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, LEDForConditionalGeneration
3
  import requests
4
  from bs4 import BeautifulSoup
5
  import nltk
 
149
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
150
  # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
  # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
152
+ summ_tok = AutoTokenizer.from_pretrained('allenai/led-base-16384-ms2')
153
+ summ_mdl = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384-ms2')
154
+ return question_answerer, reranker, stop, device, summ_mdl, summ_tok
155
 
156
+ qa_model, reranker, stop, device, summ_mdl, summ_tok = init_models() # queryexp_model, queryexp_tokenizer
157
 
158
 
159
  def clean_query(query, strict=True, clean=True):
 
271
  return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
272
  return None
273
 
274
+ def process_document(documents, tokenizer, docsep_token_id, pad_token_id, device=device):
275
+ input_ids_all=[]
276
+ for data in documents:
277
+ all_docs = data.split("|||||")
278
+ for i, doc in enumerate(all_docs):
279
+ doc = doc.replace("\n", " ")
280
+ doc = " ".join(doc.split())
281
+ all_docs[i] = doc
282
+
283
+ #### concat with global attention on doc-sep
284
+ input_ids = []
285
+ for doc in all_docs:
286
+ input_ids.extend(
287
+ tokenizer.encode(
288
+ doc,
289
+ truncation=True,
290
+ max_length=4096 // len(all_docs),
291
+ )[1:-1]
292
+ )
293
+ input_ids.append(docsep_token_id)
294
+ input_ids = (
295
+ [tokenizer.bos_token_id]
296
+ + input_ids
297
+ + [tokenizer.eos_token_id]
298
+ )
299
+ input_ids_all.append(torch.tensor(input_ids))
300
+ input_ids = torch.nn.utils.rnn.pad_sequence(
301
+ input_ids_all, batch_first=True, padding_value=pad_token_id
302
+ )
303
+ return input_ids
304
+
305
+
306
+ def batch_process(batch, model, tokenizer, docsep_token_id, pad_token_id, device=device):
307
+ input_ids=process_document(batch['document'], tokenizer, docsep_token_id, pad_token_id)
308
+ # get the input ids and attention masks together
309
+ global_attention_mask = torch.zeros_like(input_ids).to(device)
310
+ input_ids = input_ids.to(device)
311
+ # put global attention on <s> token
312
+
313
+ global_attention_mask[:, 0] = 1
314
+ global_attention_mask[input_ids == docsep_token_id] = 1
315
+ generated_ids = model.generate(
316
+ input_ids=input_ids,
317
+ global_attention_mask=global_attention_mask,
318
+ use_cache=True,
319
+ max_length=1024,
320
+ num_beams=5,
321
+ )
322
+ generated_str = tokenizer.batch_decode(
323
+ generated_ids.tolist(), skip_special_tokens=True
324
+ )
325
+ result={}
326
+ result['generated_summaries'] = generated_str
327
+ return result
328
+
329
 
330
  def gen_summary(query, sorted_result):
331
+ pad_token_id = summ_tok.pad_token_id
332
+ docsep_token_id = summ_tok.convert_tokens_to_ids("</s>")
333
+ out = batch_process({ 'document': [f'||||'.join([f'{query} '.join(r['texts']) + r['context'] for r in sorted_result])]}, summ_mdl, summ_tok, docsep_token_id, pad_token_id)
334
  st.markdown(f"""
335
  <div class="container-fluid">
336
  <div class="row align-items-start">
337
  <div class="col-md-12 col-sm-12">
338
+ <strong>Answer:</strong> {out['generated_summaries'][0]}
339
  </div>
340
  </div>
341
  </div>