run training with easy setup
Browse files- README.md +5 -25
- requirements.txt +0 -3
- setup_venv.sh +19 -0
- train.py +5 -4
- train.sh +18 -0
README.md
CHANGED
@@ -1,36 +1,16 @@
|
|
1 |
-
#
|
2 |
|
3 |
A Transformer-VAE made using flax.
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
Builds on T5, using an autoencoder to convert it into a VAE.
|
8 |
-
|
9 |
-
[See training logs.](https://wandb.ai/fraser/flax-vae)
|
10 |
-
|
11 |
-
## ToDo
|
12 |
-
|
13 |
-
- [ ] Basic training script working. (Fraser + Theo)
|
14 |
-
- [ ] Add MMD loss (Theo)
|
15 |
|
16 |
-
|
17 |
-
- [ ] Make a tokenizer using the OPTIMUS tokenized dataset.
|
18 |
-
- [ ] Train on the OPTIMUS wikipedia sentences dataset.
|
19 |
-
|
20 |
-
- [ ] Make Huggingface widget interpolating sentences! (???) https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-build-a-demo
|
21 |
-
|
22 |
-
Optional ToDos:
|
23 |
-
|
24 |
-
- [ ] Add Funnel transformer encoder to FLAX (don't need weights).
|
25 |
-
- [ ] Train a Funnel-encoder + T5-decoder transformer VAE.
|
26 |
|
27 |
-
|
28 |
-
- [ ] Poetry (https://www.gwern.net/GPT-2#data-the-project-gutenberg-poetry-corpus)
|
29 |
-
- [ ] 8-bit music (https://github.com/chrisdonahue/LakhNES)
|
30 |
|
31 |
## Setup
|
32 |
|
33 |
-
Follow all steps to install dependencies from https://
|
34 |
|
35 |
- [ ] Find dataset storage site.
|
36 |
- [ ] Ask JAX team for dataset storage.
|
|
|
1 |
+
# T5-VAE-Python (flax) (WIP)
|
2 |
|
3 |
A Transformer-VAE made using flax.
|
4 |
|
5 |
+
It has been trained to interpolate on lines of Python code form the [python-lines dataset](https://huggingface.co/datasets/Fraser/python-lines).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
Builds on T5, using an autoencoder to convert it into an MMD-VAE.
|
|
|
|
|
10 |
|
11 |
## Setup
|
12 |
|
13 |
+
Follow all steps to install dependencies from https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md#tpu-vm
|
14 |
|
15 |
- [ ] Find dataset storage site.
|
16 |
- [ ] Ask JAX team for dataset storage.
|
requirements.txt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
jax
|
2 |
-
jaxlib
|
3 |
-
-r requirements-tpu.txt
|
|
|
|
|
|
|
|
setup_venv.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# setup training on a TPU VM
|
2 |
+
rm -fr venv
|
3 |
+
python3 -m venv venv
|
4 |
+
source venv/bin/activate
|
5 |
+
pip install -U pip
|
6 |
+
pip install -U wheel
|
7 |
+
pip install requests
|
8 |
+
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
9 |
+
|
10 |
+
cd ..
|
11 |
+
git clone https://github.com/huggingface/transformers.git
|
12 |
+
cd transformers
|
13 |
+
pip install -e ".[flax]"
|
14 |
+
cd ..
|
15 |
+
|
16 |
+
git clone https://github.com/huggingface/datasets.git
|
17 |
+
cd datasets
|
18 |
+
pip install -e ".[streaming]"
|
19 |
+
cd ..
|
train.py
CHANGED
@@ -2,8 +2,6 @@
|
|
2 |
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
|
3 |
|
4 |
TODO:
|
5 |
-
- [x] Get this running.
|
6 |
-
- [x] Don't make decoder input ids.
|
7 |
- [ ] Add reg loss
|
8 |
- [x] calculate MMD loss
|
9 |
- [ ] schedule MMD loss weight
|
@@ -87,6 +85,10 @@ class ModelArguments:
|
|
87 |
"help": "Number of dimensions to use for each latent token."
|
88 |
},
|
89 |
)
|
|
|
|
|
|
|
|
|
90 |
config_path: Optional[str] = field(
|
91 |
default=None, metadata={"help": "Pretrained config path"}
|
92 |
)
|
@@ -361,8 +363,7 @@ def main():
|
|
361 |
model = FlaxT5VaeForAutoencoding.from_pretrained(
|
362 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
363 |
)
|
364 |
-
|
365 |
-
assert(model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size.")
|
366 |
else:
|
367 |
vocab_size = len(tokenizer)
|
368 |
config.t5.vocab_size = vocab_size
|
|
|
2 |
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
|
3 |
|
4 |
TODO:
|
|
|
|
|
5 |
- [ ] Add reg loss
|
6 |
- [x] calculate MMD loss
|
7 |
- [ ] schedule MMD loss weight
|
|
|
85 |
"help": "Number of dimensions to use for each latent token."
|
86 |
},
|
87 |
)
|
88 |
+
add_special_tokens: bool = field(
|
89 |
+
default=False,
|
90 |
+
metadata={"help": "Add these special tokens to the tokenizer: {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}"},
|
91 |
+
)
|
92 |
config_path: Optional[str] = field(
|
93 |
default=None, metadata={"help": "Pretrained config path"}
|
94 |
)
|
|
|
363 |
model = FlaxT5VaeForAutoencoding.from_pretrained(
|
364 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
365 |
)
|
366 |
+
assert model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
|
|
|
367 |
else:
|
368 |
vocab_size = len(tokenizer)
|
369 |
config.t5.vocab_size = vocab_size
|
train.sh
CHANGED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export RUN_NAME=single_latent
|
2 |
+
|
3 |
+
./venv/bin/python train.py \
|
4 |
+
--t5_model_name_or_path="t5-base" \
|
5 |
+
--output_dir="output/${RUN_NAME}" \
|
6 |
+
--overwrite_output_dir \
|
7 |
+
--dataset_name="Fraser/python-lines" \
|
8 |
+
--do_train --do_eval \
|
9 |
+
--n_latent_tokens 1 \
|
10 |
+
--latent_token_size 32 \
|
11 |
+
--save_steps="2500" \
|
12 |
+
--eval_steps="2500" \
|
13 |
+
--block_size="32" \
|
14 |
+
--per_device_train_batch_size="10" \
|
15 |
+
--per_device_eval_batch_size="10" \
|
16 |
+
--overwrite_output_dir \
|
17 |
+
--num_train_epochs="1" \
|
18 |
+
--push_to_hub \
|