ZeroShotWiki / README.md
hantian2's picture
Update README.md
0b85bb8
metadata
license: apache-2.0

Model description

A BertForSequenceClassification model that is finetuned on Wikipedia for zero-shot text classification. For details, see our NAACL'22 paper.

Usage

Concatenate the text sentence with each of the candidate labels as input to the model. The model will output a score for each label. Below is an example.

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

tokenizer = AutoTokenizer.from_pretrained("CogComp/ZeroShotWiki")
model = AutoModelForSequenceClassification.from_pretrained("CogComp/ZeroShotWiki")

labels = ["sports", "business", "politics"]
texts = ["As of the 2018 FIFA World Cup, twenty-one final tournaments have been held and a total of 79 national teams have competed."]

with torch.no_grad():
    for text in texts:
        label_score = {}
        for label in labels:
            inputs = tokenizer(text, label, return_tensors='pt')
            out = model(**inputs)
            label_score[label]=float(torch.nn.functional.softmax(out[0], dim=-1)[0][0])
        print(label_score)  # Predict the label with the highest score