aredden commited on
Commit
2f2c44c
·
1 Parent(s): 289aa1f

Fix non-offload inference & add option to load from prequantized flux

Browse files
configs/config-dev-prequant.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/flux-fp16-acc/flux_fp8.safetensors",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qfloat8",
52
+ "num_to_quant": 22,
53
+ "compile_extras": true,
54
+ "compile_blocks": true,
55
+ "prequantized_flow": true
56
+ }
float8_quantize.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from torchao.float8.float8_utils import (
@@ -16,6 +17,24 @@ except ImportError:
16
  CublasLinear = type(None)
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class F8Linear(nn.Module):
20
 
21
  def __init__(
@@ -24,7 +43,7 @@ class F8Linear(nn.Module):
24
  out_features: int,
25
  bias: bool = True,
26
  device=None,
27
- dtype=None,
28
  float8_dtype=torch.float8_e4m3fn,
29
  float_weight: torch.Tensor = None,
30
  float_bias: torch.Tensor = None,
@@ -53,7 +72,6 @@ class F8Linear(nn.Module):
53
  if bias:
54
  self.bias = nn.Parameter(
55
  torch.empty(out_features, **factory_kwargs),
56
- requires_grad=bias.requires_grad,
57
  )
58
  else:
59
  self.register_parameter("bias", None)
@@ -78,6 +96,85 @@ class F8Linear(nn.Module):
78
  "input_scale_reciprocal", None
79
  )
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def quantize_weight(self):
82
  if self.weight_initialized:
83
  return
 
1
+ from typing import Any, Mapping
2
  import torch
3
  import torch.nn as nn
4
  from torchao.float8.float8_utils import (
 
17
  CublasLinear = type(None)
18
 
19
 
20
+ def check_scale_tensor(tensor):
21
+ return (
22
+ tensor is not None
23
+ and isinstance(tensor, torch.Tensor)
24
+ and tensor.dtype == torch.float32
25
+ and tensor.numel() == 1
26
+ and tensor != torch.zeros_like(tensor)
27
+ )
28
+
29
+
30
+ def check_scale_in_state_dict(state_dict, key):
31
+ return key in state_dict and check_scale_tensor(state_dict[key])
32
+
33
+
34
+ def check_scales_given_state_dict_and_keys(state_dict, keys):
35
+ return all(check_scale_in_state_dict(state_dict, key) for key in keys)
36
+
37
+
38
  class F8Linear(nn.Module):
39
 
40
  def __init__(
 
43
  out_features: int,
44
  bias: bool = True,
45
  device=None,
46
+ dtype=torch.float16,
47
  float8_dtype=torch.float8_e4m3fn,
48
  float_weight: torch.Tensor = None,
49
  float_bias: torch.Tensor = None,
 
72
  if bias:
73
  self.bias = nn.Parameter(
74
  torch.empty(out_features, **factory_kwargs),
 
75
  )
76
  else:
77
  self.register_parameter("bias", None)
 
96
  "input_scale_reciprocal", None
97
  )
98
 
99
+ def _load_from_state_dict(
100
+ self,
101
+ state_dict,
102
+ prefix,
103
+ local_metadata,
104
+ strict,
105
+ missing_keys,
106
+ unexpected_keys,
107
+ error_msgs,
108
+ ):
109
+ sd = {k.replace(prefix, ""): v for k, v in state_dict.items()}
110
+ if "weight" in sd:
111
+ if (
112
+ "float8_data" not in sd
113
+ or sd["float8_data"] is None
114
+ and sd["weight"].shape == (self.out_features, self.in_features)
115
+ ):
116
+ # Initialize as if it's an F8Linear that needs to be quantized
117
+ self._parameters["weight"] = nn.Parameter(
118
+ sd["weight"], requires_grad=False
119
+ )
120
+ if "bias" in sd:
121
+ self._parameters["bias"] = nn.Parameter(
122
+ sd["bias"], requires_grad=False
123
+ )
124
+ self.quantize_weight()
125
+ elif sd["float8_data"].shape == (
126
+ self.out_features,
127
+ self.in_features,
128
+ ) and sd["weight"] == torch.zeros_like(sd["weight"]):
129
+ w = sd["weight"]
130
+ # Set the init values as if it's already quantized float8_data
131
+ self.float8_data = sd["float8_data"]
132
+ self._parameters["weight"] = nn.Parameter(
133
+ torch.zeros(
134
+ 1,
135
+ dtype=w.dtype,
136
+ device=w.device,
137
+ requires_grad=False,
138
+ )
139
+ )
140
+ if "bias" in sd:
141
+ self._parameters["bias"] = nn.Parameter(
142
+ sd["bias"], requires_grad=False
143
+ )
144
+ self.weight_initialized = True
145
+
146
+ # Check if scales and reciprocals are initialized
147
+ if all(
148
+ key in sd
149
+ for key in [
150
+ "scale",
151
+ "input_scale",
152
+ "scale_reciprocal",
153
+ "input_scale_reciprocal",
154
+ ]
155
+ ):
156
+ self.scale = sd["scale"].float()
157
+ self.input_scale = sd["input_scale"].float()
158
+ self.scale_reciprocal = sd["scale_reciprocal"].float()
159
+ self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
160
+ self.input_scale_initialized = True
161
+ self.trial_index = self.num_scale_trials
162
+ else:
163
+ # If scales are not initialized, reset trials
164
+ self.input_scale_initialized = False
165
+ self.trial_index = 0
166
+ self.input_amax_trials = torch.zeros(
167
+ self.num_scale_trials, requires_grad=False, dtype=torch.float32
168
+ )
169
+ else:
170
+ raise RuntimeError(
171
+ f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}"
172
+ )
173
+ else:
174
+ raise RuntimeError(
175
+ "Weight tensor not found or has incorrect shape in state dict"
176
+ )
177
+
178
  def quantize_weight(self):
179
  if self.weight_initialized:
180
  return
flux_pipeline.py CHANGED
@@ -92,18 +92,26 @@ class FluxPipeline:
92
  self.offload_text_encoder = config.offload_text_encoder
93
  self.offload_vae = config.offload_vae
94
  self.offload_flow = config.offload_flow
 
 
 
 
 
 
 
95
 
96
  if self.config.compile_blocks or self.config.compile_extras:
97
- print("Warmups for compile...")
98
- warmup_dict = dict(
99
- prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
100
- height=1024,
101
- width=1024,
102
- num_steps=30,
103
- guidance=3.5,
104
- seed=10,
105
- )
106
- self.generate(**warmup_dict)
 
107
  to_gpu_extras = [
108
  "vector_in",
109
  "img_in",
@@ -247,7 +255,7 @@ class FluxPipeline:
247
  im = torch.vstack(images)
248
 
249
  torch.cuda.synchronize()
250
- im = self.turbojpeg.encode_torch(im, quality=99)
251
  images.clear()
252
  return io.BytesIO(im)
253
 
@@ -458,6 +466,7 @@ class FluxPipeline:
458
 
459
  with torch.inference_mode():
460
  print("flow_quantization_dtype", config.flow_quantization_dtype)
 
461
 
462
  models = load_models_from_config(config)
463
  config = models.config
@@ -466,13 +475,14 @@ class FluxPipeline:
466
  clip_device = into_device(config.text_enc_device)
467
  t5_device = into_device(config.text_enc_device)
468
  flux_dtype = into_dtype(config.flow_dtype)
469
- flow_model = models.flow.type(flux_dtype).to(
470
- memory_format=torch.channels_last
471
- )
472
 
473
- flow_model = quantize_flow_transformer_and_dispatch_float8(
474
- flow_model, flux_device
475
- )
 
 
 
476
 
477
  return cls(
478
  name=config.version,
@@ -492,7 +502,7 @@ class FluxPipeline:
492
 
493
  if __name__ == "__main__":
494
  pipe = FluxPipeline.load_pipeline_from_config_path(
495
- "configs/config-dev-offload.json"
496
  )
497
  o = pipe.generate(
498
  prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
@@ -503,7 +513,7 @@ if __name__ == "__main__":
503
  seed=10,
504
  )
505
  open("out.jpg", "wb").write(o.read())
506
- for x in range(10):
507
 
508
  o = pipe.generate(
509
  prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
 
92
  self.offload_text_encoder = config.offload_text_encoder
93
  self.offload_vae = config.offload_vae
94
  self.offload_flow = config.offload_flow
95
+ if not self.offload_flow:
96
+ self.model.to(self.device_flux)
97
+ if not self.offload_vae:
98
+ self.ae.to(self.device_ae)
99
+ if not self.offload_text_encoder:
100
+ self.clip.to(self.device_clip)
101
+ self.t5.to(self.device_t5)
102
 
103
  if self.config.compile_blocks or self.config.compile_extras:
104
+ if not self.config.prequantized_flow:
105
+ print("Warmups for compile...")
106
+ warmup_dict = dict(
107
+ prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
108
+ height=1024,
109
+ width=1024,
110
+ num_steps=30,
111
+ guidance=3.5,
112
+ seed=10,
113
+ )
114
+ self.generate(**warmup_dict)
115
  to_gpu_extras = [
116
  "vector_in",
117
  "img_in",
 
255
  im = torch.vstack(images)
256
 
257
  torch.cuda.synchronize()
258
+ im = self.img_encoder.encode_torch(im, quality=99)
259
  images.clear()
260
  return io.BytesIO(im)
261
 
 
466
 
467
  with torch.inference_mode():
468
  print("flow_quantization_dtype", config.flow_quantization_dtype)
469
+ print("prequantized_flow?", config.prequantized_flow)
470
 
471
  models = load_models_from_config(config)
472
  config = models.config
 
475
  clip_device = into_device(config.text_enc_device)
476
  t5_device = into_device(config.text_enc_device)
477
  flux_dtype = into_dtype(config.flow_dtype)
478
+ flow_model = models.flow
 
 
479
 
480
+ if not config.prequantized_flow:
481
+ flow_model = quantize_flow_transformer_and_dispatch_float8(
482
+ flow_model, flux_device, offload_flow=config.offload_flow
483
+ )
484
+ else:
485
+ flow_model.eval().requires_grad_(False)
486
 
487
  return cls(
488
  name=config.version,
 
502
 
503
  if __name__ == "__main__":
504
  pipe = FluxPipeline.load_pipeline_from_config_path(
505
+ "configs/config-dev-prequant.json",
506
  )
507
  o = pipe.generate(
508
  prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
 
513
  seed=10,
514
  )
515
  open("out.jpg", "wb").write(o.read())
516
+ for x in range(2):
517
 
518
  o = pipe.generate(
519
  prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
modules/flux_model_f8.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import os
3
+ import torch
4
+
5
+ DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cudnn.allow_tf32 = True
8
+ torch.backends.cudnn.benchmark = True
9
+ torch.backends.cudnn.benchmark_limit = 20
10
+ torch.set_float32_matmul_precision("high")
11
+ import math
12
+
13
+ from torch import Tensor, nn
14
+ from pydantic import BaseModel
15
+ from torch.nn import functional as F
16
+ from float8_quantize import F8Linear
17
+
18
+ try:
19
+ from cublas_ops import CublasLinear
20
+ except ImportError:
21
+ CublasLinear = nn.Linear
22
+
23
+
24
+ class FluxParams(BaseModel):
25
+ in_channels: int
26
+ vec_in_dim: int
27
+ context_in_dim: int
28
+ hidden_size: int
29
+ mlp_ratio: float
30
+ num_heads: int
31
+ depth: int
32
+ depth_single_blocks: int
33
+ axes_dim: list[int]
34
+ theta: int
35
+ qkv_bias: bool
36
+ guidance_embed: bool
37
+
38
+
39
+ # attention is always same shape each time it's called per H*W, so compile with fullgraph
40
+ # @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
41
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
42
+ q, k = apply_rope(q, k, pe)
43
+ x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
44
+ x = x.reshape(*x.shape[:-2], -1)
45
+ return x
46
+
47
+
48
+ # @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
49
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
50
+ scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
51
+ omega = 1.0 / (theta**scale)
52
+ out = torch.einsum("...n,d->...nd", pos, omega)
53
+ out = torch.stack(
54
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
55
+ )
56
+ out = out.reshape(*out.shape[:-1], 2, 2)
57
+ return out
58
+
59
+
60
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
61
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
62
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
63
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
64
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
65
+ return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
66
+
67
+
68
+ class EmbedND(nn.Module):
69
+ def __init__(
70
+ self,
71
+ dim: int,
72
+ theta: int,
73
+ axes_dim: list[int],
74
+ dtype: torch.dtype = torch.bfloat16,
75
+ ):
76
+ super().__init__()
77
+ self.dim = dim
78
+ self.theta = theta
79
+ self.axes_dim = axes_dim
80
+ self.dtype = dtype
81
+
82
+ def forward(self, ids: Tensor) -> Tensor:
83
+ n_axes = ids.shape[-1]
84
+ emb = torch.cat(
85
+ [
86
+ rope(ids[..., i], self.axes_dim[i], self.theta).type(self.dtype)
87
+ for i in range(n_axes)
88
+ ],
89
+ dim=-3,
90
+ )
91
+
92
+ return emb.unsqueeze(1)
93
+
94
+
95
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
96
+ """
97
+ Create sinusoidal timestep embeddings.
98
+ :param t: a 1-D Tensor of N indices, one per batch element.
99
+ These may be fractional.
100
+ :param dim: the dimension of the output.
101
+ :param max_period: controls the minimum frequency of the embeddings.
102
+ :return: an (N, D) Tensor of positional embeddings.
103
+ """
104
+ t = time_factor * t
105
+ half = dim // 2
106
+ freqs = torch.exp(
107
+ -math.log(max_period)
108
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
109
+ / half
110
+ )
111
+
112
+ args = t[:, None].float() * freqs[None]
113
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
114
+ if dim % 2:
115
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
116
+ return embedding
117
+
118
+
119
+ class MLPEmbedder(nn.Module):
120
+ def __init__(self, in_dim: int, hidden_dim: int):
121
+ super().__init__()
122
+ self.in_layer = F8Linear(in_dim, hidden_dim, bias=True)
123
+ self.silu = nn.SiLU()
124
+ self.out_layer = F8Linear(hidden_dim, hidden_dim, bias=True)
125
+
126
+ def forward(self, x: Tensor) -> Tensor:
127
+ return self.out_layer(self.silu(self.in_layer(x)))
128
+
129
+
130
+ class RMSNorm(torch.nn.Module):
131
+ def __init__(self, dim: int):
132
+ super().__init__()
133
+ self.scale = nn.Parameter(torch.ones(dim))
134
+
135
+ def forward(self, x: Tensor):
136
+ return F.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
137
+
138
+
139
+ class QKNorm(torch.nn.Module):
140
+ def __init__(self, dim: int):
141
+ super().__init__()
142
+ self.query_norm = RMSNorm(dim)
143
+ self.key_norm = RMSNorm(dim)
144
+
145
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
146
+ q = self.query_norm(q)
147
+ k = self.key_norm(k)
148
+ return q, k
149
+
150
+
151
+ class SelfAttention(nn.Module):
152
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
153
+ super().__init__()
154
+ self.num_heads = num_heads
155
+ head_dim = dim // num_heads
156
+
157
+ self.qkv = F8Linear(dim, dim * 3, bias=qkv_bias)
158
+ self.norm = QKNorm(head_dim)
159
+ self.proj = F8Linear(dim, dim)
160
+ self.K = 3
161
+ self.H = self.num_heads
162
+ self.KH = self.K * self.H
163
+
164
+ def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
165
+ B, L, D = x.shape
166
+ q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
167
+ return q, k, v
168
+
169
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
170
+ qkv = self.qkv(x)
171
+ q, k, v = self.rearrange_for_norm(qkv)
172
+ q, k = self.norm(q, k, v)
173
+ x = attention(q, k, v, pe=pe)
174
+ x = self.proj(x)
175
+ return x
176
+
177
+
178
+ ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
179
+
180
+
181
+ class Modulation(nn.Module):
182
+ def __init__(self, dim: int, double: bool):
183
+ super().__init__()
184
+ self.is_double = double
185
+ self.multiplier = 6 if double else 3
186
+ self.lin = F8Linear(dim, self.multiplier * dim, bias=True)
187
+ self.act = nn.SiLU()
188
+
189
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
190
+ out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1)
191
+
192
+ return (
193
+ ModulationOut(*out[:3]),
194
+ ModulationOut(*out[3:]) if self.is_double else None,
195
+ )
196
+
197
+
198
+ class DoubleStreamBlock(nn.Module):
199
+ def __init__(
200
+ self,
201
+ hidden_size: int,
202
+ num_heads: int,
203
+ mlp_ratio: float,
204
+ qkv_bias: bool = False,
205
+ dtype: torch.dtype = torch.float16,
206
+ ):
207
+ super().__init__()
208
+ self.dtype = dtype
209
+
210
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
211
+ self.num_heads = num_heads
212
+ self.hidden_size = hidden_size
213
+ self.img_mod = Modulation(hidden_size, double=True)
214
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
215
+ self.img_attn = SelfAttention(
216
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
217
+ )
218
+
219
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
220
+ self.img_mlp = nn.Sequential(
221
+ F8Linear(hidden_size, mlp_hidden_dim, bias=True),
222
+ nn.GELU(approximate="tanh"),
223
+ F8Linear(mlp_hidden_dim, hidden_size, bias=True),
224
+ )
225
+
226
+ self.txt_mod = Modulation(hidden_size, double=True)
227
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
228
+ self.txt_attn = SelfAttention(
229
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
230
+ )
231
+
232
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
233
+ self.txt_mlp = nn.Sequential(
234
+ F8Linear(hidden_size, mlp_hidden_dim, bias=True),
235
+ nn.GELU(approximate="tanh"),
236
+ F8Linear(mlp_hidden_dim, hidden_size, bias=True),
237
+ )
238
+ self.K = 3
239
+ self.H = self.num_heads
240
+ self.KH = self.K * self.H
241
+
242
+ def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
243
+ B, L, D = x.shape
244
+ q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
245
+ return q, k, v
246
+
247
+ def forward(
248
+ self,
249
+ img: Tensor,
250
+ txt: Tensor,
251
+ vec: Tensor,
252
+ pe: Tensor,
253
+ ) -> tuple[Tensor, Tensor]:
254
+ img_mod1, img_mod2 = self.img_mod(vec)
255
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
256
+
257
+ # prepare image for attention
258
+ img_modulated = self.img_norm1(img)
259
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
260
+ img_qkv = self.img_attn.qkv(img_modulated)
261
+ img_q, img_k, img_v = self.rearrange_for_norm(img_qkv)
262
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
263
+
264
+ # prepare txt for attention
265
+ txt_modulated = self.txt_norm1(txt)
266
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
267
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
268
+ txt_q, txt_k, txt_v = self.rearrange_for_norm(txt_qkv)
269
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
270
+
271
+ q = torch.cat((txt_q, img_q), dim=2)
272
+ k = torch.cat((txt_k, img_k), dim=2)
273
+ v = torch.cat((txt_v, img_v), dim=2)
274
+
275
+ attn = attention(q, k, v, pe=pe)
276
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
277
+ # calculate the img bloks
278
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
279
+ img = img + img_mod2.gate * self.img_mlp(
280
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
281
+ ).clamp(min=-384 * 2, max=384 * 2)
282
+
283
+ # calculate the txt bloks
284
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
285
+ txt = txt + txt_mod2.gate * self.txt_mlp(
286
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
287
+ ).clamp(min=-384 * 2, max=384 * 2)
288
+
289
+ return img, txt
290
+
291
+
292
+ class SingleStreamBlock(nn.Module):
293
+ """
294
+ A DiT block with parallel linear layers as described in
295
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ hidden_size: int,
301
+ num_heads: int,
302
+ mlp_ratio: float = 4.0,
303
+ qk_scale: float | None = None,
304
+ dtype: torch.dtype = torch.float16,
305
+ ):
306
+ super().__init__()
307
+ self.dtype = dtype
308
+ self.hidden_dim = hidden_size
309
+ self.num_heads = num_heads
310
+ head_dim = hidden_size // num_heads
311
+ self.scale = qk_scale or head_dim**-0.5
312
+
313
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
314
+ # qkv and mlp_in
315
+ self.linear1 = F8Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
316
+ # proj and mlp_out
317
+ self.linear2 = F8Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
318
+
319
+ self.norm = QKNorm(head_dim)
320
+
321
+ self.hidden_size = hidden_size
322
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
323
+
324
+ self.mlp_act = nn.GELU(approximate="tanh")
325
+ self.modulation = Modulation(hidden_size, double=False)
326
+
327
+ self.K = 3
328
+ self.H = self.num_heads
329
+ self.KH = self.K * self.H
330
+
331
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
332
+ mod = self.modulation(vec)[0]
333
+ pre_norm = self.pre_norm(x)
334
+ x_mod = (1 + mod.scale) * pre_norm + mod.shift
335
+ qkv, mlp = torch.split(
336
+ self.linear1(x_mod),
337
+ [3 * self.hidden_size, self.mlp_hidden_dim],
338
+ dim=-1,
339
+ )
340
+ B, L, D = qkv.shape
341
+ q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
342
+ q, k = self.norm(q, k, v)
343
+ attn = attention(q, k, v, pe=pe)
344
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
345
+ min=-384 * 4, max=384 * 4
346
+ )
347
+ return x + mod.gate * output
348
+
349
+
350
+ class LastLayer(nn.Module):
351
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
352
+ super().__init__()
353
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
354
+ self.linear = CublasLinear(
355
+ hidden_size, patch_size * patch_size * out_channels, bias=True
356
+ )
357
+ self.adaLN_modulation = nn.Sequential(
358
+ nn.SiLU(), CublasLinear(hidden_size, 2 * hidden_size, bias=True)
359
+ )
360
+
361
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
362
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
363
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
364
+ x = self.linear(x)
365
+ return x
366
+
367
+
368
+ class Flux(nn.Module):
369
+ """
370
+ Transformer model for flow matching on sequences.
371
+ """
372
+
373
+ def __init__(self, params: FluxParams, dtype: torch.dtype = torch.bfloat16):
374
+ super().__init__()
375
+
376
+ self.dtype = dtype
377
+ self.params = params
378
+ self.in_channels = params.in_channels
379
+ self.out_channels = self.in_channels
380
+ if params.hidden_size % params.num_heads != 0:
381
+ raise ValueError(
382
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
383
+ )
384
+ pe_dim = params.hidden_size // params.num_heads
385
+ if sum(params.axes_dim) != pe_dim:
386
+ raise ValueError(
387
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
388
+ )
389
+ self.hidden_size = params.hidden_size
390
+ self.num_heads = params.num_heads
391
+ self.pe_embedder = EmbedND(
392
+ dim=pe_dim,
393
+ theta=params.theta,
394
+ axes_dim=params.axes_dim,
395
+ dtype=self.dtype,
396
+ )
397
+ self.img_in = F8Linear(self.in_channels, self.hidden_size, bias=True)
398
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
399
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
400
+ self.guidance_in = (
401
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
402
+ if params.guidance_embed
403
+ else nn.Identity()
404
+ )
405
+ self.txt_in = F8Linear(params.context_in_dim, self.hidden_size)
406
+
407
+ self.double_blocks = nn.ModuleList(
408
+ [
409
+ DoubleStreamBlock(
410
+ self.hidden_size,
411
+ self.num_heads,
412
+ mlp_ratio=params.mlp_ratio,
413
+ qkv_bias=params.qkv_bias,
414
+ dtype=self.dtype,
415
+ )
416
+ for _ in range(params.depth)
417
+ ]
418
+ )
419
+
420
+ self.single_blocks = nn.ModuleList(
421
+ [
422
+ SingleStreamBlock(
423
+ self.hidden_size,
424
+ self.num_heads,
425
+ mlp_ratio=params.mlp_ratio,
426
+ dtype=self.dtype,
427
+ )
428
+ for _ in range(params.depth_single_blocks)
429
+ ]
430
+ )
431
+
432
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
433
+
434
+ def forward(
435
+ self,
436
+ img: Tensor,
437
+ img_ids: Tensor,
438
+ txt: Tensor,
439
+ txt_ids: Tensor,
440
+ timesteps: Tensor,
441
+ y: Tensor,
442
+ guidance: Tensor | None = None,
443
+ ) -> Tensor:
444
+ if img.ndim != 3 or txt.ndim != 3:
445
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
446
+
447
+ # running on sequences img
448
+ img = self.img_in(img)
449
+ vec = self.time_in(timestep_embedding(timesteps, 256).type(self.dtype))
450
+
451
+ if self.params.guidance_embed:
452
+ if guidance is None:
453
+ raise ValueError(
454
+ "Didn't get guidance strength for guidance distilled model."
455
+ )
456
+ vec = vec + self.guidance_in(
457
+ timestep_embedding(guidance, 256).type(self.dtype)
458
+ )
459
+ vec = vec + self.vector_in(y)
460
+
461
+ txt = self.txt_in(txt)
462
+
463
+ ids = torch.cat((txt_ids, img_ids), dim=1)
464
+ pe = self.pe_embedder(ids)
465
+
466
+ # double stream blocks
467
+ for block in self.double_blocks:
468
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
469
+
470
+ img = torch.cat((txt, img), 1)
471
+
472
+ # single stream blocks
473
+ for block in self.single_blocks:
474
+ img = block(img, vec=vec, pe=pe)
475
+
476
+ img = img[:, txt.shape[1] :, ...]
477
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
478
+ return img
479
+
480
+ @classmethod
481
+ def from_pretrained(cls, path: str, dtype: torch.dtype = torch.bfloat16) -> "Flux":
482
+ from util import load_config_from_path
483
+ from safetensors.torch import load_file
484
+
485
+ config = load_config_from_path(path)
486
+ with torch.device("meta"):
487
+ klass = cls(params=config.params, dtype=dtype).type(dtype)
488
+
489
+ ckpt = load_file(config.ckpt_path, device="cpu")
490
+ klass.load_state_dict(ckpt, assign=True)
491
+ return klass.to("cpu")
util.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from modules.autoencoder import AutoEncoder, AutoEncoderParams
7
  from modules.conditioner import HFEmbedder
8
  from modules.flux_model import Flux, FluxParams
9
-
10
  from safetensors.torch import load_file as load_sft
11
  from enum import StrEnum
12
  from pydantic import BaseModel, ConfigDict
@@ -53,6 +53,7 @@ class ModelSpec(BaseModel):
53
  offload_text_encoder: bool = False
54
  offload_vae: bool = False
55
  offload_flow: bool = False
 
56
 
57
  model_config: ConfigDict = {
58
  "arbitrary_types_allowed": True,
@@ -194,9 +195,14 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
194
 
195
  def load_flow_model(config: ModelSpec) -> Flux:
196
  ckpt_path = config.ckpt_path
 
 
 
 
 
197
 
198
  with torch.device("meta"):
199
- model = Flux(config.params, dtype=into_dtype(config.flow_dtype)).type(
200
  into_dtype(config.flow_dtype)
201
  )
202
 
@@ -247,7 +253,7 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
247
 
248
 
249
  class LoadedModels(BaseModel):
250
- flow: Flux
251
  ae: AutoEncoder
252
  clip: HFEmbedder
253
  t5: HFEmbedder
 
6
  from modules.autoencoder import AutoEncoder, AutoEncoderParams
7
  from modules.conditioner import HFEmbedder
8
  from modules.flux_model import Flux, FluxParams
9
+ from modules.flux_model_f8 import Flux as FluxF8
10
  from safetensors.torch import load_file as load_sft
11
  from enum import StrEnum
12
  from pydantic import BaseModel, ConfigDict
 
53
  offload_text_encoder: bool = False
54
  offload_vae: bool = False
55
  offload_flow: bool = False
56
+ prequantized_flow: bool = False
57
 
58
  model_config: ConfigDict = {
59
  "arbitrary_types_allowed": True,
 
195
 
196
  def load_flow_model(config: ModelSpec) -> Flux:
197
  ckpt_path = config.ckpt_path
198
+ FluxClass = Flux
199
+ if config.prequantized_flow:
200
+ from modules.flux_model_f8 import Flux as FluxF8
201
+
202
+ FluxClass = FluxF8
203
 
204
  with torch.device("meta"):
205
+ model = FluxClass(config.params, dtype=into_dtype(config.flow_dtype)).type(
206
  into_dtype(config.flow_dtype)
207
  )
208
 
 
253
 
254
 
255
  class LoadedModels(BaseModel):
256
+ flow: Flux | FluxF8
257
  ae: AutoEncoder
258
  clip: HFEmbedder
259
  t5: HFEmbedder