aredden commited on
Commit
b6617b1
·
1 Parent(s): 2f2c44c

cuda version checks

Browse files
Files changed (3) hide show
  1. README.md +24 -2
  2. float8_quantize.py +22 -6
  3. requirements.txt +2 -1
README.md CHANGED
@@ -1,4 +1,4 @@
1
- # Flux FP16 Accumulate Model Implementation with FastAPI
2
 
3
  This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. The API can be run via command-line arguments.
4
 
@@ -13,12 +13,34 @@ This repository contains an implementation of the Flux model, along with an API
13
 
14
  ## Installation
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  To install the required dependencies, run:
17
 
18
  ```bash
19
- pip install -r requirements.txt
20
  ```
21
 
 
 
22
  ## Usage
23
 
24
  You can run the API server using the following command:
 
1
+ # Flux FP8 (true) Matmul Implementation with FastAPI
2
 
3
  This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. The API can be run via command-line arguments.
4
 
 
13
 
14
  ## Installation
15
 
16
+ This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
17
+
18
+ ```bash
19
+ mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
20
+ mamba activate flux-fp8-matmul-api
21
+
22
+ # or with conda
23
+ conda create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
24
+ conda activate flux-fp8-matmul-api
25
+
26
+ # or with nightly... (which is what I am using) - also, just switch 'mamba' to 'conda' if you are using conda
27
+ mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch-nightly -c nvidia
28
+ mamba activate flux-fp8-matmul-api
29
+
30
+ # or with pip
31
+ python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
32
+ # or pip nightly
33
+ python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
34
+ ```
35
+
36
  To install the required dependencies, run:
37
 
38
  ```bash
39
+ python -m pip install -r requirements.txt
40
  ```
41
 
42
+ If you get errors installing `torch-cublas-hgemm`, feel free to comment it out in requirements.txt, since it's not necessary, but will speed up inference for non-fp8 linear layers.
43
+
44
  ## Usage
45
 
46
  You can run the API server using the following command:
float8_quantize.py CHANGED
@@ -9,8 +9,21 @@ from torchao.float8.float8_utils import (
9
  from torch.nn import init
10
  import math
11
  from torch.compiler import is_compiling
 
 
12
 
13
-
 
 
 
 
 
 
 
 
 
 
 
14
  try:
15
  from cublas_ops import CublasLinear
16
  except ImportError:
@@ -244,19 +257,22 @@ class F8Linear(nn.Module):
244
  x = self.quantize_input(x)
245
 
246
  prev_dims = x.shape[:-1]
247
-
248
  x = x.view(-1, self.in_features)
249
 
250
  # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
251
- return torch._scaled_mm(
252
  x,
253
  self.float8_data.T,
254
- self.input_scale_reciprocal,
255
- self.scale_reciprocal,
256
  bias=self.bias,
257
  out_dtype=self.weight.dtype,
258
  use_fast_accum=True,
259
- ).view(*prev_dims, self.out_features)
 
 
 
 
260
 
261
  @classmethod
262
  def from_linear(
 
9
  from torch.nn import init
10
  import math
11
  from torch.compiler import is_compiling
12
+ from torch import __version__
13
+ from torch.version import cuda
14
 
15
+ IS_TORCH_2_4 = __version__ >= (2, 4) and __version__ < (2, 5)
16
+ LT_TORCH_2_4 = __version__ < (2, 4)
17
+ if LT_TORCH_2_4:
18
+ if not hasattr(torch, "_scaled_mm"):
19
+ raise RuntimeError(
20
+ "This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later."
21
+ )
22
+ CUDA_VERSION = float(cuda) if cuda else 0
23
+ if CUDA_VERSION < 12.4:
24
+ raise RuntimeError(
25
+ f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}."
26
+ )
27
  try:
28
  from cublas_ops import CublasLinear
29
  except ImportError:
 
257
  x = self.quantize_input(x)
258
 
259
  prev_dims = x.shape[:-1]
 
260
  x = x.view(-1, self.in_features)
261
 
262
  # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
263
+ out = torch._scaled_mm(
264
  x,
265
  self.float8_data.T,
266
+ scale_a=self.input_scale_reciprocal,
267
+ scale_b=self.scale_reciprocal,
268
  bias=self.bias,
269
  out_dtype=self.weight.dtype,
270
  use_fast_accum=True,
271
+ )
272
+ if IS_TORCH_2_4:
273
+ out = out[0]
274
+ out = out.view(*prev_dims, self.out_features)
275
+ return out
276
 
277
  @classmethod
278
  def from_linear(
requirements.txt CHANGED
@@ -12,4 +12,5 @@ sentencepiece
12
  click
13
  accelerate
14
  quanto
15
- pydash
 
 
12
  click
13
  accelerate
14
  quanto
15
+ pydash
16
+ pybase64