File size: 6,338 Bytes
f234b6a b922d15 f234b6a b922d15 f234b6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
---
license: mit
language:
- en
tags:
- t5
model-index:
- name: metro_t0pp_base
results:
- task:
type: natural-language-inference
dataset:
type: super_glue
name: RTE
config: rte
split: validation
metrics:
- type: accuracy
value: 75.41516245487364
- task:
type: natural-language-inference
dataset:
type: super_glue
name: CB
config: cb
split: validation
metrics:
- type: accuracy
value: 46.904761904761905
- task:
type: natural-language-inference
dataset:
type: anli
name: ANLI R1
split: dev_r1
metrics:
- type: accuracy
value: 34.233333333333334
- task:
type: natural-language-inference
dataset:
type: anli
name: ANLI R2
split: dev_r2
metrics:
- type: accuracy
value: 33.906666666666666
- task:
type: natural-language-inference
dataset:
type: anli
name: ANLI R3
split: dev_r3
metrics:
- type: accuracy
value: 35.71111111111111
- task:
type: coreference-resolution
dataset:
type: super_glue
name: WSC
config: wsc.fixed
split: validation
metrics:
- type: accuracy
value: 55.0
- task:
type: coreference-resolution
dataset:
type: winogrande
name: Winogrande XL
config: winogrande_xl
split: validation
metrics:
- type: accuracy
value: 51.22336227308604
- task:
type: multiple-choice-qa
dataset:
type: super_glue
name: COPA
config: copa
split: validation
metrics:
- type: accuracy
value: 69.5
- task:
type: multiple-choice-qa
dataset:
type: story_cloze
name: StoryCloze 2016
config: '2016'
split: validation
metrics:
- type: accuracy
value: 84.17958311063602
- task:
type: multiple-choice-qa
dataset:
type: hellaswag
name: HellaSwag
split: validation
metrics:
- type: accuracy
value: 43.432583150766774
- task:
type: word-sense-disambiguation
dataset:
type: super_glue
name: WiC
config: wic
split: validation
metrics:
- type: accuracy
value: 65.12539184952979
---
Official repository: https://github.com/gonglinyuan/metro_t0
# METRO-T0
Paper: [Model-Generated Pretraining Signals Improves Zero-Shot Generalization of Text-to-Text Transformers](https://arxiv.org/abs/2305.12567) (ACL 2023)
METRO-T0 is a T5-style text-to-text Transformer pretrained using model-generated pretraining signals, prompt-finetuned on a family of public NLP tasks proposed in [T0](https://arxiv.org/abs/2110.08207).
METRO-T0 is highly parameter efficient. For example, METRO-T0-Large++ (775M parameters) outperforms GPT-3 (175B parameters) and T0-3B (3B parameters) on a wide range of NLP tasks.
![The architecture of METRO-T0 during pretraining using BERT as the auxiliary model to generate signals](https://github.com/gonglinyuan/metro_t0/raw/main/assets/metro_t0_method.png)
![Prompt learning results of METRO-T0 versus our T0 baseline and T03B by Sanh et al. (2022) on 4 tasks in the T0 Eval benchmark. Each point denotes the accuracy using one prompt template, except that the median accuracy over all templates of T03B is indicated by the blue point. The plots of other tasks are in our paper.](https://github.com/gonglinyuan/metro_t0/raw/main/assets/metro_t0_selected_results.png)
## Use METRO-T0++-Base
To use METRO-T0++-Base in PyTorch (Python 3.7+, PyTorch 1.12+ and transformers 4.17+ are prerequisites), refer to the code snippet below:
```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("gonglinyuan/metro_t0pp_base", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("gonglinyuan/metro_t0pp_base", trust_remote_code=True)
input_text = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
inputs = tokenizer([input_text], max_length=512, truncation=True, add_special_tokens=True, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=256, do_sample=False)
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) # expected: positive
```
## Other METRO-T0 Models
| | # Parameters | Pretraining Data | Prompt-Finetuning Data |
|--------------------|--------------|------------------|------------------------|
| [METRO-T0-Base](https://huggingface.co/gonglinyuan/metro_t0_base) | 226M | Wikibook (16G) | T0 Train |
| [METRO-T0+-Base](https://huggingface.co/gonglinyuan/metro_t0p_base) | 226M | Wikibook (16G) | T0+ Train |
| [METRO-T0++-Base](https://huggingface.co/gonglinyuan/metro_t0pp_base) | 226M | Wikibook (16G) | T0++ Train |
| [METRO-T0-Base++](https://huggingface.co/gonglinyuan/metro_t0_basepp) | 256M | 160G corpus | T0 Train |
| [METRO-T0+-Base++](https://huggingface.co/gonglinyuan/metro_t0p_basepp) | 256M | 160G corpus | T0+ Train |
| [METRO-T0++-Base++](https://huggingface.co/gonglinyuan/metro_t0pp_basepp) | 256M | 160G corpus | T0++ Train |
| [METRO-T0-Large++](https://huggingface.co/gonglinyuan/metro_t0_largepp) | 775M | 160G corpus | T0 Train |
| [METRO-T0+-Large++](https://huggingface.co/gonglinyuan/metro_t0p_largepp) | 775M | 160G corpus | T0+ Train |
| [METRO-T0++-Large++](https://huggingface.co/gonglinyuan/metro_t0pp_largepp) | 775M | 160G corpus | T0++ Train |
## Citation
If you find the code and models useful for your research, please cite the following paper:
```
@misc{gong2023modelgenerated,
title={Model-Generated Pretraining Signals Improves Zero-Shot Generalization of Text-to-Text Transformers},
author={Linyuan Gong and Chenyan Xiong and Xiaodong Liu and Payal Bajaj and Yiqing Xie and Alvin Cheung and Jianfeng Gao and Xia Song},
year={2023},
eprint={2305.12567},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2305.12567}
}
``` |