Fix non-offload inference & add option to load from prequantized flux
Browse files- configs/config-dev-prequant.json +56 -0
- float8_quantize.py +99 -2
- flux_pipeline.py +29 -19
- modules/flux_model_f8.py +491 -0
- util.py +9 -3
configs/config-dev-prequant.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/flux-fp16-acc/flux_fp8.safetensors",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:1",
|
45 |
+
"ae_device": "cuda:1",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qfloat8",
|
52 |
+
"num_to_quant": 22,
|
53 |
+
"compile_extras": true,
|
54 |
+
"compile_blocks": true,
|
55 |
+
"prequantized_flow": true
|
56 |
+
}
|
float8_quantize.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from torchao.float8.float8_utils import (
|
@@ -16,6 +17,24 @@ except ImportError:
|
|
16 |
CublasLinear = type(None)
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class F8Linear(nn.Module):
|
20 |
|
21 |
def __init__(
|
@@ -24,7 +43,7 @@ class F8Linear(nn.Module):
|
|
24 |
out_features: int,
|
25 |
bias: bool = True,
|
26 |
device=None,
|
27 |
-
dtype=
|
28 |
float8_dtype=torch.float8_e4m3fn,
|
29 |
float_weight: torch.Tensor = None,
|
30 |
float_bias: torch.Tensor = None,
|
@@ -53,7 +72,6 @@ class F8Linear(nn.Module):
|
|
53 |
if bias:
|
54 |
self.bias = nn.Parameter(
|
55 |
torch.empty(out_features, **factory_kwargs),
|
56 |
-
requires_grad=bias.requires_grad,
|
57 |
)
|
58 |
else:
|
59 |
self.register_parameter("bias", None)
|
@@ -78,6 +96,85 @@ class F8Linear(nn.Module):
|
|
78 |
"input_scale_reciprocal", None
|
79 |
)
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
def quantize_weight(self):
|
82 |
if self.weight_initialized:
|
83 |
return
|
|
|
1 |
+
from typing import Any, Mapping
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
from torchao.float8.float8_utils import (
|
|
|
17 |
CublasLinear = type(None)
|
18 |
|
19 |
|
20 |
+
def check_scale_tensor(tensor):
|
21 |
+
return (
|
22 |
+
tensor is not None
|
23 |
+
and isinstance(tensor, torch.Tensor)
|
24 |
+
and tensor.dtype == torch.float32
|
25 |
+
and tensor.numel() == 1
|
26 |
+
and tensor != torch.zeros_like(tensor)
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def check_scale_in_state_dict(state_dict, key):
|
31 |
+
return key in state_dict and check_scale_tensor(state_dict[key])
|
32 |
+
|
33 |
+
|
34 |
+
def check_scales_given_state_dict_and_keys(state_dict, keys):
|
35 |
+
return all(check_scale_in_state_dict(state_dict, key) for key in keys)
|
36 |
+
|
37 |
+
|
38 |
class F8Linear(nn.Module):
|
39 |
|
40 |
def __init__(
|
|
|
43 |
out_features: int,
|
44 |
bias: bool = True,
|
45 |
device=None,
|
46 |
+
dtype=torch.float16,
|
47 |
float8_dtype=torch.float8_e4m3fn,
|
48 |
float_weight: torch.Tensor = None,
|
49 |
float_bias: torch.Tensor = None,
|
|
|
72 |
if bias:
|
73 |
self.bias = nn.Parameter(
|
74 |
torch.empty(out_features, **factory_kwargs),
|
|
|
75 |
)
|
76 |
else:
|
77 |
self.register_parameter("bias", None)
|
|
|
96 |
"input_scale_reciprocal", None
|
97 |
)
|
98 |
|
99 |
+
def _load_from_state_dict(
|
100 |
+
self,
|
101 |
+
state_dict,
|
102 |
+
prefix,
|
103 |
+
local_metadata,
|
104 |
+
strict,
|
105 |
+
missing_keys,
|
106 |
+
unexpected_keys,
|
107 |
+
error_msgs,
|
108 |
+
):
|
109 |
+
sd = {k.replace(prefix, ""): v for k, v in state_dict.items()}
|
110 |
+
if "weight" in sd:
|
111 |
+
if (
|
112 |
+
"float8_data" not in sd
|
113 |
+
or sd["float8_data"] is None
|
114 |
+
and sd["weight"].shape == (self.out_features, self.in_features)
|
115 |
+
):
|
116 |
+
# Initialize as if it's an F8Linear that needs to be quantized
|
117 |
+
self._parameters["weight"] = nn.Parameter(
|
118 |
+
sd["weight"], requires_grad=False
|
119 |
+
)
|
120 |
+
if "bias" in sd:
|
121 |
+
self._parameters["bias"] = nn.Parameter(
|
122 |
+
sd["bias"], requires_grad=False
|
123 |
+
)
|
124 |
+
self.quantize_weight()
|
125 |
+
elif sd["float8_data"].shape == (
|
126 |
+
self.out_features,
|
127 |
+
self.in_features,
|
128 |
+
) and sd["weight"] == torch.zeros_like(sd["weight"]):
|
129 |
+
w = sd["weight"]
|
130 |
+
# Set the init values as if it's already quantized float8_data
|
131 |
+
self.float8_data = sd["float8_data"]
|
132 |
+
self._parameters["weight"] = nn.Parameter(
|
133 |
+
torch.zeros(
|
134 |
+
1,
|
135 |
+
dtype=w.dtype,
|
136 |
+
device=w.device,
|
137 |
+
requires_grad=False,
|
138 |
+
)
|
139 |
+
)
|
140 |
+
if "bias" in sd:
|
141 |
+
self._parameters["bias"] = nn.Parameter(
|
142 |
+
sd["bias"], requires_grad=False
|
143 |
+
)
|
144 |
+
self.weight_initialized = True
|
145 |
+
|
146 |
+
# Check if scales and reciprocals are initialized
|
147 |
+
if all(
|
148 |
+
key in sd
|
149 |
+
for key in [
|
150 |
+
"scale",
|
151 |
+
"input_scale",
|
152 |
+
"scale_reciprocal",
|
153 |
+
"input_scale_reciprocal",
|
154 |
+
]
|
155 |
+
):
|
156 |
+
self.scale = sd["scale"].float()
|
157 |
+
self.input_scale = sd["input_scale"].float()
|
158 |
+
self.scale_reciprocal = sd["scale_reciprocal"].float()
|
159 |
+
self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
|
160 |
+
self.input_scale_initialized = True
|
161 |
+
self.trial_index = self.num_scale_trials
|
162 |
+
else:
|
163 |
+
# If scales are not initialized, reset trials
|
164 |
+
self.input_scale_initialized = False
|
165 |
+
self.trial_index = 0
|
166 |
+
self.input_amax_trials = torch.zeros(
|
167 |
+
self.num_scale_trials, requires_grad=False, dtype=torch.float32
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
raise RuntimeError(
|
171 |
+
f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}"
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
raise RuntimeError(
|
175 |
+
"Weight tensor not found or has incorrect shape in state dict"
|
176 |
+
)
|
177 |
+
|
178 |
def quantize_weight(self):
|
179 |
if self.weight_initialized:
|
180 |
return
|
flux_pipeline.py
CHANGED
@@ -92,18 +92,26 @@ class FluxPipeline:
|
|
92 |
self.offload_text_encoder = config.offload_text_encoder
|
93 |
self.offload_vae = config.offload_vae
|
94 |
self.offload_flow = config.offload_flow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
if self.config.compile_blocks or self.config.compile_extras:
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
107 |
to_gpu_extras = [
|
108 |
"vector_in",
|
109 |
"img_in",
|
@@ -247,7 +255,7 @@ class FluxPipeline:
|
|
247 |
im = torch.vstack(images)
|
248 |
|
249 |
torch.cuda.synchronize()
|
250 |
-
im = self.
|
251 |
images.clear()
|
252 |
return io.BytesIO(im)
|
253 |
|
@@ -458,6 +466,7 @@ class FluxPipeline:
|
|
458 |
|
459 |
with torch.inference_mode():
|
460 |
print("flow_quantization_dtype", config.flow_quantization_dtype)
|
|
|
461 |
|
462 |
models = load_models_from_config(config)
|
463 |
config = models.config
|
@@ -466,13 +475,14 @@ class FluxPipeline:
|
|
466 |
clip_device = into_device(config.text_enc_device)
|
467 |
t5_device = into_device(config.text_enc_device)
|
468 |
flux_dtype = into_dtype(config.flow_dtype)
|
469 |
-
flow_model = models.flow
|
470 |
-
memory_format=torch.channels_last
|
471 |
-
)
|
472 |
|
473 |
-
|
474 |
-
flow_model
|
475 |
-
|
|
|
|
|
|
|
476 |
|
477 |
return cls(
|
478 |
name=config.version,
|
@@ -492,7 +502,7 @@ class FluxPipeline:
|
|
492 |
|
493 |
if __name__ == "__main__":
|
494 |
pipe = FluxPipeline.load_pipeline_from_config_path(
|
495 |
-
"configs/config-dev-
|
496 |
)
|
497 |
o = pipe.generate(
|
498 |
prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
|
@@ -503,7 +513,7 @@ if __name__ == "__main__":
|
|
503 |
seed=10,
|
504 |
)
|
505 |
open("out.jpg", "wb").write(o.read())
|
506 |
-
for x in range(
|
507 |
|
508 |
o = pipe.generate(
|
509 |
prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
|
|
|
92 |
self.offload_text_encoder = config.offload_text_encoder
|
93 |
self.offload_vae = config.offload_vae
|
94 |
self.offload_flow = config.offload_flow
|
95 |
+
if not self.offload_flow:
|
96 |
+
self.model.to(self.device_flux)
|
97 |
+
if not self.offload_vae:
|
98 |
+
self.ae.to(self.device_ae)
|
99 |
+
if not self.offload_text_encoder:
|
100 |
+
self.clip.to(self.device_clip)
|
101 |
+
self.t5.to(self.device_t5)
|
102 |
|
103 |
if self.config.compile_blocks or self.config.compile_extras:
|
104 |
+
if not self.config.prequantized_flow:
|
105 |
+
print("Warmups for compile...")
|
106 |
+
warmup_dict = dict(
|
107 |
+
prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
|
108 |
+
height=1024,
|
109 |
+
width=1024,
|
110 |
+
num_steps=30,
|
111 |
+
guidance=3.5,
|
112 |
+
seed=10,
|
113 |
+
)
|
114 |
+
self.generate(**warmup_dict)
|
115 |
to_gpu_extras = [
|
116 |
"vector_in",
|
117 |
"img_in",
|
|
|
255 |
im = torch.vstack(images)
|
256 |
|
257 |
torch.cuda.synchronize()
|
258 |
+
im = self.img_encoder.encode_torch(im, quality=99)
|
259 |
images.clear()
|
260 |
return io.BytesIO(im)
|
261 |
|
|
|
466 |
|
467 |
with torch.inference_mode():
|
468 |
print("flow_quantization_dtype", config.flow_quantization_dtype)
|
469 |
+
print("prequantized_flow?", config.prequantized_flow)
|
470 |
|
471 |
models = load_models_from_config(config)
|
472 |
config = models.config
|
|
|
475 |
clip_device = into_device(config.text_enc_device)
|
476 |
t5_device = into_device(config.text_enc_device)
|
477 |
flux_dtype = into_dtype(config.flow_dtype)
|
478 |
+
flow_model = models.flow
|
|
|
|
|
479 |
|
480 |
+
if not config.prequantized_flow:
|
481 |
+
flow_model = quantize_flow_transformer_and_dispatch_float8(
|
482 |
+
flow_model, flux_device, offload_flow=config.offload_flow
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
flow_model.eval().requires_grad_(False)
|
486 |
|
487 |
return cls(
|
488 |
name=config.version,
|
|
|
502 |
|
503 |
if __name__ == "__main__":
|
504 |
pipe = FluxPipeline.load_pipeline_from_config_path(
|
505 |
+
"configs/config-dev-prequant.json",
|
506 |
)
|
507 |
o = pipe.generate(
|
508 |
prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
|
|
|
513 |
seed=10,
|
514 |
)
|
515 |
open("out.jpg", "wb").write(o.read())
|
516 |
+
for x in range(2):
|
517 |
|
518 |
o = pipe.generate(
|
519 |
prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
|
modules/flux_model_f8.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
|
6 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
7 |
+
torch.backends.cudnn.allow_tf32 = True
|
8 |
+
torch.backends.cudnn.benchmark = True
|
9 |
+
torch.backends.cudnn.benchmark_limit = 20
|
10 |
+
torch.set_float32_matmul_precision("high")
|
11 |
+
import math
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
from pydantic import BaseModel
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from float8_quantize import F8Linear
|
17 |
+
|
18 |
+
try:
|
19 |
+
from cublas_ops import CublasLinear
|
20 |
+
except ImportError:
|
21 |
+
CublasLinear = nn.Linear
|
22 |
+
|
23 |
+
|
24 |
+
class FluxParams(BaseModel):
|
25 |
+
in_channels: int
|
26 |
+
vec_in_dim: int
|
27 |
+
context_in_dim: int
|
28 |
+
hidden_size: int
|
29 |
+
mlp_ratio: float
|
30 |
+
num_heads: int
|
31 |
+
depth: int
|
32 |
+
depth_single_blocks: int
|
33 |
+
axes_dim: list[int]
|
34 |
+
theta: int
|
35 |
+
qkv_bias: bool
|
36 |
+
guidance_embed: bool
|
37 |
+
|
38 |
+
|
39 |
+
# attention is always same shape each time it's called per H*W, so compile with fullgraph
|
40 |
+
# @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
|
41 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
42 |
+
q, k = apply_rope(q, k, pe)
|
43 |
+
x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
|
44 |
+
x = x.reshape(*x.shape[:-2], -1)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
# @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
|
49 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
50 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
51 |
+
omega = 1.0 / (theta**scale)
|
52 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
53 |
+
out = torch.stack(
|
54 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
55 |
+
)
|
56 |
+
out = out.reshape(*out.shape[:-1], 2, 2)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
61 |
+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
|
62 |
+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
|
63 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
64 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
65 |
+
return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
|
66 |
+
|
67 |
+
|
68 |
+
class EmbedND(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
dim: int,
|
72 |
+
theta: int,
|
73 |
+
axes_dim: list[int],
|
74 |
+
dtype: torch.dtype = torch.bfloat16,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.dim = dim
|
78 |
+
self.theta = theta
|
79 |
+
self.axes_dim = axes_dim
|
80 |
+
self.dtype = dtype
|
81 |
+
|
82 |
+
def forward(self, ids: Tensor) -> Tensor:
|
83 |
+
n_axes = ids.shape[-1]
|
84 |
+
emb = torch.cat(
|
85 |
+
[
|
86 |
+
rope(ids[..., i], self.axes_dim[i], self.theta).type(self.dtype)
|
87 |
+
for i in range(n_axes)
|
88 |
+
],
|
89 |
+
dim=-3,
|
90 |
+
)
|
91 |
+
|
92 |
+
return emb.unsqueeze(1)
|
93 |
+
|
94 |
+
|
95 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
96 |
+
"""
|
97 |
+
Create sinusoidal timestep embeddings.
|
98 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
99 |
+
These may be fractional.
|
100 |
+
:param dim: the dimension of the output.
|
101 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
102 |
+
:return: an (N, D) Tensor of positional embeddings.
|
103 |
+
"""
|
104 |
+
t = time_factor * t
|
105 |
+
half = dim // 2
|
106 |
+
freqs = torch.exp(
|
107 |
+
-math.log(max_period)
|
108 |
+
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
109 |
+
/ half
|
110 |
+
)
|
111 |
+
|
112 |
+
args = t[:, None].float() * freqs[None]
|
113 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
114 |
+
if dim % 2:
|
115 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
116 |
+
return embedding
|
117 |
+
|
118 |
+
|
119 |
+
class MLPEmbedder(nn.Module):
|
120 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
121 |
+
super().__init__()
|
122 |
+
self.in_layer = F8Linear(in_dim, hidden_dim, bias=True)
|
123 |
+
self.silu = nn.SiLU()
|
124 |
+
self.out_layer = F8Linear(hidden_dim, hidden_dim, bias=True)
|
125 |
+
|
126 |
+
def forward(self, x: Tensor) -> Tensor:
|
127 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
128 |
+
|
129 |
+
|
130 |
+
class RMSNorm(torch.nn.Module):
|
131 |
+
def __init__(self, dim: int):
|
132 |
+
super().__init__()
|
133 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
134 |
+
|
135 |
+
def forward(self, x: Tensor):
|
136 |
+
return F.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
137 |
+
|
138 |
+
|
139 |
+
class QKNorm(torch.nn.Module):
|
140 |
+
def __init__(self, dim: int):
|
141 |
+
super().__init__()
|
142 |
+
self.query_norm = RMSNorm(dim)
|
143 |
+
self.key_norm = RMSNorm(dim)
|
144 |
+
|
145 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
146 |
+
q = self.query_norm(q)
|
147 |
+
k = self.key_norm(k)
|
148 |
+
return q, k
|
149 |
+
|
150 |
+
|
151 |
+
class SelfAttention(nn.Module):
|
152 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
153 |
+
super().__init__()
|
154 |
+
self.num_heads = num_heads
|
155 |
+
head_dim = dim // num_heads
|
156 |
+
|
157 |
+
self.qkv = F8Linear(dim, dim * 3, bias=qkv_bias)
|
158 |
+
self.norm = QKNorm(head_dim)
|
159 |
+
self.proj = F8Linear(dim, dim)
|
160 |
+
self.K = 3
|
161 |
+
self.H = self.num_heads
|
162 |
+
self.KH = self.K * self.H
|
163 |
+
|
164 |
+
def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
165 |
+
B, L, D = x.shape
|
166 |
+
q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
167 |
+
return q, k, v
|
168 |
+
|
169 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
170 |
+
qkv = self.qkv(x)
|
171 |
+
q, k, v = self.rearrange_for_norm(qkv)
|
172 |
+
q, k = self.norm(q, k, v)
|
173 |
+
x = attention(q, k, v, pe=pe)
|
174 |
+
x = self.proj(x)
|
175 |
+
return x
|
176 |
+
|
177 |
+
|
178 |
+
ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
|
179 |
+
|
180 |
+
|
181 |
+
class Modulation(nn.Module):
|
182 |
+
def __init__(self, dim: int, double: bool):
|
183 |
+
super().__init__()
|
184 |
+
self.is_double = double
|
185 |
+
self.multiplier = 6 if double else 3
|
186 |
+
self.lin = F8Linear(dim, self.multiplier * dim, bias=True)
|
187 |
+
self.act = nn.SiLU()
|
188 |
+
|
189 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
190 |
+
out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
191 |
+
|
192 |
+
return (
|
193 |
+
ModulationOut(*out[:3]),
|
194 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
class DoubleStreamBlock(nn.Module):
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
hidden_size: int,
|
202 |
+
num_heads: int,
|
203 |
+
mlp_ratio: float,
|
204 |
+
qkv_bias: bool = False,
|
205 |
+
dtype: torch.dtype = torch.float16,
|
206 |
+
):
|
207 |
+
super().__init__()
|
208 |
+
self.dtype = dtype
|
209 |
+
|
210 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
211 |
+
self.num_heads = num_heads
|
212 |
+
self.hidden_size = hidden_size
|
213 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
214 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
215 |
+
self.img_attn = SelfAttention(
|
216 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
217 |
+
)
|
218 |
+
|
219 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
220 |
+
self.img_mlp = nn.Sequential(
|
221 |
+
F8Linear(hidden_size, mlp_hidden_dim, bias=True),
|
222 |
+
nn.GELU(approximate="tanh"),
|
223 |
+
F8Linear(mlp_hidden_dim, hidden_size, bias=True),
|
224 |
+
)
|
225 |
+
|
226 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
227 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
228 |
+
self.txt_attn = SelfAttention(
|
229 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
230 |
+
)
|
231 |
+
|
232 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
233 |
+
self.txt_mlp = nn.Sequential(
|
234 |
+
F8Linear(hidden_size, mlp_hidden_dim, bias=True),
|
235 |
+
nn.GELU(approximate="tanh"),
|
236 |
+
F8Linear(mlp_hidden_dim, hidden_size, bias=True),
|
237 |
+
)
|
238 |
+
self.K = 3
|
239 |
+
self.H = self.num_heads
|
240 |
+
self.KH = self.K * self.H
|
241 |
+
|
242 |
+
def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
243 |
+
B, L, D = x.shape
|
244 |
+
q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
245 |
+
return q, k, v
|
246 |
+
|
247 |
+
def forward(
|
248 |
+
self,
|
249 |
+
img: Tensor,
|
250 |
+
txt: Tensor,
|
251 |
+
vec: Tensor,
|
252 |
+
pe: Tensor,
|
253 |
+
) -> tuple[Tensor, Tensor]:
|
254 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
255 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
256 |
+
|
257 |
+
# prepare image for attention
|
258 |
+
img_modulated = self.img_norm1(img)
|
259 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
260 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
261 |
+
img_q, img_k, img_v = self.rearrange_for_norm(img_qkv)
|
262 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
263 |
+
|
264 |
+
# prepare txt for attention
|
265 |
+
txt_modulated = self.txt_norm1(txt)
|
266 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
267 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
268 |
+
txt_q, txt_k, txt_v = self.rearrange_for_norm(txt_qkv)
|
269 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
270 |
+
|
271 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
272 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
273 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
274 |
+
|
275 |
+
attn = attention(q, k, v, pe=pe)
|
276 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
277 |
+
# calculate the img bloks
|
278 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
279 |
+
img = img + img_mod2.gate * self.img_mlp(
|
280 |
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
281 |
+
).clamp(min=-384 * 2, max=384 * 2)
|
282 |
+
|
283 |
+
# calculate the txt bloks
|
284 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
285 |
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
286 |
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
287 |
+
).clamp(min=-384 * 2, max=384 * 2)
|
288 |
+
|
289 |
+
return img, txt
|
290 |
+
|
291 |
+
|
292 |
+
class SingleStreamBlock(nn.Module):
|
293 |
+
"""
|
294 |
+
A DiT block with parallel linear layers as described in
|
295 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
296 |
+
"""
|
297 |
+
|
298 |
+
def __init__(
|
299 |
+
self,
|
300 |
+
hidden_size: int,
|
301 |
+
num_heads: int,
|
302 |
+
mlp_ratio: float = 4.0,
|
303 |
+
qk_scale: float | None = None,
|
304 |
+
dtype: torch.dtype = torch.float16,
|
305 |
+
):
|
306 |
+
super().__init__()
|
307 |
+
self.dtype = dtype
|
308 |
+
self.hidden_dim = hidden_size
|
309 |
+
self.num_heads = num_heads
|
310 |
+
head_dim = hidden_size // num_heads
|
311 |
+
self.scale = qk_scale or head_dim**-0.5
|
312 |
+
|
313 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
314 |
+
# qkv and mlp_in
|
315 |
+
self.linear1 = F8Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
316 |
+
# proj and mlp_out
|
317 |
+
self.linear2 = F8Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
318 |
+
|
319 |
+
self.norm = QKNorm(head_dim)
|
320 |
+
|
321 |
+
self.hidden_size = hidden_size
|
322 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
323 |
+
|
324 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
325 |
+
self.modulation = Modulation(hidden_size, double=False)
|
326 |
+
|
327 |
+
self.K = 3
|
328 |
+
self.H = self.num_heads
|
329 |
+
self.KH = self.K * self.H
|
330 |
+
|
331 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
332 |
+
mod = self.modulation(vec)[0]
|
333 |
+
pre_norm = self.pre_norm(x)
|
334 |
+
x_mod = (1 + mod.scale) * pre_norm + mod.shift
|
335 |
+
qkv, mlp = torch.split(
|
336 |
+
self.linear1(x_mod),
|
337 |
+
[3 * self.hidden_size, self.mlp_hidden_dim],
|
338 |
+
dim=-1,
|
339 |
+
)
|
340 |
+
B, L, D = qkv.shape
|
341 |
+
q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
342 |
+
q, k = self.norm(q, k, v)
|
343 |
+
attn = attention(q, k, v, pe=pe)
|
344 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
|
345 |
+
min=-384 * 4, max=384 * 4
|
346 |
+
)
|
347 |
+
return x + mod.gate * output
|
348 |
+
|
349 |
+
|
350 |
+
class LastLayer(nn.Module):
|
351 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
352 |
+
super().__init__()
|
353 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
354 |
+
self.linear = CublasLinear(
|
355 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
356 |
+
)
|
357 |
+
self.adaLN_modulation = nn.Sequential(
|
358 |
+
nn.SiLU(), CublasLinear(hidden_size, 2 * hidden_size, bias=True)
|
359 |
+
)
|
360 |
+
|
361 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
362 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
363 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
364 |
+
x = self.linear(x)
|
365 |
+
return x
|
366 |
+
|
367 |
+
|
368 |
+
class Flux(nn.Module):
|
369 |
+
"""
|
370 |
+
Transformer model for flow matching on sequences.
|
371 |
+
"""
|
372 |
+
|
373 |
+
def __init__(self, params: FluxParams, dtype: torch.dtype = torch.bfloat16):
|
374 |
+
super().__init__()
|
375 |
+
|
376 |
+
self.dtype = dtype
|
377 |
+
self.params = params
|
378 |
+
self.in_channels = params.in_channels
|
379 |
+
self.out_channels = self.in_channels
|
380 |
+
if params.hidden_size % params.num_heads != 0:
|
381 |
+
raise ValueError(
|
382 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
383 |
+
)
|
384 |
+
pe_dim = params.hidden_size // params.num_heads
|
385 |
+
if sum(params.axes_dim) != pe_dim:
|
386 |
+
raise ValueError(
|
387 |
+
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
388 |
+
)
|
389 |
+
self.hidden_size = params.hidden_size
|
390 |
+
self.num_heads = params.num_heads
|
391 |
+
self.pe_embedder = EmbedND(
|
392 |
+
dim=pe_dim,
|
393 |
+
theta=params.theta,
|
394 |
+
axes_dim=params.axes_dim,
|
395 |
+
dtype=self.dtype,
|
396 |
+
)
|
397 |
+
self.img_in = F8Linear(self.in_channels, self.hidden_size, bias=True)
|
398 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
399 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
400 |
+
self.guidance_in = (
|
401 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
402 |
+
if params.guidance_embed
|
403 |
+
else nn.Identity()
|
404 |
+
)
|
405 |
+
self.txt_in = F8Linear(params.context_in_dim, self.hidden_size)
|
406 |
+
|
407 |
+
self.double_blocks = nn.ModuleList(
|
408 |
+
[
|
409 |
+
DoubleStreamBlock(
|
410 |
+
self.hidden_size,
|
411 |
+
self.num_heads,
|
412 |
+
mlp_ratio=params.mlp_ratio,
|
413 |
+
qkv_bias=params.qkv_bias,
|
414 |
+
dtype=self.dtype,
|
415 |
+
)
|
416 |
+
for _ in range(params.depth)
|
417 |
+
]
|
418 |
+
)
|
419 |
+
|
420 |
+
self.single_blocks = nn.ModuleList(
|
421 |
+
[
|
422 |
+
SingleStreamBlock(
|
423 |
+
self.hidden_size,
|
424 |
+
self.num_heads,
|
425 |
+
mlp_ratio=params.mlp_ratio,
|
426 |
+
dtype=self.dtype,
|
427 |
+
)
|
428 |
+
for _ in range(params.depth_single_blocks)
|
429 |
+
]
|
430 |
+
)
|
431 |
+
|
432 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
433 |
+
|
434 |
+
def forward(
|
435 |
+
self,
|
436 |
+
img: Tensor,
|
437 |
+
img_ids: Tensor,
|
438 |
+
txt: Tensor,
|
439 |
+
txt_ids: Tensor,
|
440 |
+
timesteps: Tensor,
|
441 |
+
y: Tensor,
|
442 |
+
guidance: Tensor | None = None,
|
443 |
+
) -> Tensor:
|
444 |
+
if img.ndim != 3 or txt.ndim != 3:
|
445 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
446 |
+
|
447 |
+
# running on sequences img
|
448 |
+
img = self.img_in(img)
|
449 |
+
vec = self.time_in(timestep_embedding(timesteps, 256).type(self.dtype))
|
450 |
+
|
451 |
+
if self.params.guidance_embed:
|
452 |
+
if guidance is None:
|
453 |
+
raise ValueError(
|
454 |
+
"Didn't get guidance strength for guidance distilled model."
|
455 |
+
)
|
456 |
+
vec = vec + self.guidance_in(
|
457 |
+
timestep_embedding(guidance, 256).type(self.dtype)
|
458 |
+
)
|
459 |
+
vec = vec + self.vector_in(y)
|
460 |
+
|
461 |
+
txt = self.txt_in(txt)
|
462 |
+
|
463 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
464 |
+
pe = self.pe_embedder(ids)
|
465 |
+
|
466 |
+
# double stream blocks
|
467 |
+
for block in self.double_blocks:
|
468 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
469 |
+
|
470 |
+
img = torch.cat((txt, img), 1)
|
471 |
+
|
472 |
+
# single stream blocks
|
473 |
+
for block in self.single_blocks:
|
474 |
+
img = block(img, vec=vec, pe=pe)
|
475 |
+
|
476 |
+
img = img[:, txt.shape[1] :, ...]
|
477 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
478 |
+
return img
|
479 |
+
|
480 |
+
@classmethod
|
481 |
+
def from_pretrained(cls, path: str, dtype: torch.dtype = torch.bfloat16) -> "Flux":
|
482 |
+
from util import load_config_from_path
|
483 |
+
from safetensors.torch import load_file
|
484 |
+
|
485 |
+
config = load_config_from_path(path)
|
486 |
+
with torch.device("meta"):
|
487 |
+
klass = cls(params=config.params, dtype=dtype).type(dtype)
|
488 |
+
|
489 |
+
ckpt = load_file(config.ckpt_path, device="cpu")
|
490 |
+
klass.load_state_dict(ckpt, assign=True)
|
491 |
+
return klass.to("cpu")
|
util.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
from modules.autoencoder import AutoEncoder, AutoEncoderParams
|
7 |
from modules.conditioner import HFEmbedder
|
8 |
from modules.flux_model import Flux, FluxParams
|
9 |
-
|
10 |
from safetensors.torch import load_file as load_sft
|
11 |
from enum import StrEnum
|
12 |
from pydantic import BaseModel, ConfigDict
|
@@ -53,6 +53,7 @@ class ModelSpec(BaseModel):
|
|
53 |
offload_text_encoder: bool = False
|
54 |
offload_vae: bool = False
|
55 |
offload_flow: bool = False
|
|
|
56 |
|
57 |
model_config: ConfigDict = {
|
58 |
"arbitrary_types_allowed": True,
|
@@ -194,9 +195,14 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
|
194 |
|
195 |
def load_flow_model(config: ModelSpec) -> Flux:
|
196 |
ckpt_path = config.ckpt_path
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
with torch.device("meta"):
|
199 |
-
model =
|
200 |
into_dtype(config.flow_dtype)
|
201 |
)
|
202 |
|
@@ -247,7 +253,7 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
|
|
247 |
|
248 |
|
249 |
class LoadedModels(BaseModel):
|
250 |
-
flow: Flux
|
251 |
ae: AutoEncoder
|
252 |
clip: HFEmbedder
|
253 |
t5: HFEmbedder
|
|
|
6 |
from modules.autoencoder import AutoEncoder, AutoEncoderParams
|
7 |
from modules.conditioner import HFEmbedder
|
8 |
from modules.flux_model import Flux, FluxParams
|
9 |
+
from modules.flux_model_f8 import Flux as FluxF8
|
10 |
from safetensors.torch import load_file as load_sft
|
11 |
from enum import StrEnum
|
12 |
from pydantic import BaseModel, ConfigDict
|
|
|
53 |
offload_text_encoder: bool = False
|
54 |
offload_vae: bool = False
|
55 |
offload_flow: bool = False
|
56 |
+
prequantized_flow: bool = False
|
57 |
|
58 |
model_config: ConfigDict = {
|
59 |
"arbitrary_types_allowed": True,
|
|
|
195 |
|
196 |
def load_flow_model(config: ModelSpec) -> Flux:
|
197 |
ckpt_path = config.ckpt_path
|
198 |
+
FluxClass = Flux
|
199 |
+
if config.prequantized_flow:
|
200 |
+
from modules.flux_model_f8 import Flux as FluxF8
|
201 |
+
|
202 |
+
FluxClass = FluxF8
|
203 |
|
204 |
with torch.device("meta"):
|
205 |
+
model = FluxClass(config.params, dtype=into_dtype(config.flow_dtype)).type(
|
206 |
into_dtype(config.flow_dtype)
|
207 |
)
|
208 |
|
|
|
253 |
|
254 |
|
255 |
class LoadedModels(BaseModel):
|
256 |
+
flow: Flux | FluxF8
|
257 |
ae: AutoEncoder
|
258 |
clip: HFEmbedder
|
259 |
t5: HFEmbedder
|