FactCC / README.md
manueldeprada's picture
Update README.md
c7b3148
metadata
license: bsd-3-clause-clear
datasets:
  - cnn_dailymail
language:
  - en
metrics:
  - f1

FactCC factuality prediction model

Original paper: Evaluating the Factual Consistency of Abstractive Text Summarization

This is a more modern implementation of the model and code from the original github repo

This model is trained to predict whether a summary is factual with respect to the original text. Basic usage:

from transformers import BertForSequenceClassification, BertTokenizer
model_path = 'manueldeprada/FactCC'

tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)

text='''The US has "passed the peak" on new coronavirus cases, the White House reported. They predict that some states would reopen this month.
The US has over 637,000 confirmed Covid-19 cases and over 30,826 deaths, the highest for any country in the world.'''
wrong_summary = '''The pandemic has almost not affected the US'''

input_dict = tokenizer(text, wrong_summary, max_length=512, padding='max_length', truncation='only_first', return_tensors='pt')
logits = model(**input_dict).logits
pred = logits.argmax(dim=1)
model.config.id2label[pred.item()] # prints: INCORRECT

It can also be used with a pipeline. Beware that since pipelines are not thought to be used with pair of sentences, and you have to use this double-list hack:

>>> from transformers import pipeline

>>> pipe=pipeline(model="manueldeprada/FactCC")
>>> pipe([[[text1,summary1]],[[text2,summary2]]],truncation='only_first',padding='max_length')
# output [{'label': 'INCORRECT', 'score': 0.9979124665260315}, {'label': 'CORRECT', 'score': 0.879124665260315}]

Example on how to perform batched inference to reproduce authors results on the test set:

def batched_FactCC(text_l, summary_l, max_length=512):    
    input_dict = tokenizer(text_l, summary_l, max_length=max_length, padding='max_length', truncation='only_first', return_tensors='pt')
    with torch.no_grad():
        logits = model(**input_dict).logits
        preds = logits.argmax(dim=1)
        return logits, preds

texts = []
claims = []
labels = []
with open('factCC/annotated_data/test/data-dev.jsonl', 'r') as file:
    for line in file:
        obj = json.loads(line)  # Load the JSON data from each line
        texts.append(obj['text'])
        claims.append(obj['claim'])
        labels.append(model.config.label2id[o['label']])

preds = []
batch_size = 8
for i in tqdm(range(0, len(texts), batch_size)):
    batch_texts = texts[i:i+batch_size]
    batch_claims = claims[i:i+batch_size]
    _, preds = fact_cc(batch_texts, batch_claims)
    preds.extend(preds.tolist())

print(f"F1 micro: {f1_score(labels, preds, average='micro')}")
print(f"Balanced accuracy: {balanced_accuracy_score(labels, preds)}")