# Training
This directory provides a training code for Stable Cascade, as well as guides to download the models you need.
Specifically, you can find training scripts for the following use-cases:
- Text-to-Image
- ControlNet
- LoRA
- Image Reconstruction
#### Note:
A quick clarification, Stable Cascade uses Stage A & B to compress images and Stage C is used for the text-conditional
learning. Therefore, it makes sense to train a LoRA or ControlNet **only** for Stage C. You also don't train a LoRA or
ControlNet for the Stable Diffusion VAE right?
## Basics
In the [training configs](../configs/training) folder we provide config files for all trainings. All config files
follow a similar structure and only contain the most essential parameters you need to set. Let's take a look at the
structure each config follows:
At first, you will set the run name, checkpoint-, & output-folder and which version you want to train.
```yaml
experiment_id: stage_c_3b_controlnet_base
checkpoint_path: /path/to/checkpoint
output_path: /path/to/output
model_version: 3.6B
```
Next, you can set your [Weights & Biases]() information if you want to use it for logging.
```yaml
wandb_project: StableCascade
wandb_entity: wandb_username
```
Afterwards, you define the training parameters.
```yaml
lr: 1.0e-4
batch_size: 256
image_size: 768
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
grad_accum_steps: 1
updates: 500000
backup_every: 50000
save_every: 2000
warmup_updates: 1
use_fsdp: False
```
Most, of them will be quite familiar to you probably already. A few clarification tho: `updates` refers to the number of
training steps, `backup_every` creates additional checkpoints, so you can revert to earlier ones if you want,
`save_every` concerns how often models will be saved and sampling will be done. Furthermore, since distributed training
is essential when training large models from scratch or doing large finetunes, we have an option to use PyTorch's
[**Fully Shared Data Parallel (FSDP)**](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/). You
can use it by setting `use_fsdp: True`. Note, that you will need multiple GPUs for FSDP. However, this as mentioned
above, this is only needed for large runs. You can still train and finetune our largest models on a powerful single
machine.
Another thing we provide is training with **Multi-Aspect-Ratio**. You can set the aspect ratios you want in the list
for `multi_aspect_ratio`.
For diffusion models, having an EMA (Exponential Moving Average) model, can drastically improve the performance of
your model. To include an EMA model in your training you can set the following parameters, otherwise you can just
leave them away.
```yaml
ema_start_iters: 5000
ema_iters: 100
ema_beta: 0.9
```
Next, you can define the dataset that you want to use. Note, that the code uses
[webdataset](https://github.com/webdataset/webdataset) for this.
```yaml
webdataset_path:
- s3://path/to/your/first/dataset/on/s3
- file:/path/to/your/local/dataset.tar
```
You can set as many dataset paths as you want, and they can either be on
[Amazon S3 storage](https://aws.amazon.com/s3/) or just local.
There are a few more specifics to each kind of training and to datasets in general. These will be discussed below.
## Starting a Training
You can start an actual training very easily by first moving to the root directory of this repository (so [here](..)).
Next, the python command looks like the following:
```python
python3 training_file training_config
```
For example, if you want to train a LoRA model, the command would look like this:
```python
python3 train/train_c_lora.py configs/training/finetune_c_3b_lora.yaml
```
Moreover, we also provide a [bash script](example_train.sh) for working with slurm. Note, this assumes you have access to a cluster
that runs slurm as the cluster manager.
## Dataset
As mentioned above, the code uses [webdataset](https://github.com/webdataset/webdataset) for working with datasets,
because this library supports working with large amounts of data very easily. In case you want to **finetune** a model,
train a **LoRA** or train a **ControlNet**, you might not have them in a webdataset format. Therefore, here follows
a simple example how you can convert your dataset into the appropriate format.
1. Put all your images and captions into a folder
2. Rename them to have the same number / id as the name. For example:
`0000.jpg, 0000.txt, 0001.jpg, 0001.txt, 0002.jpg, 0002.txt, 0003.jpg, 0003.txt`
3. Run the following command: ``tar --sort=name -cf dataset.tar dataset/`` or manually create a tar file from the folder
4. Set the `webdataset_path: file:/path/to/your/local/dataset.tar` in the config file
Next, there are a few more settings that might be helpful to you, especially when working with large datasets that
might contain more information about images, like some kind of variables that you want to filter for. You can apply
dataset filters like the following in the config file:
```yaml
dataset_filters:
- ['aesthetic_score', 'lambda s: s > 4.5']
- ['nsfw_probability', 'lambda s: s < 0.01']
```
In this case, you would have `0000.json, 0001.json, 0002.json, 0003.json` in your dataset as well, with keys for
`aesthetic_score` and `nsfw_probability`.
## Starting from a Pretrained Model
If you want to finetune any model you need the pretrained models. You can find details on how to download them in the
[models](../models) section. After downloading them, you need to modify the checkpoint paths in the config file too.
See below for example config files.
## Text-to-Image Training
You can use the following configs for finetuning Stage C on your own datasets. All necessary parameters were already
explained above. So there is nothing new here. Take a look at the config for finetuning the
[3.6B Stage C](../configs/training/finetune_c_3b.yaml) and the [1B Stage C](../configs/training/finetune_c_1b.yaml).
## ControlNet Training
Training a ControlNet requires setting some extra parameters as well as adding the specific ControlNet Filter you want.
With filter, we simply mean a class that for example performs Canny Edge Detection, Human Pose Detection, etc.
```yaml
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
controlnet_filter: CannyFilter
controlnet_filter_params:
resize: 224
```
Here we need to give a little more detail on how Stage C's architecture looks like. It basically is just a stack of
residual blocks (convolutional and attention) that all work at the same latent resolution. We **do not** use a UNet.
And this is where `controlnet_blocks` comes in. It determines at which blocks you want to inject the controlling
information. This way, the ControlNet architecture differs from the common one used in Stable Diffusion where you
create an entire copy of the encoder of the UNet. With Stable Cascade it is a bit simpler and comes with the great
benefit of using much fewer parameters.
Next you define the class that filters the images and extracts the information you want to condition Stage C on
(Canny Edge Detection, Human Pose Detection, etc.) with the `controlnet_filter` parameter. In the example, we use the
CannyFilter defined in the [controlnet.py](../modules/controlnet.py) file. This is the place where you can add your own
ControlNet Filters. Lastly, `controlnet_filter_params` simply sets additional parameters to your `controlnet_filter`
class. That's it. You can view the example ControlNet configs for
[Inpainting / Outpainting](../configs/training/controlnet_c_3b_inpainting.yaml),
[Face Identity](../configs/training/controlnet_c_3b_identity.yaml),
[Canny](../configs/training/controlnet_c_3b_canny.yaml) and
[Super Resolution](../configs/training/controlnet_c_3b_sr.yaml).
## LoRA Training
To train a LoRA on Stage C, you have a few more parameters available to set for the training.
```yaml
module_filters: ['.attn']
rank: 4
train_tokens:
# - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
- ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails
```
These include the `module_filters`, which simply determines on what modules you want to train LoRA-layers. In the
example above, it is using the attention layers (`.attn`). Currently, only linear layers can be lora'd.
However, adding different layers (like convolutions) is possible as well.
You can also set the `rank` and if you want to learn a specific token for your training. The latter can be done by
setting `train_tokens` which expects a list of two things for each element: the token you want to train and a regex for
the token / tokens that you want to use for initializing the token. In the example above, a token `[fernando]` is
created and is initialized with the average of all tokens that include the word `dog`. Note, in order to **add** a new
token, **it has to start with `[` and end with `]`**. There is also the option of using existing tokens which will be
trained. For this, you just enter the token, **without** placing `[ ]` around it, like in the commented example above
for the token `sanil`. The second element is `null`, because we don't initialize this token and just finetune the
`snail` token.
You can find an example config for training a LoRA [here](../configs/training/finetune_c_3b_lora.yaml).
Additionally, you can also download an
[example dataset](https://huggingface.co/dome272/stable-cascade/blob/main/fernando.tar) for a cute little good boy dog.
Simply download it and set the path in the config file to your destination path.
## Image Reconstruction Training
Here we mainly focus on training **Stage B**, because it is doing most of the heavy lifting for the compression, while
Stage A only applies a very small compression and thus the results are near perfect. Why do we use Stage A even? The
reason is just to make the training and inference of Stage B cheaper and faster. With Stage A in place, Stage B works
at a 4x smaller space (for example `1 x 4 x 256 x 256` instead of `1 x 3 x 1024 x 1024`). Furthermore, we observed that
Stage B learns faster when using Stage A compared to learning Stage B directly at pixel space. Anyway, why would you
even want to train Stage B? Either you want to try to create an even higher compression or finetune on something
very specific. But this probably is a rare occasion. If you do want to, you can take a look at the training config
for the large Stage B [here](../configs/training/finetune_b_3b.yaml) or for the small Stage B
[here](../configs/training/finetune_b_700m.yaml).
## Remarks
The codebase is in early development. You might encounter unexpected errors or not perfectly optimized training and
inference code. We apologize for that in advance. If there is interest, we will continue releasing updates to it,
aiming to bring in the latest improvements and optimizations. Moreover, we would be more than happy to receive
ideas, feedback or even updates from people that would like to contribute. Cheers.