Improved precision / reduced frequency of nan outputs, allow bf16 t5, f32 rmsnorm, larger clamp
Browse files- README.md +7 -0
- modules/conditioner.py +7 -1
- modules/flux_model.py +14 -9
- util.py +2 -1
README.md
CHANGED
@@ -79,6 +79,13 @@ pipeline = FluxPipeline.load_pipeline_from_config_path(config_path, **config_ove
|
|
79 |
pipeline.load_lora(lora_path, scale=1.0)
|
80 |
```
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
## Installation
|
83 |
|
84 |
This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
|
|
|
79 |
pipeline.load_lora(lora_path, scale=1.0)
|
80 |
```
|
81 |
|
82 |
+
### Updates 09/07/24
|
83 |
+
|
84 |
+
- Improve quality by ensuring that the RMSNorm layers use fp32
|
85 |
+
- Raise the clamp range for single blocks & double blocks to +/-32000 to reduce deviation from expected outputs.
|
86 |
+
- Make BF16 _not_ clamp, which improves quality and isn't needed because bf16 is the expected dtype for flux. **I would now recommend always using `"flow_dtype": "bfloat16"` in the config**, though it will slow things down on consumer gpus- but not by much at all since most of the compute still happens via fp8.
|
87 |
+
- Allow for the T5 Model to be run without any quantization, by specifying `"text_enc_quantization_dtype": "bfloat16"` in the config - or also `"float16"`, though not recommended since t5 deviates a bit when running with float16. I noticed that even with qint8/qfloat8 there is a bit of deviation from bf16 text encoder outputs- so for those who want more accurate / expected text encoder outputs, you can use this option.
|
88 |
+
|
89 |
## Installation
|
90 |
|
91 |
This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
|
modules/conditioner.py
CHANGED
@@ -29,6 +29,8 @@ def auto_quantization_config(
|
|
29 |
return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
|
30 |
elif quantization_dtype == "qint2":
|
31 |
return QuantoConfig(weights="int2")
|
|
|
|
|
32 |
else:
|
33 |
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
|
34 |
|
@@ -57,7 +59,11 @@ class HFEmbedder(nn.Module):
|
|
57 |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
58 |
|
59 |
auto_quant_config = (
|
60 |
-
auto_quantization_config(quantization_dtype)
|
|
|
|
|
|
|
|
|
61 |
)
|
62 |
|
63 |
# BNB will move to cuda:0 by default if not specified
|
|
|
29 |
return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
|
30 |
elif quantization_dtype == "qint2":
|
31 |
return QuantoConfig(weights="int2")
|
32 |
+
elif quantization_dtype is None or quantization_dtype == "bfloat16":
|
33 |
+
return None
|
34 |
else:
|
35 |
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
|
36 |
|
|
|
59 |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
60 |
|
61 |
auto_quant_config = (
|
62 |
+
auto_quantization_config(quantization_dtype)
|
63 |
+
if quantization_dtype is not None
|
64 |
+
and quantization_dtype != "bfloat16"
|
65 |
+
and quantization_dtype != "float16"
|
66 |
+
else None
|
67 |
)
|
68 |
|
69 |
# BNB will move to cuda:0 by default if not specified
|
modules/flux_model.py
CHANGED
@@ -159,7 +159,7 @@ class RMSNorm(torch.nn.Module):
|
|
159 |
self.scale = nn.Parameter(torch.ones(dim))
|
160 |
|
161 |
def forward(self, x: Tensor):
|
162 |
-
return F.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
163 |
|
164 |
|
165 |
class QKNorm(torch.nn.Module):
|
@@ -344,7 +344,7 @@ class DoubleStreamBlock(nn.Module):
|
|
344 |
self.K = 3
|
345 |
self.H = self.num_heads
|
346 |
self.KH = self.K * self.H
|
347 |
-
|
348 |
def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
349 |
B, L, D = x.shape
|
350 |
q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
@@ -384,14 +384,16 @@ class DoubleStreamBlock(nn.Module):
|
|
384 |
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
385 |
img = img + img_mod2.gate * self.img_mlp(
|
386 |
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
387 |
-
)
|
388 |
|
389 |
# calculate the txt bloks
|
390 |
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
391 |
txt = txt + txt_mod2.gate * self.txt_mlp(
|
392 |
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
393 |
-
)
|
394 |
-
|
|
|
|
|
395 |
return img, txt
|
396 |
|
397 |
|
@@ -457,6 +459,7 @@ class SingleStreamBlock(nn.Module):
|
|
457 |
self.K = 3
|
458 |
self.H = self.num_heads
|
459 |
self.KH = self.K * self.H
|
|
|
460 |
|
461 |
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
462 |
mod = self.modulation(vec)[0]
|
@@ -471,10 +474,12 @@ class SingleStreamBlock(nn.Module):
|
|
471 |
q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
472 |
q, k = self.norm(q, k, v)
|
473 |
attn = attention(q, k, v, pe=pe)
|
474 |
-
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
475 |
-
|
476 |
-
|
477 |
-
|
|
|
|
|
478 |
|
479 |
|
480 |
class LastLayer(nn.Module):
|
|
|
159 |
self.scale = nn.Parameter(torch.ones(dim))
|
160 |
|
161 |
def forward(self, x: Tensor):
|
162 |
+
return F.rms_norm(x.float(), self.scale.shape, self.scale, eps=1e-6).to(x)
|
163 |
|
164 |
|
165 |
class QKNorm(torch.nn.Module):
|
|
|
344 |
self.K = 3
|
345 |
self.H = self.num_heads
|
346 |
self.KH = self.K * self.H
|
347 |
+
self.do_clamp = dtype == torch.float16
|
348 |
def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
349 |
B, L, D = x.shape
|
350 |
q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
|
|
384 |
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
385 |
img = img + img_mod2.gate * self.img_mlp(
|
386 |
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
387 |
+
)
|
388 |
|
389 |
# calculate the txt bloks
|
390 |
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
391 |
txt = txt + txt_mod2.gate * self.txt_mlp(
|
392 |
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
393 |
+
)
|
394 |
+
if self.do_clamp:
|
395 |
+
img = img.clamp(min=-32000, max=32000)
|
396 |
+
txt = txt.clamp(min=-32000, max=32000)
|
397 |
return img, txt
|
398 |
|
399 |
|
|
|
459 |
self.K = 3
|
460 |
self.H = self.num_heads
|
461 |
self.KH = self.K * self.H
|
462 |
+
self.do_clamp = dtype == torch.float16
|
463 |
|
464 |
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
465 |
mod = self.modulation(vec)[0]
|
|
|
474 |
q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
475 |
q, k = self.norm(q, k, v)
|
476 |
attn = attention(q, k, v, pe=pe)
|
477 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
478 |
+
if self.do_clamp:
|
479 |
+
out = (x + mod.gate * output).clamp(min=-32000, max=32000)
|
480 |
+
else:
|
481 |
+
out = x + mod.gate * output
|
482 |
+
return out
|
483 |
|
484 |
|
485 |
class LastLayer(nn.Module):
|
util.py
CHANGED
@@ -31,7 +31,8 @@ class QuantizationDtype(StrEnum):
|
|
31 |
qint2 = "qint2"
|
32 |
qint4 = "qint4"
|
33 |
qint8 = "qint8"
|
34 |
-
|
|
|
35 |
|
36 |
class ModelSpec(BaseModel):
|
37 |
version: ModelVersion
|
|
|
31 |
qint2 = "qint2"
|
32 |
qint4 = "qint4"
|
33 |
qint8 = "qint8"
|
34 |
+
bfloat16 = "bfloat16"
|
35 |
+
float16 = "float16"
|
36 |
|
37 |
class ModelSpec(BaseModel):
|
38 |
version: ModelVersion
|