JunxiongWang commited on
Commit
0e2cc08
·
1 Parent(s): 50fe51f

Update ReadMe

Browse files
Files changed (1) hide show
  1. README.md +56 -53
README.md CHANGED
@@ -1,56 +1,18 @@
1
- ## Pretrained Models
2
- |**Sentence Length**|**Trained Tokens**|**Link**|
3
- |----------|----------|----------|
4
- |128|~11B|[BiGS-11B-128](https://drive.google.com/drive/folders/1-nhzeWVgpXwMyNEQ5j-MwJxSzwKyT2an?usp=sharing)
5
- |128|~29B|[BiGS-29B-128](https://drive.google.com/drive/folders/10Mtl8_XUJb2mmHLyRC9x1wltdIWy6aaP?usp=sharing)
6
- |128|~97B|[BiGS-97B-128](https://huggingface.co/JunxiongWang/BiGS_128)
7
- |512|~108B|[BiGS-108B-512](https://huggingface.co/JunxiongWang/BiGS_512)
8
- |1024|~110B|[BiGS-110B-1024](https://huggingface.co/JunxiongWang/BiGS_1024)
9
- |4096|~110B|[BiGS-110B-4096](https://huggingface.co/JunxiongWang/BiGS_4096)
10
 
11
- ### MNLI Checkpoints
12
 
13
- |**Sentence Length**|**Trained Tokens**|**Model**|
14
- |----------|----------|----------|
15
- |128|~11B|[BiGS-11B-128MNLI](https://drive.google.com/drive/folders/1-tn5ar_tRi9DnK_bNMZtPpappUdNnVET?usp=sharing)
16
- |128|~29B|[BiGS-29B-128MNLI](https://drive.google.com/drive/folders/116JwMbChYp9tBuPTz5jbiaulhXrXt1P2?usp=sharing)
17
- |128|~97B|[BiGS-97B-128MNLI](https://huggingface.co/JunxiongWang/BiGS_128_MNLI)
18
- |512|~108B|[BiGS-108B-512MNLI](https://huggingface.co/JunxiongWang/BiGS_512_MNLI)
19
 
20
- <!-- Sentence length: 128
21
 
22
- |**Training Tokens**|**Model**|
23
- |----------|----------|
24
- |~11B|[https://drive.google.com/drive/folders/1-nhzeWVgpXwMyNEQ5j-MwJxSzwKyT2an?usp=sharing](https://drive.google.com/drive/folders/1-nhzeWVgpXwMyNEQ5j-MwJxSzwKyT2an?usp=sharing)
25
- |~29B|[https://drive.google.com/drive/folders/10Mtl8_XUJb2mmHLyRC9x1wltdIWy6aaP?usp=sharing](https://drive.google.com/drive/folders/10Mtl8_XUJb2mmHLyRC9x1wltdIWy6aaP?usp=sharing)
26
- |~97B|[https://huggingface.co/JunxiongWang/BiGS_128](https://huggingface.co/JunxiongWang/BiGS_128)
27
- -->
28
 
29
- <!-- Sentence length: 512
30
-
31
- |**Training Tokens**|**Model**|
32
- |----------|----------|
33
- |~108B|[https://huggingface.co/JunxiongWang/BiGS_512](https://huggingface.co/JunxiongWang/BiGS_512) -->
34
-
35
- <!-- MNLI checkpoint:
36
-
37
- |**Training Tokens**|**Model**|
38
- |----------|----------|
39
- |~108B|[https://huggingface.co/JunxiongWang/BiGS_512_MNLI](https://huggingface.co/JunxiongWang/BiGS_512_MNLI)
40
-
41
- Sentence length: 1024
42
-
43
- |**Training Tokens**|**Model**|
44
- |----------|----------|
45
- |~110B|[https://huggingface.co/JunxiongWang/BiGS_1024](https://huggingface.co/JunxiongWang/BiGS_1024)
46
-
47
- Sentence length: 4096
48
-
49
- |**Training Tokens**|**Model**|
50
- |----------|----------|
51
- |~110B|[https://huggingface.co/JunxiongWang/BiGS_4096](https://huggingface.co/JunxiongWang/BiGS_4096)
52
- -->
53
- ## Example Usage
54
 
55
 
56
  ### Load Masked Language Model
@@ -62,10 +24,10 @@ from transformers import BertTokenizer
62
  from BiGS.modeling_flax_bigs import FlaxBiGSForMaskedLM
63
 
64
  tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
65
- model = FlaxBiGSForMaskedLM.from_pretrained('JunxiongWang/BiGS_128')
66
 
67
  text = "The goal of life is [MASK]."
68
- encoded_input = tokenizer(text, return_tensors='np', padding='max_length', max_length=128)
69
  output = model(**encoded_input)
70
  tokenizer.convert_ids_to_tokens(jnp.flip(jnp.argsort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10])
71
  # output: ['happiness', 'love', 'peace', 'perfection', 'life', 'enlightenment', 'god', 'survival', 'freedom', 'good']
@@ -73,9 +35,9 @@ jnp.flip(jnp.sort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103])
73
  # probability: [0.16052087, 0.04306792, 0.03651363, 0.03468223, 0.02927081, 0.02549769, 0.02385132, 0.02261189, 0.01672831, 0.01619471]
74
 
75
  text = "Paris is the [MASK] of France."
76
- encoded_input = tokenizer(text, return_tensors='np', padding='max_length', max_length=128)
77
  output = model(**encoded_input)
78
- tokenizer.convert_ids_to_tokens(jnp.flip(jnp.argsort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:8])
79
  # output: ['capital', 'centre', 'center', 'city', 'capitol', 'prefecture', 'headquarters', 'president', 'metropolis', 'heart']
80
  jnp.flip(jnp.sort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10]
81
  # probability: [0.9981787 , 0.00034076, 0.00026992, 0.00026926, 0.00017787, 0.00004816, 0.00004256, 0.00003716, 0.00003634, 0.00002893]
@@ -100,4 +62,45 @@ model = FlaxBiGSForQuestionAnswering.from_pretrained('JunxiongWang/BiGS_512')
100
  ```python
101
  from BiGS.modeling_flax_bigs import FlaxBiGSForMultipleChoice
102
  model = FlaxBiGSForMultipleChoice.from_pretrained('JunxiongWang/BiGS_512')
103
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Pretraining Without Attention(BiGS) <br>
 
 
 
 
 
 
 
 
2
 
3
+ ## Official JAX Models with maximal sequence length 512<br>
4
 
5
+ ### [Paper](https://arxiv.org/abs/2212.10544) | [![Hugging Face Hub](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Hub-blue)](https://huggingface.co/JunxiongWang) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Fz3OSRF3PZEF_dlnyJ3KZ8Bq35DfUrIB?usp=sharing)
6
+
7
+ <img width="537" alt="BiGS" src="https://user-images.githubusercontent.com/16102460/221464744-06b6538a-7e84-4c95-909f-239eab1dba71.png">
 
 
 
8
 
9
+ This [repository](https://github.com/jxiw/BiGS) contains BiGS's jax model definitions, pretrained models weights, training and fintuning code for our paper exploring using state space models for pretraining. You can find more details in our paper.
10
 
11
+ [**Pretraining Without Attention**](https://arxiv.org/abs/2212.10544)<br>
12
+ [Junxiong Wang](), [Jing Nathan Yan](), [Albert Gu](), [Alexander M.Rush]()
13
+ <br>Cornell University, Cornell Tech, DeepMind<br>
 
 
 
14
 
15
+ Transformers have been essential to pretraining success in NLP. While other architectures have been used, downstream accuracy is either significantly worse, or requires attention layers to match standard benchmarks such as GLUE. This work explores pretraining without attention by using recent advances in sequence routing based on state-space models (SSMs). Our proposed model, Bidirectional Gated SSM (BiGS), combines SSM layers with a multiplicative gating architecture that has been effective in simplified sequence modeling architectures. The model learns static layers that do not consider pair-wise interactions. Even so, BiGS is able to match BERT pretraining accuracy on GLUE and can be extended to long-form pretraining of 4096 tokens without approximation. Analysis shows that while the models have similar accuracy, the approach has significantly different inductive biases than BERT in terms of interactions and syntactic representations.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  ### Load Masked Language Model
 
24
  from BiGS.modeling_flax_bigs import FlaxBiGSForMaskedLM
25
 
26
  tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
27
+ model = FlaxBiGSForMaskedLM.from_pretrained('JunxiongWang/BiGS_512')
28
 
29
  text = "The goal of life is [MASK]."
30
+ encoded_input = tokenizer(text, return_tensors='np', padding='max_length', max_length=512)
31
  output = model(**encoded_input)
32
  tokenizer.convert_ids_to_tokens(jnp.flip(jnp.argsort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10])
33
  # output: ['happiness', 'love', 'peace', 'perfection', 'life', 'enlightenment', 'god', 'survival', 'freedom', 'good']
 
35
  # probability: [0.16052087, 0.04306792, 0.03651363, 0.03468223, 0.02927081, 0.02549769, 0.02385132, 0.02261189, 0.01672831, 0.01619471]
36
 
37
  text = "Paris is the [MASK] of France."
38
+ encoded_input = tokenizer(text, return_tensors='np', padding='max_length', max_length=512)
39
  output = model(**encoded_input)
40
+ tokenizer.convert_ids_to_tokens(jnp.flip(jnp.argsort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10])
41
  # output: ['capital', 'centre', 'center', 'city', 'capitol', 'prefecture', 'headquarters', 'president', 'metropolis', 'heart']
42
  jnp.flip(jnp.sort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10]
43
  # probability: [0.9981787 , 0.00034076, 0.00026992, 0.00026926, 0.00017787, 0.00004816, 0.00004256, 0.00003716, 0.00003634, 0.00002893]
 
62
  ```python
63
  from BiGS.modeling_flax_bigs import FlaxBiGSForMultipleChoice
64
  model = FlaxBiGSForMultipleChoice.from_pretrained('JunxiongWang/BiGS_512')
65
+ ```
66
+
67
+
68
+ ### GLUE Experiments
69
+
70
+ GLUE is made up of a total of 9 different tasks. You can use this python [script](https://github.com/jxiw/BiGS/blob/main/run_glue2.py) to run GLUE tasks.
71
+
72
+ We finetune BiGS on TPU-v3 with 8 cores. Since the batch size per device is 2, the total number of batch size is 16.
73
+
74
+ ```
75
+ export TASK_NAME=cola
76
+
77
+ python run_glue2.py \
78
+ --model_name_or_path JunxiongWang/BiGS_512 \
79
+ --task_name $TASK_NAME \
80
+ --max_seq_length 512 \
81
+ --learning_rate 2e-5 \
82
+ --num_train_epochs 3 \
83
+ --per_device_train_batch_size 2 \
84
+ --logging_steps 100 \
85
+ --eval_steps 500 \
86
+ --weight_decay 0.01 \
87
+ --output_dir BiGS_$TASK_NAME/
88
+ ```
89
+
90
+ Those give us the following result
91
+
92
+ | Task | Metric | Result |
93
+ |-------|------------------------------|-------------|
94
+ | CoLA | Matthews corr | 67.9 |
95
+ | SST-2 | Accuracy | 93.8 |
96
+ | QQP | Accuracy/F1 | 91.4/88.4 |
97
+ | MNLI | Matched acc./Mismatched acc. | 86.2 |
98
+ | QNLI | Accuracy | 91.6 |
99
+ | MRPC | F1/Accuracy | 86.4/80.4 |
100
+ | STS-B | Pearson/Spearman corr. | 89.1/89.0 |
101
+ | RTE | Accuracy | 73.3 |
102
+
103
+
104
+
105
+
106
+