aredden commited on
Commit
9dc5b0b
·
1 Parent(s): 58082af

Add all relevent args to argparser & update readme

Browse files
Files changed (3) hide show
  1. README.md +102 -23
  2. main.py +64 -3
  3. util.py +23 -5
README.md CHANGED
@@ -1,6 +1,6 @@
1
  # Flux FP8 (true) Matmul Implementation with FastAPI
2
 
3
- This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. The API can be run via command-line arguments.
4
 
5
  ## Speed Comparison
6
 
@@ -73,13 +73,21 @@ If you get errors installing `torch-cublas-hgemm`, feel free to comment it out i
73
 
74
  ## Usage
75
 
 
 
 
 
 
 
 
 
76
  You can run the API server using the following command:
77
 
78
  ```bash
79
  python main.py --config-path <path_to_config> --port <port_number> --host <host_address>
80
  ```
81
 
82
- ### Command-Line Arguments
83
 
84
  - `--config-path`: Path to the configuration file. If not provided, the model will be loaded from the command line arguments.
85
  - `--port`: Port to run the server on (default: 8088).
@@ -91,17 +99,47 @@ python main.py --config-path <path_to_config> --port <port_number> --host <host_
91
  - `--flux-device`: Device to run the flow model on (default: cuda:0).
92
  - `--text-enc-device`: Device to run the text encoder on (default: cuda:0).
93
  - `--autoencoder-device`: Device to run the autoencoder on (default: cuda:0).
94
- - `--num-to-quant`: Number of linear layers in the flow transformer to quantize (default: 20).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  ## Configuration
97
 
98
  The configuration files are located in the `configs` directory. You can specify different configurations for different model versions and devices.
99
 
100
- Example configuration file (`configs/config-dev.json`):
101
 
102
- ```json
103
  {
104
- "version": "flux-dev",
105
  "params": {
106
  "in_channels": 64,
107
  "vec_in_dim": 768,
@@ -114,7 +152,7 @@ Example configuration file (`configs/config-dev.json`):
114
  "axes_dim": [16, 56, 56],
115
  "theta": 10000,
116
  "qkv_bias": true,
117
- "guidance_embed": true
118
  },
119
  "ae_params": {
120
  "resolution": 256,
@@ -127,23 +165,27 @@ Example configuration file (`configs/config-dev.json`):
127
  "scale_factor": 0.3611,
128
  "shift_factor": 0.1159
129
  },
130
- "ckpt_path": "/path/to/your/flux1-dev.sft",
131
- "ae_path": "/path/to/your/ae.sft",
132
- "repo_id": "black-forest-labs/FLUX.1-dev",
133
- "repo_flow": "flux1-dev.sft",
134
- "repo_ae": "ae.sft",
135
- "text_enc_max_length": 512,
136
- "text_enc_path": "path/to/your/t5-v1_1-xxl-encoder-bf16",
137
- "text_enc_device": "cuda:1",
138
- "ae_device": "cuda:1",
139
  "flux_device": "cuda:0",
140
  "flow_dtype": "float16",
141
  "ae_dtype": "bfloat16",
142
  "text_enc_dtype": "bfloat16",
143
- "text_enc_quantization_dtype": "qfloat8",
144
- "compile_extras": true,
145
- "compile_blocks": true,
146
- ...
 
 
 
 
147
  }
148
  ```
149
 
@@ -157,6 +199,12 @@ The only things you should need to change in general are the:
157
 
158
  Other things to change can be the
159
 
 
 
 
 
 
 
160
  - `"text_enc_quantization_dtype": "qfloat8"`
161
  quantization dtype for the text encoder, if `qfloat8` or `qint2` will use quanto, `qint4`, `qint8` will use bitsandbytes
162
 
@@ -220,9 +268,11 @@ python main.py --port 8088 --host 0.0.0.0 \
220
  --autoencoder-path /path/to/your/ae.sft \
221
  --model-version flux-dev \
222
  --flux-device cuda:0 \
223
- --text-enc-device cuda:1 \
224
- --autoencoder-device cuda:1 \
225
- --num-to-quant 20
 
 
226
  ```
227
 
228
  ### Generating an Image
@@ -263,3 +313,32 @@ with open(f"output.jpg", "wb") as f:
263
  f.write(io.BytesIO(res.content).read())
264
 
265
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Flux FP8 (true) Matmul Implementation with FastAPI
2
 
3
+ This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. And also a simple single line of code to use the generator as a single object, similar to diffusers pipelines.
4
 
5
  ## Speed Comparison
6
 
 
73
 
74
  ## Usage
75
 
76
+ For a single ADA GPU with less than 24GB vram, and more than 16GB vram, you should use the `configs/config-dev-1-4080.json` config file as a base, and then tweak the parameters to fit your needs. It offloads all models to CPU when not in use, compiles the flow model with extra optimizations, and quantizes the text encoder to nf4 and the autoencoder to qfloat8.
77
+
78
+ For a single ADA GPU with more than ~32GB vram, you should use the `configs/config-dev-1-RTX6000ADA.json` config file as a base, and then tweak the parameters to fit your needs. It does not offload any models to CPU, compiles the flow model with extra optimizations, and quantizes the text encoder to qfloat8 and the autoencoder to stays as bfloat16.
79
+
80
+ For a single 4090 GPU, you should use the `configs/config-dev-1-4090.json` config file as a base, and then tweak the parameters to fit your needs. It offloads the text encoder and the autoencoder to CPU, compiles the flow model with extra optimizations, and quantizes the text encoder to nf4 and the autoencoder to float8.
81
+
82
+ **NOTE:** For all of these configs, you must change the `ckpt_path`, `ae_path`, and `text_enc_path` parameters to the path to your own checkpoint, autoencoder, and text encoder.
83
+
84
  You can run the API server using the following command:
85
 
86
  ```bash
87
  python main.py --config-path <path_to_config> --port <port_number> --host <host_address>
88
  ```
89
 
90
+ ### API Command-Line Arguments
91
 
92
  - `--config-path`: Path to the configuration file. If not provided, the model will be loaded from the command line arguments.
93
  - `--port`: Port to run the server on (default: 8088).
 
99
  - `--flux-device`: Device to run the flow model on (default: cuda:0).
100
  - `--text-enc-device`: Device to run the text encoder on (default: cuda:0).
101
  - `--autoencoder-device`: Device to run the autoencoder on (default: cuda:0).
102
+ - `--compile`: Compile the flow model with extra optimizations (default: False).
103
+ - `--quant-text-enc`: Quantize the T5 text encoder to the given dtype (`qint4`, `qfloat8`, `qint2`, `qint8`, `bf16`), if `bf16`, will not quantize (default: `qfloat8`).
104
+ - `--quant-ae`: Quantize the autoencoder with float8 linear layers, otherwise will use bfloat16 (default: False).
105
+ - `--offload-flow`: Offload the flow model to the CPU when not being used to save memory (default: False).
106
+ - `--no-offload-ae`: Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed (default: True).
107
+ - `--no-offload-text-enc`: Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed (default: True).
108
+ - `--prequantized-flow`: Load the flow model from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: False).
109
+
110
+ ## Examples
111
+
112
+ ### Running the Server
113
+
114
+ ```bash
115
+ python main.py --config-path configs/config-dev-1-4090.json --port 8088 --host 0.0.0.0
116
+ ```
117
+
118
+ Or if you need more granular control over the all of the settings, you can run the server with something like this:
119
+
120
+ ```bash
121
+ python main.py --port 8088 --host 0.0.0.0 \
122
+ --flow-model-path /path/to/your/flux1-dev.sft \
123
+ --text-enc-path /path/to/your/t5-v1_1-xxl-encoder-bf16 \
124
+ --autoencoder-path /path/to/your/ae.sft \
125
+ --model-version flux-dev \
126
+ --flux-device cuda:0 \
127
+ --text-enc-device cuda:0 \
128
+ --autoencoder-device cuda:0 \
129
+ --compile \
130
+ --quant-text-enc qfloat8 \
131
+ --quant-ae
132
+ ```
133
 
134
  ## Configuration
135
 
136
  The configuration files are located in the `configs` directory. You can specify different configurations for different model versions and devices.
137
 
138
+ Example configuration file for a single 4090 (`configs/config-dev-1-4090.json`):
139
 
140
+ ```js
141
  {
142
+ "version": "flux-dev", // or flux-schnell
143
  "params": {
144
  "in_channels": 64,
145
  "vec_in_dim": 768,
 
152
  "axes_dim": [16, 56, 56],
153
  "theta": 10000,
154
  "qkv_bias": true,
155
+ "guidance_embed": true // if you are using flux-schnell, set this to false
156
  },
157
  "ae_params": {
158
  "resolution": 256,
 
165
  "scale_factor": 0.3611,
166
  "shift_factor": 0.1159
167
  },
168
+ "ckpt_path": "/your/path/to/flux1-dev.sft", // local path to original bf16 BFL flux checkpoint
169
+ "ae_path": "/your/path/to/ae.sft", // local path to original bf16 BFL autoencoder checkpoint
170
+ "repo_id": "black-forest-labs/FLUX.1-dev", // can ignore
171
+ "repo_flow": "flux1-dev.sft", // can ignore
172
+ "repo_ae": "ae.sft", // can ignore
173
+ "text_enc_max_length": 512, // use 256 if you are using flux-schnell
174
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", // or custom HF full bf16 T5EncoderModel repo id
175
+ "text_enc_device": "cuda:0",
176
+ "ae_device": "cuda:0",
177
  "flux_device": "cuda:0",
178
  "flow_dtype": "float16",
179
  "ae_dtype": "bfloat16",
180
  "text_enc_dtype": "bfloat16",
181
+ "flow_quantization_dtype": "qfloat8", // will always be qfloat8, so can ignore
182
+ "text_enc_quantization_dtype": "qint4", // choose between qint4, qint8, qfloat8, qint2 or delete entry for no quantization
183
+ "ae_quantization_dtype": "qfloat8", // can either be qfloat8 or delete entry for no quantization
184
+ "compile_extras": true, // compile the layers not included in the single-blocks or double-blocks
185
+ "compile_blocks": true, // compile the single-blocks and double-blocks
186
+ "offload_text_encoder": true, // offload the text encoder to cpu when not in use
187
+ "offload_vae": true, // offload the autoencoder to cpu when not in use
188
+ "offload_flow": false // offload the flow transformer to cpu when not in use
189
  }
190
  ```
191
 
 
199
 
200
  Other things to change can be the
201
 
202
+ - `"text_enc_max_length": 512`
203
+ max length for the text encoder, 256 if you are using flux-schnell
204
+
205
+ - `"ae_quantization_dtype": "qfloat8"`
206
+ quantization dtype for the autoencoder, can be `qfloat8` or delete entry for no quantization, will use the float8 linear layer implementation included in this repo.
207
+
208
  - `"text_enc_quantization_dtype": "qfloat8"`
209
  quantization dtype for the text encoder, if `qfloat8` or `qint2` will use quanto, `qint4`, `qint8` will use bitsandbytes
210
 
 
268
  --autoencoder-path /path/to/your/ae.sft \
269
  --model-version flux-dev \
270
  --flux-device cuda:0 \
271
+ --text-enc-device cuda:0 \
272
+ --autoencoder-device cuda:0 \
273
+ --compile \
274
+ --quant-text-enc qfloat8 \
275
+ --quant-ae
276
  ```
277
 
278
  ### Generating an Image
 
313
  f.write(io.BytesIO(res.content).read())
314
 
315
  ```
316
+
317
+ You can also generate an image by directly importing the FluxPipeline class and using it to generate an image. This is useful if you have a custom model configuration and want to generate an image without having to run the server.
318
+
319
+ ```py
320
+ import io
321
+ from flux_pipeline import FluxPipeline
322
+
323
+
324
+ pipe = FluxPipeline.load_pipeline_from_config_path(
325
+ "configs/config-dev-1-4090.json" # or whatever your config is
326
+ )
327
+
328
+ output_jpeg_bytes: io.BytesIO = pipe.generate(
329
+ # Required args:
330
+ prompt="A beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
331
+ # Optional args:
332
+ width=1024,
333
+ height=1024,
334
+ num_inference_steps=20,
335
+ guidance_scale=3.5,
336
+ seed=13456,
337
+ init_image="path/to/your/init_image.jpg",
338
+ strength=0.8,
339
+ )
340
+
341
+ with open("output.jpg", "wb") as f:
342
+ f.write(output_jpeg_bytes.getvalue())
343
+
344
+ ```
main.py CHANGED
@@ -1,8 +1,6 @@
1
  import argparse
2
  import uvicorn
3
  from api import app
4
- from flux_pipeline import FluxPipeline
5
- from util import load_config, ModelVersion
6
 
7
 
8
  def parse_args():
@@ -79,13 +77,68 @@ def parse_args():
79
  default=False,
80
  help="Compile the flow model with extra optimizations",
81
  )
82
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return parser.parse_args()
84
 
85
 
86
  def main():
87
  args = parse_args()
88
 
 
 
 
 
89
  if args.config_path:
90
  app.state.model = FluxPipeline.load_pipeline_from_config_path(
91
  args.config_path, flow_model_path=args.flow_model_path
@@ -110,6 +163,14 @@ def main():
110
  num_to_quant=args.num_to_quant,
111
  compile_extras=args.compile,
112
  compile_blocks=args.compile,
 
 
 
 
 
 
 
 
113
  )
114
  app.state.model = FluxPipeline.load_pipeline_from_config(config)
115
 
 
1
  import argparse
2
  import uvicorn
3
  from api import app
 
 
4
 
5
 
6
  def parse_args():
 
77
  default=False,
78
  help="Compile the flow model with extra optimizations",
79
  )
80
+ parser.add_argument(
81
+ "-qT",
82
+ "--quant-text-enc",
83
+ type=str,
84
+ default="qfloat8",
85
+ choices=["qint4", "qfloat8", "qint2", "qint8", "bf16"],
86
+ help="Quantize the t5 text encoder to the given dtype, if bf16, will not quantize",
87
+ dest="quant_text_enc",
88
+ )
89
+ parser.add_argument(
90
+ "-qA",
91
+ "--quant-ae",
92
+ action="store_true",
93
+ default=False,
94
+ help="Quantize the autoencoder with float8 linear layers, otherwise will use bfloat16",
95
+ dest="quant_ae",
96
+ )
97
+ parser.add_argument(
98
+ "-OF",
99
+ "--offload-flow",
100
+ action="store_true",
101
+ default=False,
102
+ dest="offload_flow",
103
+ help="Offload the flow model to the CPU when not being used to save memory",
104
+ )
105
+ parser.add_argument(
106
+ "-OA",
107
+ "--no-offload-ae",
108
+ action="store_false",
109
+ default=True,
110
+ dest="offload_ae",
111
+ help="Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed",
112
+ )
113
+ parser.add_argument(
114
+ "-OT",
115
+ "--no-offload-text-enc",
116
+ action="store_false",
117
+ default=True,
118
+ dest="offload_text_enc",
119
+ help="Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed",
120
+ )
121
+ parser.add_argument(
122
+ "-PF",
123
+ "--prequantized-flow",
124
+ action="store_true",
125
+ default=False,
126
+ dest="prequantized_flow",
127
+ help="Load the flow model from a prequantized checkpoint "
128
+ + "(requires loading the flow model, running a minimum of 24 steps, "
129
+ + "and then saving the state_dict as a safetensors file), "
130
+ + "which reduces the size of the checkpoint by about 50% & reduces startup time",
131
+ )
132
  return parser.parse_args()
133
 
134
 
135
  def main():
136
  args = parse_args()
137
 
138
+ # lazy loading so cli returns fast instead of waiting for torch to load modules
139
+ from flux_pipeline import FluxPipeline
140
+ from util import load_config, ModelVersion
141
+
142
  if args.config_path:
143
  app.state.model = FluxPipeline.load_pipeline_from_config_path(
144
  args.config_path, flow_model_path=args.flow_model_path
 
163
  num_to_quant=args.num_to_quant,
164
  compile_extras=args.compile,
165
  compile_blocks=args.compile,
166
+ quant_text_enc=(
167
+ None if args.quant_text_enc == "bf16" else args.quant_text_enc
168
+ ),
169
+ quant_ae=args.quant_ae,
170
+ offload_flow=args.offload_flow,
171
+ offload_ae=args.offload_ae,
172
+ offload_text_enc=args.offload_text_enc,
173
+ prequantized_flow=args.prequantized_flow,
174
  )
175
  app.state.model = FluxPipeline.load_pipeline_from_config(config)
176
 
util.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  from pathlib import Path
3
- from typing import Optional
4
 
5
  import torch
6
  from modules.autoencoder import AutoEncoder, AutoEncoderParams
@@ -113,7 +113,16 @@ def load_config(
113
  num_to_quant: Optional[int] = 20,
114
  compile_extras: bool = False,
115
  compile_blocks: bool = False,
116
- ):
 
 
 
 
 
 
 
 
 
117
  text_enc_device = str(parse_device(text_enc_device))
118
  ae_device = str(parse_device(ae_device))
119
  flux_device = str(parse_device(flux_device))
@@ -166,6 +175,17 @@ def load_config(
166
  num_to_quant=num_to_quant,
167
  compile_extras=compile_extras,
168
  compile_blocks=compile_blocks,
 
 
 
 
 
 
 
 
 
 
 
169
  )
170
 
171
 
@@ -193,12 +213,10 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
193
  )
194
 
195
 
196
- def load_flow_model(config: ModelSpec) -> Flux:
197
  ckpt_path = config.ckpt_path
198
  FluxClass = Flux
199
  if config.prequantized_flow:
200
- from modules.flux_model_f8 import Flux as FluxF8
201
-
202
  FluxClass = FluxF8
203
 
204
  with torch.device("meta"):
 
1
  import json
2
  from pathlib import Path
3
+ from typing import Literal, Optional
4
 
5
  import torch
6
  from modules.autoencoder import AutoEncoder, AutoEncoderParams
 
113
  num_to_quant: Optional[int] = 20,
114
  compile_extras: bool = False,
115
  compile_blocks: bool = False,
116
+ offload_text_enc: bool = False,
117
+ offload_ae: bool = False,
118
+ offload_flow: bool = False,
119
+ quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
120
+ quant_ae: bool = False,
121
+ prequantized_flow: bool = False,
122
+ ) -> ModelSpec:
123
+ """
124
+ Load a model configuration using the passed arguments.
125
+ """
126
  text_enc_device = str(parse_device(text_enc_device))
127
  ae_device = str(parse_device(ae_device))
128
  flux_device = str(parse_device(flux_device))
 
175
  num_to_quant=num_to_quant,
176
  compile_extras=compile_extras,
177
  compile_blocks=compile_blocks,
178
+ offload_flow=offload_flow,
179
+ offload_text_encoder=offload_text_enc,
180
+ offload_vae=offload_ae,
181
+ text_enc_quantization_dtype={
182
+ "float8": QuantizationDtype.qfloat8,
183
+ "qint2": QuantizationDtype.qint2,
184
+ "qint4": QuantizationDtype.qint4,
185
+ "qint8": QuantizationDtype.qint8,
186
+ }.get(quant_text_enc, None),
187
+ ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
188
+ prequantized_flow=prequantized_flow,
189
  )
190
 
191
 
 
213
  )
214
 
215
 
216
+ def load_flow_model(config: ModelSpec) -> Flux | FluxF8:
217
  ckpt_path = config.ckpt_path
218
  FluxClass = Flux
219
  if config.prequantized_flow:
 
 
220
  FluxClass = FluxF8
221
 
222
  with torch.device("meta"):