aredden commited on
Commit
f708e90
·
1 Parent(s): 3cc2f3f

Improved precision / reduced frequency of nan outputs, allow bf16 t5, f32 rmsnorm, larger clamp

Browse files
Files changed (4) hide show
  1. README.md +7 -0
  2. modules/conditioner.py +7 -1
  3. modules/flux_model.py +14 -9
  4. util.py +2 -1
README.md CHANGED
@@ -79,6 +79,13 @@ pipeline = FluxPipeline.load_pipeline_from_config_path(config_path, **config_ove
79
  pipeline.load_lora(lora_path, scale=1.0)
80
  ```
81
 
 
 
 
 
 
 
 
82
  ## Installation
83
 
84
  This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
 
79
  pipeline.load_lora(lora_path, scale=1.0)
80
  ```
81
 
82
+ ### Updates 09/07/24
83
+
84
+ - Improve quality by ensuring that the RMSNorm layers use fp32
85
+ - Raise the clamp range for single blocks & double blocks to +/-32000 to reduce deviation from expected outputs.
86
+ - Make BF16 _not_ clamp, which improves quality and isn't needed because bf16 is the expected dtype for flux. **I would now recommend always using `"flow_dtype": "bfloat16"` in the config**, though it will slow things down on consumer gpus- but not by much at all since most of the compute still happens via fp8.
87
+ - Allow for the T5 Model to be run without any quantization, by specifying `"text_enc_quantization_dtype": "bfloat16"` in the config - or also `"float16"`, though not recommended since t5 deviates a bit when running with float16. I noticed that even with qint8/qfloat8 there is a bit of deviation from bf16 text encoder outputs- so for those who want more accurate / expected text encoder outputs, you can use this option.
88
+
89
  ## Installation
90
 
91
  This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
modules/conditioner.py CHANGED
@@ -29,6 +29,8 @@ def auto_quantization_config(
29
  return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
30
  elif quantization_dtype == "qint2":
31
  return QuantoConfig(weights="int2")
 
 
32
  else:
33
  raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
34
 
@@ -57,7 +59,11 @@ class HFEmbedder(nn.Module):
57
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
58
 
59
  auto_quant_config = (
60
- auto_quantization_config(quantization_dtype) if quantization_dtype else None
 
 
 
 
61
  )
62
 
63
  # BNB will move to cuda:0 by default if not specified
 
29
  return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
30
  elif quantization_dtype == "qint2":
31
  return QuantoConfig(weights="int2")
32
+ elif quantization_dtype is None or quantization_dtype == "bfloat16":
33
+ return None
34
  else:
35
  raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
36
 
 
59
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
60
 
61
  auto_quant_config = (
62
+ auto_quantization_config(quantization_dtype)
63
+ if quantization_dtype is not None
64
+ and quantization_dtype != "bfloat16"
65
+ and quantization_dtype != "float16"
66
+ else None
67
  )
68
 
69
  # BNB will move to cuda:0 by default if not specified
modules/flux_model.py CHANGED
@@ -159,7 +159,7 @@ class RMSNorm(torch.nn.Module):
159
  self.scale = nn.Parameter(torch.ones(dim))
160
 
161
  def forward(self, x: Tensor):
162
- return F.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
163
 
164
 
165
  class QKNorm(torch.nn.Module):
@@ -344,7 +344,7 @@ class DoubleStreamBlock(nn.Module):
344
  self.K = 3
345
  self.H = self.num_heads
346
  self.KH = self.K * self.H
347
-
348
  def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
349
  B, L, D = x.shape
350
  q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
@@ -384,14 +384,16 @@ class DoubleStreamBlock(nn.Module):
384
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
385
  img = img + img_mod2.gate * self.img_mlp(
386
  (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
387
- ).clamp(min=-384 * 2, max=384 * 2)
388
 
389
  # calculate the txt bloks
390
  txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
391
  txt = txt + txt_mod2.gate * self.txt_mlp(
392
  (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
393
- ).clamp(min=-384 * 2, max=384 * 2)
394
-
 
 
395
  return img, txt
396
 
397
 
@@ -457,6 +459,7 @@ class SingleStreamBlock(nn.Module):
457
  self.K = 3
458
  self.H = self.num_heads
459
  self.KH = self.K * self.H
 
460
 
461
  def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
462
  mod = self.modulation(vec)[0]
@@ -471,10 +474,12 @@ class SingleStreamBlock(nn.Module):
471
  q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
472
  q, k = self.norm(q, k, v)
473
  attn = attention(q, k, v, pe=pe)
474
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
475
- min=-384 * 4, max=384 * 4
476
- )
477
- return x + mod.gate * output
 
 
478
 
479
 
480
  class LastLayer(nn.Module):
 
159
  self.scale = nn.Parameter(torch.ones(dim))
160
 
161
  def forward(self, x: Tensor):
162
+ return F.rms_norm(x.float(), self.scale.shape, self.scale, eps=1e-6).to(x)
163
 
164
 
165
  class QKNorm(torch.nn.Module):
 
344
  self.K = 3
345
  self.H = self.num_heads
346
  self.KH = self.K * self.H
347
+ self.do_clamp = dtype == torch.float16
348
  def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
349
  B, L, D = x.shape
350
  q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
 
384
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
385
  img = img + img_mod2.gate * self.img_mlp(
386
  (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
387
+ )
388
 
389
  # calculate the txt bloks
390
  txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
391
  txt = txt + txt_mod2.gate * self.txt_mlp(
392
  (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
393
+ )
394
+ if self.do_clamp:
395
+ img = img.clamp(min=-32000, max=32000)
396
+ txt = txt.clamp(min=-32000, max=32000)
397
  return img, txt
398
 
399
 
 
459
  self.K = 3
460
  self.H = self.num_heads
461
  self.KH = self.K * self.H
462
+ self.do_clamp = dtype == torch.float16
463
 
464
  def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
465
  mod = self.modulation(vec)[0]
 
474
  q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
475
  q, k = self.norm(q, k, v)
476
  attn = attention(q, k, v, pe=pe)
477
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
478
+ if self.do_clamp:
479
+ out = (x + mod.gate * output).clamp(min=-32000, max=32000)
480
+ else:
481
+ out = x + mod.gate * output
482
+ return out
483
 
484
 
485
  class LastLayer(nn.Module):
util.py CHANGED
@@ -31,7 +31,8 @@ class QuantizationDtype(StrEnum):
31
  qint2 = "qint2"
32
  qint4 = "qint4"
33
  qint8 = "qint8"
34
-
 
35
 
36
  class ModelSpec(BaseModel):
37
  version: ModelVersion
 
31
  qint2 = "qint2"
32
  qint4 = "qint4"
33
  qint8 = "qint8"
34
+ bfloat16 = "bfloat16"
35
+ float16 = "float16"
36
 
37
  class ModelSpec(BaseModel):
38
  version: ModelVersion