File size: 1,141 Bytes
d84f828 44bc19d 82c804e 44bc19d 82c804e 44bc19d 82c804e 44bc19d 0b85bb8 44bc19d 82c804e 44bc19d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
---
license: apache-2.0
---
# Model description
A BertForSequenceClassification model that is finetuned on Wikipedia for zero-shot text classification. For details, see our NAACL'22 paper.
# Usage
Concatenate the text sentence with each of the candidate labels as input to the model. The model will output a score for each label. Below is an example.
```
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained("CogComp/ZeroShotWiki")
model = AutoModelForSequenceClassification.from_pretrained("CogComp/ZeroShotWiki")
labels = ["sports", "business", "politics"]
texts = ["As of the 2018 FIFA World Cup, twenty-one final tournaments have been held and a total of 79 national teams have competed."]
with torch.no_grad():
for text in texts:
label_score = {}
for label in labels:
inputs = tokenizer(text, label, return_tensors='pt')
out = model(**inputs)
label_score[label]=float(torch.nn.functional.softmax(out[0], dim=-1)[0][0])
print(label_score) # Predict the label with the highest score
``` |