File size: 1,141 Bytes
d84f828
 
 
44bc19d
82c804e
 
 
 
 
44bc19d
 
82c804e
 
44bc19d
 
82c804e
44bc19d
0b85bb8
44bc19d
82c804e
 
 
 
 
 
 
 
 
 
 
 
44bc19d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
---
license: apache-2.0
---

# Model description

A BertForSequenceClassification model that is finetuned on Wikipedia for zero-shot text classification. For details, see our NAACL'22 paper. 


# Usage

Concatenate the text sentence with each of the candidate labels as input to the model. The model will output a score for each label. Below is an example.

```
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

tokenizer = AutoTokenizer.from_pretrained("CogComp/ZeroShotWiki")
model = AutoModelForSequenceClassification.from_pretrained("CogComp/ZeroShotWiki")

labels = ["sports", "business", "politics"]
texts = ["As of the 2018 FIFA World Cup, twenty-one final tournaments have been held and a total of 79 national teams have competed."]

with torch.no_grad():
    for text in texts:
        label_score = {}
        for label in labels:
            inputs = tokenizer(text, label, return_tensors='pt')
            out = model(**inputs)
            label_score[label]=float(torch.nn.functional.softmax(out[0], dim=-1)[0][0])
        print(label_score)  # Predict the label with the highest score
```