Spaces:
Runtime error
Runtime error
domenicrosati
commited on
Commit
Β·
f1fd3e1
1
Parent(s):
f5555cd
use ms2 for summarization
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import pipeline
|
3 |
import requests
|
4 |
from bs4 import BeautifulSoup
|
5 |
import nltk
|
@@ -149,10 +149,11 @@ def init_models():
|
|
149 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
150 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
151 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
152 |
-
|
153 |
-
|
|
|
154 |
|
155 |
-
qa_model, reranker, stop, device,
|
156 |
|
157 |
|
158 |
def clean_query(query, strict=True, clean=True):
|
@@ -270,15 +271,71 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
|
270 |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
271 |
return None
|
272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
def gen_summary(query, sorted_result):
|
275 |
-
|
276 |
-
|
|
|
277 |
st.markdown(f"""
|
278 |
<div class="container-fluid">
|
279 |
<div class="row align-items-start">
|
280 |
<div class="col-md-12 col-sm-12">
|
281 |
-
<strong>Answer:</strong> {
|
282 |
</div>
|
283 |
</div>
|
284 |
</div>
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline, AutoTokenizer, LEDForConditionalGeneration
|
3 |
import requests
|
4 |
from bs4 import BeautifulSoup
|
5 |
import nltk
|
|
|
149 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
150 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
151 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
152 |
+
summ_tok = AutoTokenizer.from_pretrained('allenai/led-base-16384-ms2')
|
153 |
+
summ_mdl = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384-ms2')
|
154 |
+
return question_answerer, reranker, stop, device, summ_mdl, summ_tok
|
155 |
|
156 |
+
qa_model, reranker, stop, device, summ_mdl, summ_tok = init_models() # queryexp_model, queryexp_tokenizer
|
157 |
|
158 |
|
159 |
def clean_query(query, strict=True, clean=True):
|
|
|
271 |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
272 |
return None
|
273 |
|
274 |
+
def process_document(documents, tokenizer, docsep_token_id, pad_token_id, device=device):
|
275 |
+
input_ids_all=[]
|
276 |
+
for data in documents:
|
277 |
+
all_docs = data.split("|||||")
|
278 |
+
for i, doc in enumerate(all_docs):
|
279 |
+
doc = doc.replace("\n", " ")
|
280 |
+
doc = " ".join(doc.split())
|
281 |
+
all_docs[i] = doc
|
282 |
+
|
283 |
+
#### concat with global attention on doc-sep
|
284 |
+
input_ids = []
|
285 |
+
for doc in all_docs:
|
286 |
+
input_ids.extend(
|
287 |
+
tokenizer.encode(
|
288 |
+
doc,
|
289 |
+
truncation=True,
|
290 |
+
max_length=4096 // len(all_docs),
|
291 |
+
)[1:-1]
|
292 |
+
)
|
293 |
+
input_ids.append(docsep_token_id)
|
294 |
+
input_ids = (
|
295 |
+
[tokenizer.bos_token_id]
|
296 |
+
+ input_ids
|
297 |
+
+ [tokenizer.eos_token_id]
|
298 |
+
)
|
299 |
+
input_ids_all.append(torch.tensor(input_ids))
|
300 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
301 |
+
input_ids_all, batch_first=True, padding_value=pad_token_id
|
302 |
+
)
|
303 |
+
return input_ids
|
304 |
+
|
305 |
+
|
306 |
+
def batch_process(batch, model, tokenizer, docsep_token_id, pad_token_id, device=device):
|
307 |
+
input_ids=process_document(batch['document'], tokenizer, docsep_token_id, pad_token_id)
|
308 |
+
# get the input ids and attention masks together
|
309 |
+
global_attention_mask = torch.zeros_like(input_ids).to(device)
|
310 |
+
input_ids = input_ids.to(device)
|
311 |
+
# put global attention on <s> token
|
312 |
+
|
313 |
+
global_attention_mask[:, 0] = 1
|
314 |
+
global_attention_mask[input_ids == docsep_token_id] = 1
|
315 |
+
generated_ids = model.generate(
|
316 |
+
input_ids=input_ids,
|
317 |
+
global_attention_mask=global_attention_mask,
|
318 |
+
use_cache=True,
|
319 |
+
max_length=1024,
|
320 |
+
num_beams=5,
|
321 |
+
)
|
322 |
+
generated_str = tokenizer.batch_decode(
|
323 |
+
generated_ids.tolist(), skip_special_tokens=True
|
324 |
+
)
|
325 |
+
result={}
|
326 |
+
result['generated_summaries'] = generated_str
|
327 |
+
return result
|
328 |
+
|
329 |
|
330 |
def gen_summary(query, sorted_result):
|
331 |
+
pad_token_id = summ_tok.pad_token_id
|
332 |
+
docsep_token_id = summ_tok.convert_tokens_to_ids("</s>")
|
333 |
+
out = batch_process({ 'document': [f'||||'.join([f'{query} '.join(r['texts']) + r['context'] for r in sorted_result])]}, summ_mdl, summ_tok, docsep_token_id, pad_token_id)
|
334 |
st.markdown(f"""
|
335 |
<div class="container-fluid">
|
336 |
<div class="row align-items-start">
|
337 |
<div class="col-md-12 col-sm-12">
|
338 |
+
<strong>Answer:</strong> {out['generated_summaries'][0]}
|
339 |
</div>
|
340 |
</div>
|
341 |
</div>
|