Fraser commited on
Commit
6306a19
·
1 Parent(s): 0b69648

run training with easy setup

Browse files
Files changed (5) hide show
  1. README.md +5 -25
  2. requirements.txt +0 -3
  3. setup_venv.sh +19 -0
  4. train.py +5 -4
  5. train.sh +18 -0
README.md CHANGED
@@ -1,36 +1,16 @@
1
- # Transformer-VAE (flax) (WIP)
2
 
3
  A Transformer-VAE made using flax.
4
 
5
- Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
6
-
7
- Builds on T5, using an autoencoder to convert it into a VAE.
8
-
9
- [See training logs.](https://wandb.ai/fraser/flax-vae)
10
-
11
- ## ToDo
12
-
13
- - [ ] Basic training script working. (Fraser + Theo)
14
- - [ ] Add MMD loss (Theo)
15
 
16
- - [ ] Save a wikipedia sentences dataset to Huggingface (see original https://github.com/ChunyuanLI/Optimus/blob/master/data/download_datasets.md) (Mina)
17
- - [ ] Make a tokenizer using the OPTIMUS tokenized dataset.
18
- - [ ] Train on the OPTIMUS wikipedia sentences dataset.
19
-
20
- - [ ] Make Huggingface widget interpolating sentences! (???) https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-build-a-demo
21
-
22
- Optional ToDos:
23
-
24
- - [ ] Add Funnel transformer encoder to FLAX (don't need weights).
25
- - [ ] Train a Funnel-encoder + T5-decoder transformer VAE.
26
 
27
- - [ ] Additional datasets:
28
- - [ ] Poetry (https://www.gwern.net/GPT-2#data-the-project-gutenberg-poetry-corpus)
29
- - [ ] 8-bit music (https://github.com/chrisdonahue/LakhNES)
30
 
31
  ## Setup
32
 
33
- Follow all steps to install dependencies from https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm
34
 
35
  - [ ] Find dataset storage site.
36
  - [ ] Ask JAX team for dataset storage.
 
1
+ # T5-VAE-Python (flax) (WIP)
2
 
3
  A Transformer-VAE made using flax.
4
 
5
+ It has been trained to interpolate on lines of Python code form the [python-lines dataset](https://huggingface.co/datasets/Fraser/python-lines).
 
 
 
 
 
 
 
 
 
6
 
7
+ Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
 
 
 
 
 
 
 
 
 
8
 
9
+ Builds on T5, using an autoencoder to convert it into an MMD-VAE.
 
 
10
 
11
  ## Setup
12
 
13
+ Follow all steps to install dependencies from https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md#tpu-vm
14
 
15
  - [ ] Find dataset storage site.
16
  - [ ] Ask JAX team for dataset storage.
requirements.txt DELETED
@@ -1,3 +0,0 @@
1
- jax
2
- jaxlib
3
- -r requirements-tpu.txt
 
 
 
 
setup_venv.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # setup training on a TPU VM
2
+ rm -fr venv
3
+ python3 -m venv venv
4
+ source venv/bin/activate
5
+ pip install -U pip
6
+ pip install -U wheel
7
+ pip install requests
8
+ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
9
+
10
+ cd ..
11
+ git clone https://github.com/huggingface/transformers.git
12
+ cd transformers
13
+ pip install -e ".[flax]"
14
+ cd ..
15
+
16
+ git clone https://github.com/huggingface/datasets.git
17
+ cd datasets
18
+ pip install -e ".[streaming]"
19
+ cd ..
train.py CHANGED
@@ -2,8 +2,6 @@
2
  Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
3
 
4
  TODO:
5
- - [x] Get this running.
6
- - [x] Don't make decoder input ids.
7
  - [ ] Add reg loss
8
  - [x] calculate MMD loss
9
  - [ ] schedule MMD loss weight
@@ -87,6 +85,10 @@ class ModelArguments:
87
  "help": "Number of dimensions to use for each latent token."
88
  },
89
  )
 
 
 
 
90
  config_path: Optional[str] = field(
91
  default=None, metadata={"help": "Pretrained config path"}
92
  )
@@ -361,8 +363,7 @@ def main():
361
  model = FlaxT5VaeForAutoencoding.from_pretrained(
362
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
363
  )
364
- # TODO assert token embedding size == len(tokenizer)
365
- assert(model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size.")
366
  else:
367
  vocab_size = len(tokenizer)
368
  config.t5.vocab_size = vocab_size
 
2
  Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
3
 
4
  TODO:
 
 
5
  - [ ] Add reg loss
6
  - [x] calculate MMD loss
7
  - [ ] schedule MMD loss weight
 
85
  "help": "Number of dimensions to use for each latent token."
86
  },
87
  )
88
+ add_special_tokens: bool = field(
89
+ default=False,
90
+ metadata={"help": "Add these special tokens to the tokenizer: {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}"},
91
+ )
92
  config_path: Optional[str] = field(
93
  default=None, metadata={"help": "Pretrained config path"}
94
  )
 
363
  model = FlaxT5VaeForAutoencoding.from_pretrained(
364
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
365
  )
366
+ assert model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
 
367
  else:
368
  vocab_size = len(tokenizer)
369
  config.t5.vocab_size = vocab_size
train.sh CHANGED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export RUN_NAME=single_latent
2
+
3
+ ./venv/bin/python train.py \
4
+ --t5_model_name_or_path="t5-base" \
5
+ --output_dir="output/${RUN_NAME}" \
6
+ --overwrite_output_dir \
7
+ --dataset_name="Fraser/python-lines" \
8
+ --do_train --do_eval \
9
+ --n_latent_tokens 1 \
10
+ --latent_token_size 32 \
11
+ --save_steps="2500" \
12
+ --eval_steps="2500" \
13
+ --block_size="32" \
14
+ --per_device_train_batch_size="10" \
15
+ --per_device_eval_batch_size="10" \
16
+ --overwrite_output_dir \
17
+ --num_train_epochs="1" \
18
+ --push_to_hub \