cuda version checks
Browse files- README.md +24 -2
- float8_quantize.py +22 -6
- requirements.txt +2 -1
README.md
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Flux
|
2 |
|
3 |
This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. The API can be run via command-line arguments.
|
4 |
|
@@ -13,12 +13,34 @@ This repository contains an implementation of the Flux model, along with an API
|
|
13 |
|
14 |
## Installation
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
To install the required dependencies, run:
|
17 |
|
18 |
```bash
|
19 |
-
pip install -r requirements.txt
|
20 |
```
|
21 |
|
|
|
|
|
22 |
## Usage
|
23 |
|
24 |
You can run the API server using the following command:
|
|
|
1 |
+
# Flux FP8 (true) Matmul Implementation with FastAPI
|
2 |
|
3 |
This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. The API can be run via command-line arguments.
|
4 |
|
|
|
13 |
|
14 |
## Installation
|
15 |
|
16 |
+
This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
|
17 |
+
|
18 |
+
```bash
|
19 |
+
mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
|
20 |
+
mamba activate flux-fp8-matmul-api
|
21 |
+
|
22 |
+
# or with conda
|
23 |
+
conda create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
|
24 |
+
conda activate flux-fp8-matmul-api
|
25 |
+
|
26 |
+
# or with nightly... (which is what I am using) - also, just switch 'mamba' to 'conda' if you are using conda
|
27 |
+
mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch-nightly -c nvidia
|
28 |
+
mamba activate flux-fp8-matmul-api
|
29 |
+
|
30 |
+
# or with pip
|
31 |
+
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
32 |
+
# or pip nightly
|
33 |
+
python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
|
34 |
+
```
|
35 |
+
|
36 |
To install the required dependencies, run:
|
37 |
|
38 |
```bash
|
39 |
+
python -m pip install -r requirements.txt
|
40 |
```
|
41 |
|
42 |
+
If you get errors installing `torch-cublas-hgemm`, feel free to comment it out in requirements.txt, since it's not necessary, but will speed up inference for non-fp8 linear layers.
|
43 |
+
|
44 |
## Usage
|
45 |
|
46 |
You can run the API server using the following command:
|
float8_quantize.py
CHANGED
@@ -9,8 +9,21 @@ from torchao.float8.float8_utils import (
|
|
9 |
from torch.nn import init
|
10 |
import math
|
11 |
from torch.compiler import is_compiling
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
try:
|
15 |
from cublas_ops import CublasLinear
|
16 |
except ImportError:
|
@@ -244,19 +257,22 @@ class F8Linear(nn.Module):
|
|
244 |
x = self.quantize_input(x)
|
245 |
|
246 |
prev_dims = x.shape[:-1]
|
247 |
-
|
248 |
x = x.view(-1, self.in_features)
|
249 |
|
250 |
# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
|
251 |
-
|
252 |
x,
|
253 |
self.float8_data.T,
|
254 |
-
self.input_scale_reciprocal,
|
255 |
-
self.scale_reciprocal,
|
256 |
bias=self.bias,
|
257 |
out_dtype=self.weight.dtype,
|
258 |
use_fast_accum=True,
|
259 |
-
)
|
|
|
|
|
|
|
|
|
260 |
|
261 |
@classmethod
|
262 |
def from_linear(
|
|
|
9 |
from torch.nn import init
|
10 |
import math
|
11 |
from torch.compiler import is_compiling
|
12 |
+
from torch import __version__
|
13 |
+
from torch.version import cuda
|
14 |
|
15 |
+
IS_TORCH_2_4 = __version__ >= (2, 4) and __version__ < (2, 5)
|
16 |
+
LT_TORCH_2_4 = __version__ < (2, 4)
|
17 |
+
if LT_TORCH_2_4:
|
18 |
+
if not hasattr(torch, "_scaled_mm"):
|
19 |
+
raise RuntimeError(
|
20 |
+
"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later."
|
21 |
+
)
|
22 |
+
CUDA_VERSION = float(cuda) if cuda else 0
|
23 |
+
if CUDA_VERSION < 12.4:
|
24 |
+
raise RuntimeError(
|
25 |
+
f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}."
|
26 |
+
)
|
27 |
try:
|
28 |
from cublas_ops import CublasLinear
|
29 |
except ImportError:
|
|
|
257 |
x = self.quantize_input(x)
|
258 |
|
259 |
prev_dims = x.shape[:-1]
|
|
|
260 |
x = x.view(-1, self.in_features)
|
261 |
|
262 |
# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
|
263 |
+
out = torch._scaled_mm(
|
264 |
x,
|
265 |
self.float8_data.T,
|
266 |
+
scale_a=self.input_scale_reciprocal,
|
267 |
+
scale_b=self.scale_reciprocal,
|
268 |
bias=self.bias,
|
269 |
out_dtype=self.weight.dtype,
|
270 |
use_fast_accum=True,
|
271 |
+
)
|
272 |
+
if IS_TORCH_2_4:
|
273 |
+
out = out[0]
|
274 |
+
out = out.view(*prev_dims, self.out_features)
|
275 |
+
return out
|
276 |
|
277 |
@classmethod
|
278 |
def from_linear(
|
requirements.txt
CHANGED
@@ -12,4 +12,5 @@ sentencepiece
|
|
12 |
click
|
13 |
accelerate
|
14 |
quanto
|
15 |
-
pydash
|
|
|
|
12 |
click
|
13 |
accelerate
|
14 |
quanto
|
15 |
+
pydash
|
16 |
+
pybase64
|