File size: 3,834 Bytes
89204f4 c9e76b1 4af1778 c9e76b1 89204f4 c9e76b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
---
license: apache-2.0
datasets:
- allenai/c4
---
# Hydra
![Hydra](./matrix_mixer.png "Matrix Mixer")
> **Hydra: Bidirectional State Space Models Through Generalized Matrix Mixers**\
> Sukjun Hwang*, Aakash Lahoti*, Tri Dao, Albert Gu\
> Paper: https://arxiv.org/abs/2407.09941 \
> Blogpost: https://goombalab.github.io/blog/2024/hydra-part1-matrix-mixer/
## About
## Installation
Follow the installation section of [Mamba](https://github.com/state-spaces/mamba); simply,
```bash
pip install mamba-ssm
```
[Option] For training BERT (`./hydra/bert`), install additional required packages via
```bash
pip install -r requirements.txt
```
## Usage
### Hydra Block
The quasiseparable matrix mixer, ***Hydra***, is our best model for bidirectional sequence processing (details in Section 3).\
The implementation is at [./hydra/modules/hydra.py](https://github.com/goombalab/hydra/blob/main/hydra/modules/hydra.py).
```python
import torch
from .hydra import Hydra
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Hydra(
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=7, # Local non-causal convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
```
### Matrix Mixer Block
The matrix mixer framework is implemented at [./hydra/modules/matrix_mixer.py](https://github.com/goombalab/hydra/blob/main/hydra/modules/matrix_mixer.py).\
You can easily integrate your own mixer matrix by following our implementations of various sequence mixers located at [./hydra/modules/matrix_mixers/](./hydra/modules/matrix_mixers/)!
```python
from .hydra import MatrixMixer
model = MatrixMixer(
"""
matrix_mixer_type: options for matrix_mixer_type
{'dense', 'toeplitz', 'vandermonde', 'cauchy', 'low_rank', 'attention', 'quasiseparable'}
is_data_dependent: boolean flag to parameterize the mixer matrix to SAM
"""
matrix_mixer_type=matrix_mixer_type,
is_data_dependent=is_data_dependent,
d_model=dim, # Model dimension d_model
qk_dim=qk_dim, # dimension for QK
).to("cuda")
y = model(x)
assert y.shape == x.shape
```
### BERT
Our code for training BERT ([./hydra/bert/](https://github.com/goombalab/hydra/blob/main/hydra/bert/)) is based on [MosaicBERT](https://github.com/mosaicml/examples/tree/main/examples/benchmarks/bert) and [M2](https://github.com/HazyResearch/m2).
Follow the instructions of MosaicBERT ([./hydra/bert/README.md](https://github.com/goombalab/hydra/blob/main/hydra/bert/README.md)) for details (*e.g.*, setting up dataset and running code). \
The default configurations for Hydra and MatrixMixer are located at:
- Pretrain: [./hydra/bert/yamls/pretrain](https://github.com/goombalab/hydra/blob/main/hydra/bert/yamls/pretrain)
- Finetune: [./hydra/bert/yamls/finetune](https://github.com/goombalab/hydra/blob/main/hydra/bert/yamls/finetune)
#### Example commands:
Pretrain Hydra on C4 using a single GPU:
```bash
python main.py yamls/pretrain/hydra.yaml
```
Pretrain Hydra on C4 using 8 GPUs:
```bash
composer -n 8 main.py yamls/pretrain/hydra.yaml
```
Finetune Hydra on GLUE:
```bash
python glue.py yamls/finetune/hydra.yaml
```
## Acknowledgement
We thank the authors of [Mamba](https://github.com/state-spaces/mamba), [MosaicBERT](https://github.com/mosaicml/examples/tree/main/examples/benchmarks/bert), and [M2](https://github.com/HazyResearch/m2) for their wonderful codebases.
## Citation
If you use this codebase, or otherwise find our work valuable, please cite Hydra:
```
@article{hydra,
title={Hydra: Bidirectional State Space Models Through Generalized Matrix Mixers},
author={Hwang, Sukjun and Lahoti, Aakash and Dao, Tri and Gu, Albert},
journal={arXiv preprint arXiv:2407.09941},
year={2024}
}
``` |