Update README.md
Browse files
README.md
CHANGED
@@ -2,11 +2,31 @@
|
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
|
|
|
|
|
|
|
|
|
|
|
5 |
# Usage
|
6 |
|
|
|
|
|
7 |
```
|
8 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
9 |
|
10 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
11 |
model = AutoModelForSequenceClassification.from_pretrained("CogComp/ZeroShotWiki")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
```
|
|
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
|
5 |
+
# Model description
|
6 |
+
|
7 |
+
A BertForSequenceClassification model that is finetuned on Wikipedia for zero-shot text classification. For details, see our NAACL'22 paper.
|
8 |
+
|
9 |
+
|
10 |
# Usage
|
11 |
|
12 |
+
Concatenate the text sentence with each of the candidate labels as input to the model. The model will output a score for each label. Below is an example.
|
13 |
+
|
14 |
```
|
15 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
16 |
+
import torch
|
17 |
|
18 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
19 |
model = AutoModelForSequenceClassification.from_pretrained("CogComp/ZeroShotWiki")
|
20 |
+
|
21 |
+
labels = ["sports", "business", "politics"]
|
22 |
+
texts = ["As of the 2018 FIFA World Cup, twenty-one final tournaments have been held and a total of 79 national teams have competed."]
|
23 |
+
|
24 |
+
with torch.no_grad():
|
25 |
+
for text in texts:
|
26 |
+
label_score = {}
|
27 |
+
for label in labels:
|
28 |
+
inputs = tokenizer(text, label, return_tensors='pt')
|
29 |
+
out = model(**inputs)
|
30 |
+
label_score[label]=float(torch.nn.functional.softmax(out[0], dim=-1)[0][0])
|
31 |
+
print(label_score) # Predict the label with the highest score
|
32 |
```
|