flash attn pip install (#426)
Browse files* flash attn pip
* add packaging
* add packaging to apt get
* install flash attn in dockerfile
* remove unused whls
* add wheel
* clean up pr
fix packaging requirement for ci
upgrade pip for ci
skip build isolation for requiremnents to get flash-attn working
install flash-attn seperately
* install wheel for ci
* no flash-attn for basic cicd
* install flash-attn as pip extras
---------
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: mhenrichsen <[email protected]>
Co-authored-by: Mads Henrichsen <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
- .github/workflows/main.yml +6 -5
- README.md +1 -1
- docker/Dockerfile +2 -2
- docker/Dockerfile-base +1 -26
- requirements.txt +1 -0
- setup.py +4 -1
.github/workflows/main.yml
CHANGED
@@ -13,17 +13,17 @@ jobs:
|
|
13 |
fail-fast: false
|
14 |
matrix:
|
15 |
include:
|
16 |
-
- cuda:
|
17 |
cuda_version: 11.8.0
|
18 |
python_version: "3.9"
|
19 |
pytorch: 2.0.1
|
20 |
axolotl_extras:
|
21 |
-
- cuda:
|
22 |
cuda_version: 11.8.0
|
23 |
python_version: "3.10"
|
24 |
pytorch: 2.0.1
|
25 |
axolotl_extras:
|
26 |
-
- cuda:
|
27 |
cuda_version: 11.8.0
|
28 |
python_version: "3.9"
|
29 |
pytorch: 2.0.1
|
@@ -49,10 +49,11 @@ jobs:
|
|
49 |
with:
|
50 |
context: .
|
51 |
build-args: |
|
52 |
-
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}
|
|
|
53 |
file: ./docker/Dockerfile
|
54 |
push: ${{ github.event_name != 'pull_request' }}
|
55 |
-
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}
|
56 |
labels: ${{ steps.metadata.outputs.labels }}
|
57 |
build-axolotl-runpod:
|
58 |
needs: build-axolotl
|
|
|
13 |
fail-fast: false
|
14 |
matrix:
|
15 |
include:
|
16 |
+
- cuda: 118
|
17 |
cuda_version: 11.8.0
|
18 |
python_version: "3.9"
|
19 |
pytorch: 2.0.1
|
20 |
axolotl_extras:
|
21 |
+
- cuda: 118
|
22 |
cuda_version: 11.8.0
|
23 |
python_version: "3.10"
|
24 |
pytorch: 2.0.1
|
25 |
axolotl_extras:
|
26 |
+
- cuda: 118
|
27 |
cuda_version: 11.8.0
|
28 |
python_version: "3.9"
|
29 |
pytorch: 2.0.1
|
|
|
49 |
with:
|
50 |
context: .
|
51 |
build-args: |
|
52 |
+
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
53 |
+
CUDA=${{ matrix.cuda }}
|
54 |
file: ./docker/Dockerfile
|
55 |
push: ${{ github.event_name != 'pull_request' }}
|
56 |
+
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
57 |
labels: ${{ steps.metadata.outputs.labels }}
|
58 |
build-axolotl-runpod:
|
59 |
needs: build-axolotl
|
README.md
CHANGED
@@ -69,7 +69,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
|
69 |
```bash
|
70 |
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
71 |
|
72 |
-
pip3 install -e .
|
73 |
pip3 install -U git+https://github.com/huggingface/peft.git
|
74 |
|
75 |
# finetune lora
|
|
|
69 |
```bash
|
70 |
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
71 |
|
72 |
+
pip3 install -e .[flash-attn]
|
73 |
pip3 install -U git+https://github.com/huggingface/peft.git
|
74 |
|
75 |
# finetune lora
|
docker/Dockerfile
CHANGED
@@ -16,9 +16,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|
16 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
17 |
RUN cd axolotl && \
|
18 |
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
19 |
-
pip install -e .[
|
20 |
else \
|
21 |
-
pip install -e
|
22 |
fi
|
23 |
|
24 |
# fix so that git fetch/pull from remote works
|
|
|
16 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
17 |
RUN cd axolotl && \
|
18 |
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
19 |
+
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
20 |
else \
|
21 |
+
pip install -e .[flash-attn]; \
|
22 |
fi
|
23 |
|
24 |
# fix so that git fetch/pull from remote works
|
docker/Dockerfile-base
CHANGED
@@ -31,26 +31,6 @@ WORKDIR /workspace
|
|
31 |
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
32 |
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
33 |
|
34 |
-
|
35 |
-
FROM base-builder AS flash-attn-builder
|
36 |
-
|
37 |
-
WORKDIR /workspace
|
38 |
-
|
39 |
-
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
40 |
-
|
41 |
-
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
42 |
-
cd flash-attention && \
|
43 |
-
git checkout v2.0.4 && \
|
44 |
-
python3 setup.py bdist_wheel && \
|
45 |
-
cd csrc/fused_dense_lib && \
|
46 |
-
python3 setup.py bdist_wheel && \
|
47 |
-
cd ../xentropy && \
|
48 |
-
python3 setup.py bdist_wheel && \
|
49 |
-
cd ../rotary && \
|
50 |
-
python3 setup.py bdist_wheel && \
|
51 |
-
cd ../layer_norm && \
|
52 |
-
python3 setup.py bdist_wheel
|
53 |
-
|
54 |
FROM base-builder AS deepspeed-builder
|
55 |
|
56 |
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
@@ -90,13 +70,8 @@ RUN mkdir -p /workspace/wheels/bitsandbytes
|
|
90 |
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
91 |
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
92 |
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
93 |
-
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
|
94 |
-
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
|
95 |
-
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
|
96 |
-
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
|
97 |
-
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
|
98 |
|
99 |
-
RUN pip3 install wheels/deepspeed-*.whl
|
100 |
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
101 |
RUN git lfs install --skip-repo
|
102 |
RUN pip3 install awscli && \
|
|
|
31 |
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
32 |
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
FROM base-builder AS deepspeed-builder
|
35 |
|
36 |
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|
|
70 |
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
71 |
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
72 |
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
RUN pip3 install wheels/deepspeed-*.whl
|
75 |
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
76 |
RUN git lfs install --skip-repo
|
77 |
RUN pip3 install awscli && \
|
requirements.txt
CHANGED
@@ -6,6 +6,7 @@ addict
|
|
6 |
fire
|
7 |
PyYAML==6.0
|
8 |
datasets
|
|
|
9 |
sentencepiece
|
10 |
wandb
|
11 |
einops
|
|
|
6 |
fire
|
7 |
PyYAML==6.0
|
8 |
datasets
|
9 |
+
flash-attn==2.0.8
|
10 |
sentencepiece
|
11 |
wandb
|
12 |
einops
|
setup.py
CHANGED
@@ -7,6 +7,7 @@ with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
|
7 |
# don't include peft yet until we check the int4
|
8 |
# need to manually install peft for now...
|
9 |
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
|
|
10 |
reqs = [r for r in reqs if r and r[0] != "#"]
|
11 |
for r in reqs:
|
12 |
install_requires.append(r)
|
@@ -25,8 +26,10 @@ setup(
|
|
25 |
"gptq_triton": [
|
26 |
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
27 |
],
|
|
|
|
|
|
|
28 |
"extras": [
|
29 |
-
"flash-attn",
|
30 |
"deepspeed",
|
31 |
],
|
32 |
},
|
|
|
7 |
# don't include peft yet until we check the int4
|
8 |
# need to manually install peft for now...
|
9 |
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
10 |
+
reqs = [r for r in reqs if "flash-attn" not in r]
|
11 |
reqs = [r for r in reqs if r and r[0] != "#"]
|
12 |
for r in reqs:
|
13 |
install_requires.append(r)
|
|
|
26 |
"gptq_triton": [
|
27 |
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
28 |
],
|
29 |
+
"flash-attn": [
|
30 |
+
"flash-attn==2.0.8",
|
31 |
+
],
|
32 |
"extras": [
|
|
|
33 |
"deepspeed",
|
34 |
],
|
35 |
},
|