aredden commited on
Commit
c2ecfb5
·
1 Parent(s): 9eceba0

Add img2img, more options, gradio interface

Browse files
.gitignore CHANGED
@@ -1 +1,12 @@
1
  __pycache__
 
 
 
 
 
 
 
 
 
 
 
 
1
  __pycache__
2
+ *.jpg
3
+ *.png
4
+ *.jpeg
5
+ *.gif
6
+ *.bmp
7
+ *.webp
8
+ *.mp4
9
+ *.mp3
10
+ *.mp3
11
+ *.txt
12
+ .copilotignore
api.py CHANGED
@@ -17,6 +17,8 @@ class GenerateArgs(BaseModel):
17
  seed: Optional[int] = Field(
18
  default_factory=lambda: np.random.randint(0, 2**32 - 1), gt=0, lt=2**32 - 1
19
  )
 
 
20
 
21
 
22
  @app.post("/generate")
 
17
  seed: Optional[int] = Field(
18
  default_factory=lambda: np.random.randint(0, 2**32 - 1), gt=0, lt=2**32 - 1
19
  )
20
+ strength: Optional[float] = 1.0
21
+ init_image: Optional[str] = None
22
 
23
 
24
  @app.post("/generate")
configs/config-dev-gigaquant.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:1",
45
+ "ae_device": "cuda:1",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "num_to_quant": 8000,
51
+ "quantize_extras": true
52
+ }
configs/config-dev.json CHANGED
@@ -47,5 +47,7 @@
47
  "flow_dtype": "float16",
48
  "ae_dtype": "bfloat16",
49
  "text_enc_dtype": "bfloat16",
50
- "num_to_quant": 20
 
 
51
  }
 
47
  "flow_dtype": "float16",
48
  "ae_dtype": "bfloat16",
49
  "text_enc_dtype": "bfloat16",
50
+ "num_to_quant": 22,
51
+ "compile_extras": false,
52
+ "compile_blocks": false
53
  }
cublas_linear.py CHANGED
@@ -1,152 +1 @@
1
- import math
2
- from typing import Literal, Optional
3
-
4
- import torch
5
- from torch.nn import functional as F
6
-
7
- from cublas_ops_ext import _simt_hgemv
8
- from cublas_ops_ext import cublas_hgemm_axbT as _cublas_hgemm_axbT
9
- from cublas_ops_ext import cublas_hgemm_batched_simple as _cublas_hgemm_batched_simple
10
- from cublas_ops_ext import (
11
- cublaslt_hgemm_batched_simple as _cublaslt_hgemm_batched_simple,
12
- )
13
- from cublas_ops_ext import cublaslt_hgemm_simple as _cublaslt_hgemm_simple
14
- from torch import Tensor, nn
15
-
16
- global has_moved
17
- has_moved = {idx: False for idx in range(torch.cuda.device_count())}
18
-
19
-
20
- class StaticState:
21
- workspace = {
22
- idx: torch.empty((1024 * 1024 * 8,), dtype=torch.uint8)
23
- for idx in range(torch.cuda.device_count())
24
- }
25
- workspace_size = workspace[0].nelement()
26
- bias_g = {
27
- idx: torch.tensor([], dtype=torch.float16)
28
- for idx in range(torch.cuda.device_count())
29
- }
30
-
31
- @classmethod
32
- def get(cls, __name: str, device: torch.device) -> torch.Any:
33
- global has_moved
34
- idx = device.index if device.index is not None else 0
35
- if not has_moved[idx]:
36
- cls.workspace[idx] = cls.workspace[idx].cuda(idx)
37
- cls.bias_g[idx] = cls.bias_g[idx].cuda(idx)
38
- has_moved[idx] = True
39
- if "bias" in __name:
40
- return cls.bias_g[idx]
41
- if "workspace" in __name:
42
- return cls.workspace[idx]
43
- if "workspace_size" in __name:
44
- return cls.workspace_size
45
-
46
-
47
- @torch.no_grad()
48
- def hgemv_simt(vec: torch.HalfTensor, mat: torch.HalfTensor, block_dim_x: int = 32):
49
- prev_dims = vec.shape[:-1]
50
- out = _simt_hgemv(mat, vec.view(-1, 1), block_dim_x=block_dim_x).view(
51
- *prev_dims, -1
52
- )
53
- return out
54
-
55
-
56
- @torch.no_grad()
57
- def cublas_half_matmul_batched_simple(a: torch.Tensor, b: torch.Tensor):
58
- out = _cublas_hgemm_batched_simple(a, b)
59
- return out
60
-
61
-
62
- @torch.no_grad()
63
- def cublas_half_matmul_simple(a: torch.Tensor, b: torch.Tensor):
64
- out = _cublas_hgemm_axbT(b, a)
65
- return out
66
-
67
-
68
- @torch.no_grad()
69
- def cublaslt_fused_half_matmul_simple(
70
- a: torch.Tensor,
71
- b: torch.Tensor,
72
- bias: Optional[torch.Tensor] = None,
73
- epilogue_str: Optional[Literal["NONE", "RELU", "GELU"]] = "NONE",
74
- ):
75
- if bias is None:
76
- bias = StaticState.get("bias", a.device)
77
- out = _cublaslt_hgemm_simple(
78
- a, b, bias, epilogue_str, StaticState.get("workspace", a.device)
79
- )
80
- return out
81
-
82
-
83
- @torch.no_grad()
84
- def cublaslt_fused_half_matmul_batched_simple(
85
- a: torch.Tensor,
86
- b: torch.Tensor,
87
- bias: Optional[torch.Tensor] = None,
88
- epilogue_str: Optional[Literal["NONE", "RELU", "GELU"]] = "NONE",
89
- ):
90
- if bias is None:
91
- bias = StaticState.get("bias", a.device)
92
- out = _cublaslt_hgemm_batched_simple(
93
- a, b, bias, epilogue_str, StaticState.get("workspace", a.device)
94
- )
95
- return out
96
-
97
-
98
- class CublasLinear(nn.Linear):
99
- def __init__(
100
- self,
101
- in_features,
102
- out_features,
103
- bias=True,
104
- device=None,
105
- dtype=torch.float16,
106
- epilogue_str="NONE",
107
- ):
108
- super().__init__(
109
- in_features, out_features, bias=bias, device=device, dtype=dtype
110
- )
111
- self._epilogue_str = epilogue_str
112
- self.has_bias = bias
113
- self.has_checked_weight = False
114
-
115
- def forward(self, x: Tensor) -> Tensor:
116
- if not self.has_checked_weight:
117
- if not self.weight.dtype == torch.float16:
118
- self.to(dtype=torch.float16)
119
- self.has_checked_weight = True
120
- out_dtype = x.dtype
121
- needs_convert = out_dtype != torch.float16
122
- if needs_convert:
123
- x = x.type(torch.float16)
124
-
125
- use_cublasLt = self.has_bias or self._epilogue_str != "NONE"
126
- if x.ndim == 1:
127
- x = x.unsqueeze(0)
128
- if math.prod(x.shape) == x.shape[-1]:
129
- out = F.linear(x, self.weight, bias=self.bias)
130
- if self._epilogue_str == "RELU":
131
- return F.relu(out)
132
- elif self._epilogue_str == "GELU":
133
- return F.gelu(out)
134
- if needs_convert:
135
- return out.type(out_dtype)
136
- return out
137
- if use_cublasLt:
138
- leading_dims = x.shape[:-1]
139
- x = x.reshape(-1, x.shape[-1])
140
- out = cublaslt_fused_half_matmul_simple(
141
- x, self.weight, bias=self.bias.data, epilogue_str=self._epilogue_str
142
- )
143
- if needs_convert:
144
- return out.view(*leading_dims, out.shape[-1]).type(out_dtype)
145
- return out.view(*leading_dims, out.shape[-1])
146
- else:
147
- leading_dims = x.shape[:-1]
148
- x = x.reshape(-1, x.shape[-1])
149
- out = cublas_half_matmul_simple(x, self.weight)
150
- if needs_convert:
151
- return out.view(*leading_dims, out.shape[-1]).type(out_dtype)
152
- return out.view(*leading_dims, out.shape[-1])
 
1
+ from cublas_ops import CublasLinear
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux_emphasis.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Optional
2
+ from pydash import flatten
3
+
4
+ import torch
5
+ from transformers.models.clip.tokenization_clip import CLIPTokenizer
6
+ from einops import repeat
7
+
8
+ if TYPE_CHECKING:
9
+ from flux_pipeline import FluxPipeline
10
+
11
+
12
+ def parse_prompt_attention(text):
13
+ """
14
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
15
+ Accepted tokens are:
16
+ (abc) - increases attention to abc by a multiplier of 1.1
17
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
18
+ [abc] - decreases attention to abc by a multiplier of 1.1
19
+ \\( - literal character '('
20
+ \\[ - literal character '['
21
+ \\) - literal character ')'
22
+ \\] - literal character ']'
23
+ \\ - literal character '\'
24
+ anything else - just text
25
+
26
+ >>> parse_prompt_attention('normal text')
27
+ [['normal text', 1.0]]
28
+ >>> parse_prompt_attention('an (important) word')
29
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
30
+ >>> parse_prompt_attention('(unbalanced')
31
+ [['unbalanced', 1.1]]
32
+ >>> parse_prompt_attention('\\(literal\\]')
33
+ [['(literal]', 1.0]]
34
+ >>> parse_prompt_attention('(unnecessary)(parens)')
35
+ [['unnecessaryparens', 1.1]]
36
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
37
+ [['a ', 1.0],
38
+ ['house', 1.5730000000000004],
39
+ [' ', 1.1],
40
+ ['on', 1.0],
41
+ [' a ', 1.1],
42
+ ['hill', 0.55],
43
+ [', sun, ', 1.1],
44
+ ['sky', 1.4641000000000006],
45
+ ['.', 1.1]]
46
+ """
47
+ import re
48
+
49
+ re_attention = re.compile(
50
+ r"""
51
+ \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
52
+ \)|]|[^\\()\[\]:]+|:
53
+ """,
54
+ re.X,
55
+ )
56
+
57
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
58
+
59
+ res = []
60
+ round_brackets = []
61
+ square_brackets = []
62
+
63
+ round_bracket_multiplier = 1.1
64
+ square_bracket_multiplier = 1 / 1.1
65
+
66
+ def multiply_range(start_position, multiplier):
67
+ for p in range(start_position, len(res)):
68
+ res[p][1] *= multiplier
69
+
70
+ for m in re_attention.finditer(text):
71
+ text = m.group(0)
72
+ weight = m.group(1)
73
+
74
+ if text.startswith("\\"):
75
+ res.append([text[1:], 1.0])
76
+ elif text == "(":
77
+ round_brackets.append(len(res))
78
+ elif text == "[":
79
+ square_brackets.append(len(res))
80
+ elif weight is not None and len(round_brackets) > 0:
81
+ multiply_range(round_brackets.pop(), float(weight))
82
+ elif text == ")" and len(round_brackets) > 0:
83
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
84
+ elif text == "]" and len(square_brackets) > 0:
85
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
86
+ else:
87
+ parts = re.split(re_break, text)
88
+ for i, part in enumerate(parts):
89
+ if i > 0:
90
+ res.append(["BREAK", -1])
91
+ res.append([part, 1.0])
92
+
93
+ for pos in round_brackets:
94
+ multiply_range(pos, round_bracket_multiplier)
95
+
96
+ for pos in square_brackets:
97
+ multiply_range(pos, square_bracket_multiplier)
98
+
99
+ if len(res) == 0:
100
+ res = [["", 1.0]]
101
+
102
+ # merge runs of identical weights
103
+ i = 0
104
+ while i + 1 < len(res):
105
+ if res[i][1] == res[i + 1][1]:
106
+ res[i][0] += res[i + 1][0]
107
+ res.pop(i + 1)
108
+ else:
109
+ i += 1
110
+
111
+ return res
112
+
113
+
114
+ def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
115
+ """
116
+ Get prompt token ids and weights, this function works for both prompt and negative prompt
117
+
118
+ Args:
119
+ pipe (CLIPTokenizer)
120
+ A CLIPTokenizer
121
+ prompt (str)
122
+ A prompt string with weights
123
+
124
+ Returns:
125
+ text_tokens (list)
126
+ A list contains token ids
127
+ text_weight (list)
128
+ A list contains the correspodent weight of token ids
129
+
130
+ Example:
131
+ import torch
132
+ from transformers import CLIPTokenizer
133
+
134
+ clip_tokenizer = CLIPTokenizer.from_pretrained(
135
+ "stablediffusionapi/deliberate-v2"
136
+ , subfolder = "tokenizer"
137
+ , dtype = torch.float16
138
+ )
139
+
140
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
141
+ clip_tokenizer = clip_tokenizer
142
+ ,prompt = "a (red:1.5) cat"*70
143
+ )
144
+ """
145
+ texts_and_weights = parse_prompt_attention(prompt)
146
+ text_tokens, text_weights = [], []
147
+ maxlen = clip_tokenizer.model_max_length
148
+ for word, weight in texts_and_weights:
149
+ # tokenize and discard the starting and the ending token
150
+ token = clip_tokenizer(
151
+ word, truncation=False, padding=False, add_special_tokens=False
152
+ ).input_ids
153
+ # so that tokenize whatever length prompt
154
+ # the returned token is a 1d list: [320, 1125, 539, 320]
155
+ print(
156
+ token,
157
+ "|FOR MODEL LEN{}|".format(maxlen),
158
+ clip_tokenizer.decode(
159
+ token, skip_special_tokens=True, clean_up_tokenization_spaces=True
160
+ ),
161
+ )
162
+ # merge the new tokens to the all tokens holder: text_tokens
163
+ text_tokens = [*text_tokens, *token]
164
+
165
+ # each token chunk will come with one weight, like ['red cat', 2.0]
166
+ # need to expand weight for each token.
167
+ chunk_weights = [weight] * len(token)
168
+
169
+ # append the weight back to the weight holder: text_weights
170
+ text_weights = [*text_weights, *chunk_weights]
171
+ return text_tokens, text_weights
172
+
173
+
174
+ def group_tokens_and_weights(
175
+ token_ids: list,
176
+ weights: list,
177
+ pad_last_block=False,
178
+ bos=49406,
179
+ eos=49407,
180
+ max_length=77,
181
+ pad_tokens=True,
182
+ ):
183
+ """
184
+ Produce tokens and weights in groups and pad the missing tokens
185
+
186
+ Args:
187
+ token_ids (list)
188
+ The token ids from tokenizer
189
+ weights (list)
190
+ The weights list from function get_prompts_tokens_with_weights
191
+ pad_last_block (bool)
192
+ Control if fill the last token list to 75 tokens with eos
193
+ Returns:
194
+ new_token_ids (2d list)
195
+ new_weights (2d list)
196
+
197
+ Example:
198
+ token_groups,weight_groups = group_tokens_and_weights(
199
+ token_ids = token_id_list
200
+ , weights = token_weight_list
201
+ )
202
+ """
203
+ max_len = max_length - 2 if max_length < 77 else max_length
204
+ # this will be a 2d list
205
+ new_token_ids = []
206
+ new_weights = []
207
+ while len(token_ids) >= max_len:
208
+ # get the first 75 tokens
209
+ head_75_tokens = [token_ids.pop(0) for _ in range(max_len)]
210
+ head_75_weights = [weights.pop(0) for _ in range(max_len)]
211
+
212
+ # extract token ids and weights
213
+
214
+ if pad_tokens:
215
+ if bos is not None:
216
+ temp_77_token_ids = [bos] + head_75_tokens + [eos]
217
+ temp_77_weights = [1.0] + head_75_weights + [1.0]
218
+ else:
219
+ temp_77_token_ids = head_75_tokens + [eos]
220
+ temp_77_weights = head_75_weights + [1.0]
221
+
222
+ # add 77 token and weights chunk to the holder list
223
+ new_token_ids.append(temp_77_token_ids)
224
+ new_weights.append(temp_77_weights)
225
+
226
+ # padding the left
227
+ if len(token_ids) > 0:
228
+ if pad_tokens:
229
+ padding_len = max_len - len(token_ids) if pad_last_block else 0
230
+
231
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
232
+ new_token_ids.append(temp_77_token_ids)
233
+
234
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
235
+ new_weights.append(temp_77_weights)
236
+ else:
237
+ new_token_ids.append(token_ids)
238
+ new_weights.append(weights)
239
+ return new_token_ids, new_weights
240
+
241
+
242
+ def standardize_tensor(
243
+ input_tensor: torch.Tensor, target_mean: float, target_std: float
244
+ ) -> torch.Tensor:
245
+ """
246
+ This function standardizes an input tensor so that it has a specific mean and standard deviation.
247
+
248
+ Parameters:
249
+ input_tensor (torch.Tensor): The tensor to standardize.
250
+ target_mean (float): The target mean for the tensor.
251
+ target_std (float): The target standard deviation for the tensor.
252
+
253
+ Returns:
254
+ torch.Tensor: The standardized tensor.
255
+ """
256
+
257
+ # First, compute the mean and std of the input tensor
258
+ mean = input_tensor.mean()
259
+ std = input_tensor.std()
260
+
261
+ # Then, standardize the tensor to have a mean of 0 and std of 1
262
+ standardized_tensor = (input_tensor - mean) / std
263
+
264
+ # Finally, scale the tensor to the target mean and std
265
+ output_tensor = standardized_tensor * target_std + target_mean
266
+
267
+ return output_tensor
268
+
269
+
270
+ def apply_weights(
271
+ prompt_tokens: torch.Tensor,
272
+ weight_tensor: torch.Tensor,
273
+ token_embedding: torch.Tensor,
274
+ eos_token_id: int,
275
+ pad_last_block: bool = True,
276
+ ) -> torch.FloatTensor:
277
+ mean = token_embedding.mean()
278
+ std = token_embedding.std()
279
+ if pad_last_block:
280
+ pooled_tensor = token_embedding[
281
+ torch.arange(token_embedding.shape[0], device=token_embedding.device),
282
+ (
283
+ prompt_tokens.to(dtype=torch.int, device=token_embedding.device)
284
+ == eos_token_id
285
+ )
286
+ .int()
287
+ .argmax(dim=-1),
288
+ ]
289
+ else:
290
+ pooled_tensor = token_embedding[:, -1]
291
+
292
+ for j in range(len(weight_tensor)):
293
+ if weight_tensor[j] != 1.0:
294
+ token_embedding[:, j] = (
295
+ pooled_tensor
296
+ + (token_embedding[:, j] - pooled_tensor) * weight_tensor[j]
297
+ )
298
+ return standardize_tensor(token_embedding, mean, std)
299
+
300
+
301
+ @torch.inference_mode()
302
+ def get_weighted_text_embeddings_flux(
303
+ pipe: "FluxPipeline",
304
+ prompt: str = "",
305
+ num_images_per_prompt: int = 1,
306
+ device: Optional[torch.device] = None,
307
+ target_device: Optional[torch.device] = torch.device("cuda:0"),
308
+ target_dtype: Optional[torch.dtype] = torch.bfloat16,
309
+ ):
310
+ """
311
+ This function can process long prompt with weights, no length limitation
312
+ for Stable Diffusion XL
313
+
314
+ Args:
315
+ pipe (StableDiffusionPipeline)
316
+ prompt (str)
317
+ prompt_2 (str)
318
+ neg_prompt (str)
319
+ neg_prompt_2 (str)
320
+ num_images_per_prompt (int)
321
+ device (torch.device)
322
+ Returns:
323
+ prompt_embeds (torch.Tensor)
324
+ neg_prompt_embeds (torch.Tensor)
325
+ """
326
+ device = device or pipe._execution_device
327
+
328
+ eos = pipe.clip.tokenizer.eos_token_id
329
+ eos_2 = pipe.t5.tokenizer.eos_token_id
330
+ bos = pipe.clip.tokenizer.bos_token_id
331
+ bos_2 = pipe.t5.tokenizer.bos_token_id
332
+
333
+ clip = pipe.clip.hf_module
334
+ t5 = pipe.t5.hf_module
335
+
336
+ tokenizer_clip = pipe.clip.tokenizer
337
+ tokenizer_t5 = pipe.t5.tokenizer
338
+
339
+ t5_length = 512 if pipe.name == "flux-dev" else 256
340
+ clip_length = 77
341
+
342
+ tokenizer_t5(
343
+ prompt,
344
+ add_special_tokens=True,
345
+ padding="max_length",
346
+ truncation=True,
347
+ max_length=t5_length,
348
+ return_tensors="pt",
349
+ )
350
+
351
+ # tokenizer 1
352
+ prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights(
353
+ tokenizer_clip, prompt
354
+ )
355
+
356
+ # tokenizer 2
357
+ prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights(
358
+ tokenizer_t5, prompt
359
+ )
360
+
361
+ prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights(
362
+ prompt_tokens_clip,
363
+ prompt_weights_clip,
364
+ pad_last_block=True,
365
+ bos=bos,
366
+ eos=eos,
367
+ max_length=clip_length,
368
+ )
369
+ prompt_tokens_t5_grouped, prompt_weights_t5_grouped = group_tokens_and_weights(
370
+ prompt_tokens_t5,
371
+ prompt_weights_t5,
372
+ pad_last_block=True,
373
+ bos=bos_2,
374
+ eos=eos_2,
375
+ max_length=t5_length,
376
+ pad_tokens=False,
377
+ )
378
+ prompt_tokens_t5 = flatten(prompt_tokens_t5_grouped)
379
+ prompt_weights_t5 = flatten(prompt_weights_t5_grouped)
380
+ prompt_tokens_clip = flatten(prompt_tokens_clip_grouped)
381
+ prompt_weights_clip = flatten(prompt_weights_clip_grouped)
382
+
383
+ prompt_tokens_clip = tokenizer_clip.decode(
384
+ prompt_tokens_clip, skip_special_tokens=True, clean_up_tokenization_spaces=True
385
+ )
386
+ prompt_tokens_clip = tokenizer_clip(
387
+ prompt_tokens_clip,
388
+ add_special_tokens=True,
389
+ padding="max_length",
390
+ truncation=True,
391
+ max_length=clip_length,
392
+ return_tensors="pt",
393
+ ).input_ids.to(device)
394
+ prompt_tokens_t5 = tokenizer_t5.decode(
395
+ prompt_tokens_t5, skip_special_tokens=True, clean_up_tokenization_spaces=True
396
+ )
397
+ prompt_tokens_t5 = tokenizer_t5(
398
+ prompt_tokens_t5,
399
+ add_special_tokens=True,
400
+ padding="max_length",
401
+ truncation=True,
402
+ max_length=t5_length,
403
+ return_tensors="pt",
404
+ ).input_ids.to(device)
405
+
406
+ prompt_weights_t5 = torch.cat(
407
+ [
408
+ torch.tensor(prompt_weights_t5, dtype=torch.float32),
409
+ torch.full(
410
+ (t5_length - torch.tensor(prompt_weights_t5).numel(),),
411
+ 1.0,
412
+ dtype=torch.float32,
413
+ ),
414
+ ],
415
+ dim=0,
416
+ ).to(device)
417
+
418
+ clip_embeds = clip(
419
+ prompt_tokens_clip, output_hidden_states=True, attention_mask=None
420
+ )["pooler_output"]
421
+ if clip_embeds.shape[0] == 1 and num_images_per_prompt > 1:
422
+ clip_embeds = repeat(clip_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
423
+
424
+ weight_tensor_t5 = torch.tensor(
425
+ flatten(prompt_weights_t5), dtype=torch.float32, device=device
426
+ )
427
+ t5_embeds = t5(prompt_tokens_t5, output_hidden_states=True, attention_mask=None)[
428
+ "last_hidden_state"
429
+ ]
430
+ t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2)
431
+ print(t5_embeds.shape)
432
+ if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1:
433
+ t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
434
+ txt_ids = torch.zeros(
435
+ num_images_per_prompt,
436
+ t5_embeds.shape[1],
437
+ 3,
438
+ device=target_device,
439
+ dtype=target_dtype,
440
+ )
441
+ t5_embeds = t5_embeds.to(target_device, dtype=target_dtype)
442
+ clip_embeds = clip_embeds.to(target_device, dtype=target_dtype)
443
+
444
+ return (
445
+ clip_embeds,
446
+ t5_embeds,
447
+ txt_ids,
448
+ )
flux_impl.py DELETED
@@ -1,272 +0,0 @@
1
- import io
2
- from typing import List
3
-
4
- import torch
5
- from torch import nn
6
-
7
- torch.backends.cuda.matmul.allow_tf32 = True
8
- torch.backends.cudnn.allow_tf32 = True
9
- torch.backends.cudnn.benchmark = True
10
- torch.backends.cudnn.benchmark_limit = 20
11
- torch.set_float32_matmul_precision("high")
12
- from torch._dynamo import config
13
- from torch._inductor import config as ind_config
14
-
15
- config.cache_size_limit = 10000000000
16
- ind_config.force_fuse_int_mm_with_mul = True
17
-
18
- from loguru import logger
19
- from torchao.quantization.quant_api import int8_weight_only, quantize_
20
-
21
- from cublas_linear import CublasLinear as F16Linear
22
- from modules.flux_model import RMSNorm
23
- from sampling import denoise, get_noise, get_schedule, prepare, unpack
24
- from turbojpeg_imgs import TurboImage
25
- from util import (
26
- ModelSpec,
27
- into_device,
28
- into_dtype,
29
- load_config_from_path,
30
- load_models_from_config,
31
- )
32
-
33
-
34
- class Model:
35
- def __init__(
36
- self,
37
- name,
38
- offload=False,
39
- clip=None,
40
- t5=None,
41
- model=None,
42
- ae=None,
43
- dtype=torch.bfloat16,
44
- verbose=False,
45
- flux_device="cuda:0",
46
- ae_device="cuda:1",
47
- clip_device="cuda:1",
48
- t5_device="cuda:1",
49
- ):
50
-
51
- self.name = name
52
- self.device_flux = (
53
- flux_device
54
- if isinstance(flux_device, torch.device)
55
- else torch.device(flux_device)
56
- )
57
- self.device_ae = (
58
- ae_device
59
- if isinstance(ae_device, torch.device)
60
- else torch.device(ae_device)
61
- )
62
- self.device_clip = (
63
- clip_device
64
- if isinstance(clip_device, torch.device)
65
- else torch.device(clip_device)
66
- )
67
- self.device_t5 = (
68
- t5_device
69
- if isinstance(t5_device, torch.device)
70
- else torch.device(t5_device)
71
- )
72
- self.dtype = dtype
73
- self.offload = offload
74
- self.clip = clip
75
- self.t5 = t5
76
- self.model = model
77
- self.ae = ae
78
- self.rng = torch.Generator(device="cpu")
79
- self.turbojpeg = TurboImage()
80
- self.verbose = verbose
81
-
82
- @torch.inference_mode()
83
- def generate(
84
- self,
85
- prompt,
86
- width=720,
87
- height=1023,
88
- num_steps=24,
89
- guidance=3.5,
90
- seed=None,
91
- ):
92
- if num_steps is None:
93
- num_steps = 4 if self.name == "flux-schnell" else 50
94
-
95
- # allow for packing and conversion to latent space
96
- height = 16 * (height // 16)
97
- width = 16 * (width // 16)
98
-
99
- if seed is None:
100
- seed = self.rng.seed()
101
- logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
102
-
103
- x = get_noise(
104
- 1,
105
- height,
106
- width,
107
- device=self.device_t5,
108
- dtype=torch.bfloat16,
109
- seed=seed,
110
- )
111
- inp = prepare(self.t5, self.clip, x, prompt=prompt)
112
- timesteps = get_schedule(
113
- num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")
114
- )
115
- for k in inp:
116
- inp[k] = inp[k].to(self.device_flux).type(self.dtype)
117
-
118
- # denoise initial noise
119
- x = denoise(
120
- self.model,
121
- **inp,
122
- timesteps=timesteps,
123
- guidance=guidance,
124
- dtype=self.dtype,
125
- device=self.device_flux,
126
- )
127
- inp.clear()
128
- timesteps.clear()
129
- torch.cuda.empty_cache()
130
- x = x.to(self.device_ae)
131
-
132
- # decode latents to pixel space
133
- x = unpack(x.float(), height, width)
134
- with torch.autocast(
135
- device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
136
- ):
137
- x = self.ae.decode(x)
138
-
139
- # bring into PIL format and save
140
- x = x.clamp(-1, 1)
141
- num_images = x.shape[0]
142
- images: List[torch.Tensor] = []
143
- for i in range(num_images):
144
- x = x[i].permute(1, 2, 0).add(1.0).mul(127.5).type(torch.uint8).contiguous()
145
- images.append(x)
146
- if len(images) == 1:
147
- im = images[0]
148
- else:
149
- im = torch.vstack(images)
150
-
151
- im = self.turbojpeg.encode_torch(im, quality=95)
152
- images.clear()
153
- return io.BytesIO(im)
154
-
155
-
156
- def quant_module(module, running_sum_quants=0, device_index=0):
157
- if isinstance(module, nn.Linear) and not isinstance(module, F16Linear):
158
- module.cuda(device_index)
159
- module.compile()
160
- quantize_(module, int8_weight_only())
161
- running_sum_quants += 1
162
- elif isinstance(module, F16Linear):
163
- module.cuda(device_index)
164
- elif isinstance(module, nn.Conv2d):
165
- module.cuda(device_index)
166
- elif isinstance(module, nn.Embedding):
167
- module.cuda(device_index)
168
- elif isinstance(module, nn.ConvTranspose2d):
169
- module.cuda(device_index)
170
- elif isinstance(module, nn.Conv1d):
171
- module.cuda(device_index)
172
- elif isinstance(module, nn.Conv3d):
173
- module.cuda(device_index)
174
- elif isinstance(module, nn.ConvTranspose3d):
175
- module.cuda(device_index)
176
- elif isinstance(module, nn.RMSNorm):
177
- module.cuda(device_index)
178
- elif isinstance(module, RMSNorm):
179
- module.cuda(device_index)
180
- elif isinstance(module, nn.LayerNorm):
181
- module.cuda(device_index)
182
- return running_sum_quants
183
-
184
-
185
- def full_quant(model, max_quants=24, current_quants=0, device_index=0):
186
- for module in model.modules():
187
- if current_quants < max_quants:
188
- current_quants = quant_module(
189
- module, current_quants, device_index=device_index
190
- )
191
- return current_quants
192
-
193
-
194
- @torch.inference_mode()
195
- def load_pipeline_from_config_path(path: str) -> Model:
196
- config = load_config_from_path(path)
197
- return load_pipeline_from_config(config)
198
-
199
-
200
- @torch.inference_mode()
201
- def load_pipeline_from_config(config: ModelSpec) -> Model:
202
- models = load_models_from_config(config)
203
- config = models.config
204
- num_quanted = 0
205
- max_quanted = config.num_to_quant
206
- flux_device = into_device(config.flux_device)
207
- ae_device = into_device(config.ae_device)
208
- clip_device = into_device(config.text_enc_device)
209
- t5_device = into_device(config.text_enc_device)
210
- flux_dtype = into_dtype(config.flow_dtype)
211
- device_index = flux_device.index or 0
212
- flow_model = models.flow.requires_grad_(False).eval().type(flux_dtype)
213
- for block in flow_model.single_blocks:
214
- block.cuda(flux_device)
215
- if num_quanted < max_quanted:
216
- num_quanted = quant_module(
217
- block.linear1, num_quanted, device_index=device_index
218
- )
219
-
220
- for block in flow_model.double_blocks:
221
- block.cuda(flux_device)
222
- if num_quanted < max_quanted:
223
- num_quanted = full_quant(
224
- block, max_quanted, num_quanted, device_index=device_index
225
- )
226
-
227
- to_gpu_extras = [
228
- "vector_in",
229
- "img_in",
230
- "txt_in",
231
- "time_in",
232
- "guidance_in",
233
- "final_layer",
234
- "pe_embedder",
235
- ]
236
- for extra in to_gpu_extras:
237
- getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
238
- return Model(
239
- name=config.version,
240
- clip=models.clip,
241
- t5=models.t5,
242
- model=flow_model,
243
- ae=models.ae,
244
- dtype=flux_dtype,
245
- verbose=False,
246
- flux_device=flux_device,
247
- ae_device=ae_device,
248
- clip_device=clip_device,
249
- t5_device=t5_device,
250
- )
251
-
252
-
253
- if __name__ == "__main__":
254
- pipe = load_pipeline_from_config_path("config-dev.json")
255
- o = pipe.generate(
256
- prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
257
- height=1024,
258
- width=1024,
259
- seed=13456,
260
- num_steps=24,
261
- guidance=3.0,
262
- )
263
- open("out.jpg", "wb").write(o.read())
264
- o = pipe.generate(
265
- prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
266
- height=1024,
267
- width=1024,
268
- seed=7,
269
- num_steps=24,
270
- guidance=3.0,
271
- )
272
- open("out2.jpg", "wb").write(o.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux_pipeline.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import math
4
+ from typing import TYPE_CHECKING, Callable, List
5
+ from PIL import Image
6
+ from einops import rearrange, repeat
7
+ import numpy as np
8
+
9
+ import torch
10
+
11
+ from flux_emphasis import get_weighted_text_embeddings_flux
12
+
13
+ torch.backends.cuda.matmul.allow_tf32 = True
14
+ torch.backends.cudnn.allow_tf32 = True
15
+ torch.backends.cudnn.benchmark = True
16
+ torch.backends.cudnn.benchmark_limit = 20
17
+ torch.set_float32_matmul_precision("high")
18
+ from torch._dynamo import config
19
+ from torch._inductor import config as ind_config
20
+ from pybase64 import standard_b64decode
21
+
22
+ config.cache_size_limit = 10000000000
23
+ ind_config.force_fuse_int_mm_with_mul = True
24
+
25
+ from loguru import logger
26
+ from turbojpeg_imgs import TurboImage
27
+ from torchvision.transforms import functional as TF
28
+ from tqdm import tqdm
29
+ from util import (
30
+ ModelSpec,
31
+ into_device,
32
+ into_dtype,
33
+ load_config_from_path,
34
+ load_models_from_config,
35
+ )
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from modules.conditioner import HFEmbedder
40
+ from modules.flux_model import Flux
41
+ from modules.autoencoder import AutoEncoder
42
+
43
+
44
+ class FluxPipeline:
45
+ def __init__(
46
+ self,
47
+ name: str,
48
+ offload: bool = False,
49
+ clip: "HFEmbedder" = None,
50
+ t5: "HFEmbedder" = None,
51
+ model: "Flux" = None,
52
+ ae: "AutoEncoder" = None,
53
+ dtype: torch.dtype = torch.bfloat16,
54
+ verbose: bool = False,
55
+ flux_device: torch.device | str = "cuda:0",
56
+ ae_device: torch.device | str = "cuda:1",
57
+ clip_device: torch.device | str = "cuda:1",
58
+ t5_device: torch.device | str = "cuda:1",
59
+ config: ModelSpec = None,
60
+ ):
61
+
62
+ self.name = name
63
+ self.device_flux = (
64
+ flux_device
65
+ if isinstance(flux_device, torch.device)
66
+ else torch.device(flux_device)
67
+ )
68
+ self.device_ae = (
69
+ ae_device
70
+ if isinstance(ae_device, torch.device)
71
+ else torch.device(ae_device)
72
+ )
73
+ self.device_clip = (
74
+ clip_device
75
+ if isinstance(clip_device, torch.device)
76
+ else torch.device(clip_device)
77
+ )
78
+ self.device_t5 = (
79
+ t5_device
80
+ if isinstance(t5_device, torch.device)
81
+ else torch.device(t5_device)
82
+ )
83
+ self.dtype = dtype
84
+ self.offload = offload
85
+ self.clip: "HFEmbedder" = clip
86
+ self.t5: "HFEmbedder" = t5
87
+ self.model: "Flux" = model
88
+ self.ae: "AutoEncoder" = ae
89
+ self.rng = torch.Generator(device="cpu")
90
+ self.turbojpeg = TurboImage()
91
+ self.verbose = verbose
92
+ self.ae_dtype = torch.bfloat16
93
+ self.config = config
94
+
95
+ @torch.inference_mode()
96
+ def prepare(
97
+ self,
98
+ img: torch.Tensor,
99
+ prompt: str | list[str],
100
+ target_device: torch.device = torch.device("cuda:0"),
101
+ target_dtype: torch.dtype = torch.float16,
102
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ bs, c, h, w = img.shape
104
+ if bs == 1 and not isinstance(prompt, str):
105
+ bs = len(prompt)
106
+ img = img.unfold(2, 2, 2).unfold(3, 2, 2).permute(0, 2, 3, 1, 4, 5)
107
+ img = img.reshape(img.shape[0], -1, img.shape[3] * img.shape[4] * img.shape[5])
108
+ assert img.shape == (
109
+ bs,
110
+ (h // 2) * (w // 2),
111
+ c * 2 * 2,
112
+ ), f"{img.shape} != {(bs, (h//2)*(w//2), c*2*2)}"
113
+ if img.shape[0] == 1 and bs > 1:
114
+ img = img[None].repeat_interleave(bs, dim=0)
115
+
116
+ img_ids = torch.zeros(
117
+ h // 2, w // 2, 3, device=target_device, dtype=target_dtype
118
+ )
119
+ img_ids[..., 1] = (
120
+ img_ids[..., 1]
121
+ + torch.arange(h // 2, device=target_device, dtype=target_dtype)[:, None]
122
+ )
123
+ img_ids[..., 2] = (
124
+ img_ids[..., 2]
125
+ + torch.arange(w // 2, device=target_device, dtype=target_dtype)[None, :]
126
+ )
127
+
128
+ img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
129
+ vec, txt, txt_ids = get_weighted_text_embeddings_flux(
130
+ self,
131
+ prompt,
132
+ num_images_per_prompt=bs,
133
+ device=self.device_clip,
134
+ target_device=target_device,
135
+ target_dtype=target_dtype,
136
+ )
137
+ return img, img_ids, vec, txt, txt_ids
138
+
139
+ @torch.inference_mode()
140
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
141
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
142
+
143
+ def get_lin_function(
144
+ self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
145
+ ) -> Callable[[float], float]:
146
+ m = (y2 - y1) / (x2 - x1)
147
+ b = y1 - m * x1
148
+ return lambda x: m * x + b
149
+
150
+ @torch.inference_mode()
151
+ def get_schedule(
152
+ self,
153
+ num_steps: int,
154
+ image_seq_len: int,
155
+ base_shift: float = 0.5,
156
+ max_shift: float = 1.15,
157
+ shift: bool = True,
158
+ ) -> list[float]:
159
+ # extra step for zero
160
+ timesteps = torch.linspace(1, 0, num_steps + 1)
161
+
162
+ # shifting the schedule to favor high timesteps for higher signal images
163
+ if shift:
164
+ # eastimate mu based on linear estimation between two points
165
+ mu = self.get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
166
+ timesteps = self.time_shift(mu, 1.0, timesteps)
167
+
168
+ return timesteps.tolist()
169
+
170
+ @torch.inference_mode()
171
+ def get_noise(
172
+ self,
173
+ num_samples: int,
174
+ height: int,
175
+ width: int,
176
+ generator: torch.Generator,
177
+ dtype=None,
178
+ device=None,
179
+ ):
180
+ if device is None:
181
+ device = self.device_flux
182
+ if dtype is None:
183
+ dtype = self.dtype
184
+ return torch.randn(
185
+ num_samples,
186
+ 16,
187
+ # allow for packing
188
+ 2 * math.ceil(height / 16),
189
+ 2 * math.ceil(width / 16),
190
+ device=device,
191
+ dtype=dtype,
192
+ generator=generator,
193
+ requires_grad=False,
194
+ )
195
+
196
+ @torch.inference_mode()
197
+ def into_bytes(self, x: torch.Tensor) -> io.BytesIO:
198
+ # bring into PIL format and save
199
+ x = x.clamp(-1, 1)
200
+ num_images = x.shape[0]
201
+ images: List[torch.Tensor] = []
202
+ for i in range(num_images):
203
+ x = x[i].permute(1, 2, 0).add(1.0).mul(127.5).type(torch.uint8).contiguous()
204
+ images.append(x)
205
+ if len(images) == 1:
206
+ im = images[0]
207
+ else:
208
+ im = torch.vstack(images)
209
+
210
+ im = self.turbojpeg.encode_torch(im, quality=95)
211
+ images.clear()
212
+ return io.BytesIO(im)
213
+
214
+ @torch.inference_mode()
215
+ def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
216
+ x = x.to(self.device_ae)
217
+ x = self.unpack(x.float(), height, width)
218
+ with torch.autocast(
219
+ device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
220
+ ):
221
+ x = self.ae.decode(x)
222
+ return x
223
+
224
+ def unpack(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
225
+ return rearrange(
226
+ x,
227
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
228
+ h=math.ceil(height / 16),
229
+ w=math.ceil(width / 16),
230
+ ph=2,
231
+ pw=2,
232
+ )
233
+
234
+ @torch.inference_mode()
235
+ def resize_center_crop(
236
+ self, img: torch.Tensor, height: int, width: int
237
+ ) -> torch.Tensor:
238
+ img = TF.resize(img, min(width, height))
239
+ img = TF.center_crop(img, (height, width))
240
+ return img
241
+
242
+ @torch.inference_mode()
243
+ def preprocess_latent(
244
+ self,
245
+ init_image: torch.Tensor | np.ndarray = None,
246
+ height: int = 720,
247
+ width: int = 1024,
248
+ num_steps: int = 20,
249
+ strength: float = 1.0,
250
+ generator: torch.Generator = None,
251
+ num_images: int = 1,
252
+ ) -> tuple[torch.Tensor, List[float]]:
253
+ # prepare input
254
+
255
+ if init_image is not None:
256
+ if isinstance(init_image, np.ndarray):
257
+ init_image = torch.from_numpy(init_image)
258
+
259
+ init_image = (
260
+ init_image.permute(2, 0, 1)
261
+ .contiguous()
262
+ .to(self.device_ae, dtype=self.ae_dtype)
263
+ .div(127.5)
264
+ .sub(1)[None, ...]
265
+ )
266
+ init_image = self.resize_center_crop(init_image, height, width)
267
+ with torch.autocast(
268
+ device_type=self.device_ae.type,
269
+ dtype=torch.bfloat16,
270
+ cache_enabled=False,
271
+ ):
272
+ init_image = (
273
+ self.ae.encode(init_image)
274
+ .to(dtype=self.dtype, device=self.device_flux)
275
+ .repeat(num_images, 1, 1, 1)
276
+ )
277
+
278
+ x = self.get_noise(
279
+ num_images,
280
+ height,
281
+ width,
282
+ device=self.device_flux,
283
+ dtype=self.dtype,
284
+ generator=generator,
285
+ )
286
+ timesteps = self.get_schedule(
287
+ num_steps=num_steps,
288
+ image_seq_len=x.shape[-1] * x.shape[-2] // 4,
289
+ shift=(self.name != "flux-schnell"),
290
+ )
291
+ if init_image is not None:
292
+ t_idx = int((1 - strength) * num_steps)
293
+ t = timesteps[t_idx]
294
+ timesteps = timesteps[t_idx:]
295
+ x = t * x + (1.0 - t) * init_image
296
+ return x, timesteps
297
+
298
+ @torch.inference_mode()
299
+ def generate(
300
+ self,
301
+ prompt: str,
302
+ width: int = 720,
303
+ height: int = 1024,
304
+ num_steps: int = 24,
305
+ guidance: float = 3.5,
306
+ seed: int | None = None,
307
+ init_image: torch.Tensor | str | None = None,
308
+ strength: float = 1.0,
309
+ silent: bool = False,
310
+ num_images: int = 1,
311
+ return_seed: bool = False,
312
+ ) -> io.BytesIO:
313
+ num_steps = 4 if self.name == "flux-schnell" else num_steps
314
+
315
+ if isinstance(init_image, str):
316
+ try:
317
+ init_image = Image.open(init_image)
318
+ except Exception as e:
319
+ init_image = Image.open(io.BytesIO(standard_b64decode(init_image)))
320
+ init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
321
+
322
+ # allow for packing and conversion to latent space
323
+ height = 16 * (height // 16)
324
+ width = 16 * (width // 16)
325
+ if isinstance(seed, str):
326
+ seed = int(seed)
327
+ if seed is None:
328
+ seed = self.rng.seed()
329
+ logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
330
+
331
+ generator = torch.Generator(device=self.device_flux).manual_seed(seed)
332
+ img, timesteps = self.preprocess_latent(
333
+ init_image=init_image,
334
+ height=height,
335
+ width=width,
336
+ num_steps=num_steps,
337
+ strength=strength,
338
+ generator=generator,
339
+ num_images=num_images,
340
+ )
341
+ img, img_ids, vec, txt, txt_ids = self.prepare(
342
+ img=img,
343
+ prompt=prompt,
344
+ target_device=self.device_flux,
345
+ target_dtype=self.dtype,
346
+ )
347
+
348
+ # this is ignored for schnell
349
+ guidance_vec = torch.full(
350
+ (img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
351
+ )
352
+ t_vec = None
353
+ for t_curr, t_prev in tqdm(
354
+ zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
355
+ ):
356
+ if t_vec is None:
357
+ t_vec = torch.full(
358
+ (img.shape[0],),
359
+ t_curr,
360
+ dtype=self.dtype,
361
+ device=self.device_flux,
362
+ )
363
+ else:
364
+ t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
365
+ pred = self.model.forward(
366
+ img=img,
367
+ img_ids=img_ids,
368
+ txt=txt,
369
+ txt_ids=txt_ids,
370
+ y=vec,
371
+ timesteps=t_vec,
372
+ guidance=guidance_vec,
373
+ )
374
+
375
+ img = img + (t_prev - t_curr) * pred
376
+
377
+ torch.cuda.empty_cache()
378
+
379
+ # decode latents to pixel space
380
+ img = self.vae_decode(img, height, width)
381
+
382
+ if return_seed:
383
+ return self.into_bytes(img), seed
384
+ return self.into_bytes(img)
385
+
386
+ @classmethod
387
+ def load_pipeline_from_config_path(cls, path: str) -> "FluxPipeline":
388
+ with torch.inference_mode():
389
+ config = load_config_from_path(path)
390
+ return cls.load_pipeline_from_config(config)
391
+
392
+ @classmethod
393
+ def load_pipeline_from_config(cls, config: ModelSpec) -> "FluxPipeline":
394
+ from quantize_swap_and_dispatch import quantize_and_dispatch_to_device
395
+
396
+ with torch.inference_mode():
397
+
398
+ models = load_models_from_config(config)
399
+ config = models.config
400
+ num_layers_to_quantize = config.num_to_quant
401
+ flux_device = into_device(config.flux_device)
402
+ ae_device = into_device(config.ae_device)
403
+ clip_device = into_device(config.text_enc_device)
404
+ t5_device = into_device(config.text_enc_device)
405
+ flux_dtype = into_dtype(config.flow_dtype)
406
+ flow_model = models.flow
407
+
408
+ flow_model = quantize_and_dispatch_to_device(
409
+ flow_model=flow_model,
410
+ flux_device=flux_device,
411
+ flux_dtype=flux_dtype,
412
+ num_layers_to_quantize=num_layers_to_quantize,
413
+ compile_extras=config.compile_extras,
414
+ compile_blocks=config.compile_blocks,
415
+ quantize_extras=config.quantize_extras,
416
+ )
417
+
418
+ return cls(
419
+ name=config.version,
420
+ clip=models.clip,
421
+ t5=models.t5,
422
+ model=flow_model,
423
+ ae=models.ae,
424
+ dtype=flux_dtype,
425
+ verbose=False,
426
+ flux_device=flux_device,
427
+ ae_device=ae_device,
428
+ clip_device=clip_device,
429
+ t5_device=t5_device,
430
+ config=config,
431
+ )
432
+
433
+
434
+ if __name__ == "__main__":
435
+ pipe = FluxPipeline.load_pipeline_from_config_path(
436
+ "configs/config-dev-gigaquant.json"
437
+ )
438
+ o = pipe.generate(
439
+ prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
440
+ height=1024,
441
+ width=1024,
442
+ num_steps=24,
443
+ guidance=3.0,
444
+ )
445
+ open("out.jpg", "wb").write(o.read())
446
+ o = pipe.generate(
447
+ prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
448
+ height=1024,
449
+ width=1024,
450
+ num_steps=24,
451
+ guidance=3.0,
452
+ )
453
+ open("out2.jpg", "wb").write(o.read())
454
+ o = pipe.generate(
455
+ prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
456
+ height=1024,
457
+ width=1024,
458
+ num_steps=24,
459
+ guidance=3.0,
460
+ )
461
+ open("out3.jpg", "wb").write(o.read())
main.py CHANGED
@@ -1,27 +1,43 @@
1
  import argparse
2
  import uvicorn
3
  from api import app
4
- from flux_impl import load_pipeline_from_config, load_pipeline_from_config_path
5
  from util import load_config, ModelVersion
6
 
7
 
8
  def parse_args():
9
  parser = argparse.ArgumentParser(description="Launch Flux API server")
10
  parser.add_argument(
 
11
  "--config-path",
12
  type=str,
13
  help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments",
14
  )
15
  parser.add_argument(
16
- "--port", type=int, default=8088, help="Port to run the server on"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
  parser.add_argument(
19
- "--host", type=str, default="0.0.0.0", help="Host to run the server on"
20
  )
21
- parser.add_argument("--flow-model-path", type=str, help="Path to the flow model")
22
- parser.add_argument("--text-enc-path", type=str, help="Path to the text encoder")
23
- parser.add_argument("--autoencoder-path", type=str, help="Path to the autoencoder")
24
  parser.add_argument(
 
25
  "--model-version",
26
  type=str,
27
  choices=["flux-dev", "flux-schnell"],
@@ -29,29 +45,40 @@ def parse_args():
29
  help="Choose model version",
30
  )
31
  parser.add_argument(
 
32
  "--flux-device",
33
  type=str,
34
  default="cuda:0",
35
  help="Device to run the flow model on",
36
  )
37
  parser.add_argument(
 
38
  "--text-enc-device",
39
  type=str,
40
  default="cuda:0",
41
  help="Device to run the text encoder on",
42
  )
43
  parser.add_argument(
 
44
  "--autoencoder-device",
45
  type=str,
46
  default="cuda:0",
47
  help="Device to run the autoencoder on",
48
  )
49
  parser.add_argument(
 
50
  "--num-to-quant",
51
  type=int,
52
  default=20,
53
  help="Number of linear layers in flow transformer (the 'unet') to quantize",
54
  )
 
 
 
 
 
 
 
55
 
56
  return parser.parse_args()
57
 
@@ -60,7 +87,7 @@ def main():
60
  args = parse_args()
61
 
62
  if args.config_path:
63
- app.state.model = load_pipeline_from_config_path(args.config_path)
64
  else:
65
  model_version = (
66
  ModelVersion.flux_dev
@@ -79,8 +106,10 @@ def main():
79
  text_enc_dtype="bfloat16",
80
  ae_dtype="bfloat16",
81
  num_to_quant=args.num_to_quant,
 
 
82
  )
83
- app.state.model = load_pipeline_from_config(config)
84
 
85
  uvicorn.run(app, host=args.host, port=args.port)
86
 
 
1
  import argparse
2
  import uvicorn
3
  from api import app
4
+ from flux_pipeline import FluxPipeline
5
  from util import load_config, ModelVersion
6
 
7
 
8
  def parse_args():
9
  parser = argparse.ArgumentParser(description="Launch Flux API server")
10
  parser.add_argument(
11
+ "-c",
12
  "--config-path",
13
  type=str,
14
  help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments",
15
  )
16
  parser.add_argument(
17
+ "-p",
18
+ "--port",
19
+ type=int,
20
+ default=8088,
21
+ help="Port to run the server on",
22
+ )
23
+ parser.add_argument(
24
+ "-H",
25
+ "--host",
26
+ type=str,
27
+ default="0.0.0.0",
28
+ help="Host to run the server on",
29
+ )
30
+ parser.add_argument(
31
+ "-f", "--flow-model-path", type=str, help="Path to the flow model"
32
+ )
33
+ parser.add_argument(
34
+ "-t", "--text-enc-path", type=str, help="Path to the text encoder"
35
  )
36
  parser.add_argument(
37
+ "-a", "--autoencoder-path", type=str, help="Path to the autoencoder"
38
  )
 
 
 
39
  parser.add_argument(
40
+ "-m",
41
  "--model-version",
42
  type=str,
43
  choices=["flux-dev", "flux-schnell"],
 
45
  help="Choose model version",
46
  )
47
  parser.add_argument(
48
+ "-F",
49
  "--flux-device",
50
  type=str,
51
  default="cuda:0",
52
  help="Device to run the flow model on",
53
  )
54
  parser.add_argument(
55
+ "-T",
56
  "--text-enc-device",
57
  type=str,
58
  default="cuda:0",
59
  help="Device to run the text encoder on",
60
  )
61
  parser.add_argument(
62
+ "-A",
63
  "--autoencoder-device",
64
  type=str,
65
  default="cuda:0",
66
  help="Device to run the autoencoder on",
67
  )
68
  parser.add_argument(
69
+ "-q",
70
  "--num-to-quant",
71
  type=int,
72
  default=20,
73
  help="Number of linear layers in flow transformer (the 'unet') to quantize",
74
  )
75
+ parser.add_argument(
76
+ "-C",
77
+ "--compile",
78
+ action="store_true",
79
+ default=False,
80
+ help="Compile the flow model with extra optimizations",
81
+ )
82
 
83
  return parser.parse_args()
84
 
 
87
  args = parse_args()
88
 
89
  if args.config_path:
90
+ app.state.model = FluxPipeline.load_pipeline_from_config_path(args.config_path)
91
  else:
92
  model_version = (
93
  ModelVersion.flux_dev
 
106
  text_enc_dtype="bfloat16",
107
  ae_dtype="bfloat16",
108
  num_to_quant=args.num_to_quant,
109
+ compile_extras=args.compile,
110
+ compile_blocks=args.compile,
111
  )
112
+ app.state.model = FluxPipeline.load_pipeline_from_config(config)
113
 
114
  uvicorn.run(app, host=args.host, port=args.port)
115
 
main_gr.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from flux_pipeline import FluxPipeline
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+
8
+ def create_demo(
9
+ config_path: str,
10
+ ):
11
+ generator = FluxPipeline.load_pipeline_from_config_path(config_path)
12
+
13
+ def generate_image(
14
+ prompt,
15
+ width,
16
+ height,
17
+ num_steps,
18
+ guidance,
19
+ seed,
20
+ init_image,
21
+ image2image_strength,
22
+ add_sampling_metadata,
23
+ ):
24
+
25
+ seed = int(seed)
26
+ if seed == -1:
27
+ seed = None
28
+ out = generator.generate(
29
+ prompt,
30
+ width,
31
+ height,
32
+ num_steps=num_steps,
33
+ guidance=guidance,
34
+ seed=seed,
35
+ init_image=init_image,
36
+ strength=image2image_strength,
37
+ silent=False,
38
+ num_images=1,
39
+ return_seed=True,
40
+ )
41
+ image_bytes = out[0]
42
+ return Image.open(image_bytes), str(out[1]), None
43
+
44
+ is_schnell = generator.config.version == "flux-schnell"
45
+
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown(f"# Flux Image Generation Demo - Model: {generator.config.version}")
48
+
49
+ with gr.Row():
50
+ with gr.Column():
51
+ prompt = gr.Textbox(
52
+ label="Prompt",
53
+ value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
54
+ )
55
+ do_img2img = gr.Checkbox(
56
+ label="Image to Image", value=False, interactive=not is_schnell
57
+ )
58
+ init_image = gr.Image(label="Input Image", visible=False)
59
+ image2image_strength = gr.Slider(
60
+ 0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
61
+ )
62
+
63
+ with gr.Accordion("Advanced Options", open=False):
64
+ width = gr.Slider(128, 8192, 1152, step=16, label="Width")
65
+ height = gr.Slider(128, 8192, 640, step=16, label="Height")
66
+ num_steps = gr.Slider(
67
+ 1, 50, 4 if is_schnell else 20, step=1, label="Number of steps"
68
+ )
69
+ guidance = gr.Slider(
70
+ 1.0,
71
+ 10.0,
72
+ 3.5,
73
+ step=0.1,
74
+ label="Guidance",
75
+ interactive=not is_schnell,
76
+ )
77
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
78
+ add_sampling_metadata = gr.Checkbox(
79
+ label="Add sampling parameters to metadata?", value=True
80
+ )
81
+
82
+ generate_btn = gr.Button("Generate")
83
+
84
+ with gr.Column(min_width="960px"):
85
+ output_image = gr.Image(label="Generated Image")
86
+ seed_output = gr.Number(label="Used Seed")
87
+ warning_text = gr.Textbox(label="Warning", visible=False)
88
+ # download_btn = gr.File(label="Download full-resolution")
89
+
90
+ def update_img2img(do_img2img):
91
+ return {
92
+ init_image: gr.update(visible=do_img2img),
93
+ image2image_strength: gr.update(visible=do_img2img),
94
+ }
95
+
96
+ do_img2img.change(
97
+ update_img2img, do_img2img, [init_image, image2image_strength]
98
+ )
99
+
100
+ generate_btn.click(
101
+ fn=generate_image,
102
+ inputs=[
103
+ prompt,
104
+ width,
105
+ height,
106
+ num_steps,
107
+ guidance,
108
+ seed,
109
+ init_image,
110
+ image2image_strength,
111
+ add_sampling_metadata,
112
+ ],
113
+ outputs=[output_image, seed_output, warning_text],
114
+ )
115
+
116
+ return demo
117
+
118
+
119
+ if __name__ == "__main__":
120
+ import argparse
121
+
122
+ parser = argparse.ArgumentParser(description="Flux")
123
+ parser.add_argument(
124
+ "--config", type=str, default="configs/config-dev.json", help="Config file path"
125
+ )
126
+ parser.add_argument(
127
+ "--share", action="store_true", help="Create a public link to your demo"
128
+ )
129
+ args = parser.parse_args()
130
+
131
+ demo = create_demo(args.config)
132
+ demo.launch(share=args.share)
modules/conditioner.py CHANGED
@@ -2,7 +2,7 @@ from torch import Tensor, nn
2
  import torch
3
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
4
 
5
- from transformers.utils.quantization_config import BitsAndBytesConfig
6
 
7
 
8
  class HFEmbedder(nn.Module):
@@ -30,8 +30,8 @@ class HFEmbedder(nn.Module):
30
  version,
31
  **hf_kwargs,
32
  device_map={"": device},
33
- quantization_config=BitsAndBytesConfig(
34
- load_in_4bit=True,
35
  ),
36
  )
37
 
 
2
  import torch
3
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
4
 
5
+ from transformers.utils.quantization_config import BitsAndBytesConfig, QuantoConfig
6
 
7
 
8
  class HFEmbedder(nn.Module):
 
30
  version,
31
  **hf_kwargs,
32
  device_map={"": device},
33
+ quantization_config=QuantoConfig(
34
+ weights="float8",
35
  ),
36
  )
37
 
modules/flux_model.py CHANGED
@@ -1,23 +1,25 @@
 
 
1
  import torch
2
 
 
3
  torch.backends.cuda.matmul.allow_tf32 = True
4
  torch.backends.cudnn.allow_tf32 = True
5
  torch.backends.cudnn.benchmark = True
6
  torch.backends.cudnn.benchmark_limit = 20
7
  torch.set_float32_matmul_precision("high")
8
  import math
9
- from dataclasses import dataclass
10
 
11
- from cublas_linear import CublasLinear as F16Linear
12
- from einops.layers.torch import Rearrange
13
  from torch import Tensor, nn
14
  from torch._dynamo import config
15
  from torch._inductor import config as ind_config
16
- from xformers.ops import memory_efficient_attention
17
  from pydantic import BaseModel
 
18
 
19
  config.cache_size_limit = 10000000000
20
- ind_config.force_fuse_int_mm_with_mul = True
 
21
 
22
 
23
  class FluxParams(BaseModel):
@@ -35,17 +37,16 @@ class FluxParams(BaseModel):
35
  guidance_embed: bool
36
 
37
 
38
- @torch.compile(mode="reduce-overhead")
 
39
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
40
  q, k = apply_rope(q, k, pe)
41
- x = memory_efficient_attention(
42
- q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
43
- )
44
  x = x.reshape(*x.shape[:-2], -1)
45
  return x
46
 
47
 
48
- @torch.compile(mode="reduce-overhead")
49
  def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
50
  scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
51
  omega = 1.0 / (theta**scale)
@@ -119,30 +120,21 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
119
  class MLPEmbedder(nn.Module):
120
  def __init__(self, in_dim: int, hidden_dim: int):
121
  super().__init__()
122
- self.in_layer = F16Linear(in_dim, hidden_dim, bias=True)
123
  self.silu = nn.SiLU()
124
- self.out_layer = F16Linear(hidden_dim, hidden_dim, bias=True)
125
 
126
  def forward(self, x: Tensor) -> Tensor:
127
  return self.out_layer(self.silu(self.in_layer(x)))
128
 
129
 
130
- @torch.compile(mode="reduce-overhead", dynamic=True)
131
- def calculation(
132
- x,
133
- ):
134
- rrms = torch.rsqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + 1e-6)
135
- x = x * rrms
136
- return x
137
-
138
-
139
  class RMSNorm(torch.nn.Module):
140
  def __init__(self, dim: int):
141
  super().__init__()
142
  self.scale = nn.Parameter(torch.ones(dim))
143
 
144
  def forward(self, x: Tensor):
145
- return calculation(x) * self.scale
146
 
147
 
148
  class QKNorm(torch.nn.Module):
@@ -163,25 +155,28 @@ class SelfAttention(nn.Module):
163
  self.num_heads = num_heads
164
  head_dim = dim // num_heads
165
 
166
- self.qkv = F16Linear(dim, dim * 3, bias=qkv_bias)
167
  self.norm = QKNorm(head_dim)
168
- self.proj = F16Linear(dim, dim)
169
- self.rearrange = Rearrange("B L (K H D) -> K B H L D", K=3, H=num_heads)
 
 
 
 
 
 
 
170
 
171
  def forward(self, x: Tensor, pe: Tensor) -> Tensor:
172
  qkv = self.qkv(x)
173
- q, k, v = self.rearrange(qkv)
174
  q, k = self.norm(q, k, v)
175
  x = attention(q, k, v, pe=pe)
176
  x = self.proj(x)
177
  return x
178
 
179
 
180
- @dataclass
181
- class ModulationOut:
182
- shift: Tensor
183
- scale: Tensor
184
- gate: Tensor
185
 
186
 
187
  class Modulation(nn.Module):
@@ -225,9 +220,9 @@ class DoubleStreamBlock(nn.Module):
225
 
226
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
227
  self.img_mlp = nn.Sequential(
228
- F16Linear(hidden_size, mlp_hidden_dim, bias=True),
229
  nn.GELU(approximate="tanh"),
230
- F16Linear(mlp_hidden_dim, hidden_size, bias=True),
231
  )
232
 
233
  self.txt_mod = Modulation(hidden_size, double=True)
@@ -238,13 +233,18 @@ class DoubleStreamBlock(nn.Module):
238
 
239
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
240
  self.txt_mlp = nn.Sequential(
241
- (F16Linear(hidden_size, mlp_hidden_dim, bias=True)),
242
  nn.GELU(approximate="tanh"),
243
- (F16Linear(mlp_hidden_dim, hidden_size, bias=True)),
244
- )
245
- self.rearrange_for_norm = Rearrange(
246
- "B L (K H D) -> K B H L D", K=3, H=num_heads
247
  )
 
 
 
 
 
 
 
 
248
 
249
  def forward(
250
  self,
@@ -316,7 +316,7 @@ class SingleStreamBlock(nn.Module):
316
  # qkv and mlp_in
317
  self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
318
  # proj and mlp_out
319
- self.linear2 = F16Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
320
 
321
  self.norm = QKNorm(head_dim)
322
 
@@ -325,9 +325,10 @@ class SingleStreamBlock(nn.Module):
325
 
326
  self.mlp_act = nn.GELU(approximate="tanh")
327
  self.modulation = Modulation(hidden_size, double=False)
328
- self.rearrange_for_norm = Rearrange(
329
- "B L (K H D) -> K B H L D", K=3, H=num_heads
330
- )
 
331
 
332
  def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
333
  mod = self.modulation(vec)[0]
@@ -338,7 +339,8 @@ class SingleStreamBlock(nn.Module):
338
  [3 * self.hidden_size, self.mlp_hidden_dim],
339
  dim=-1,
340
  )
341
- q, k, v = self.rearrange_for_norm(qkv)
 
342
  q, k = self.norm(q, k, v)
343
  attn = attention(q, k, v, pe=pe)
344
  output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
@@ -394,7 +396,7 @@ class Flux(nn.Module):
394
  axes_dim=params.axes_dim,
395
  dtype=self.dtype,
396
  )
397
- self.img_in = F16Linear(self.in_channels, self.hidden_size, bias=True)
398
  self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
399
  self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
400
  self.guidance_in = (
@@ -402,7 +404,7 @@ class Flux(nn.Module):
402
  if params.guidance_embed
403
  else nn.Identity()
404
  )
405
- self.txt_in = F16Linear(params.context_in_dim, self.hidden_size)
406
 
407
  self.double_blocks = nn.ModuleList(
408
  [
@@ -464,10 +466,13 @@ class Flux(nn.Module):
464
  ids = torch.cat((txt_ids, img_ids), dim=1)
465
  pe = self.pe_embedder(ids)
466
 
467
- for i, block in enumerate(self.double_blocks):
 
468
  img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
469
 
470
  img = torch.cat((txt, img), 1)
 
 
471
  for block in self.single_blocks:
472
  img = block(img, vec=vec, pe=pe)
473
 
@@ -476,17 +481,14 @@ class Flux(nn.Module):
476
  return img
477
 
478
  @classmethod
479
- def from_safetensors(
480
- self,
481
- model_path: str,
482
- model_params: FluxParams,
483
- dtype: torch.dtype = torch.bfloat16,
484
- device: torch.device = torch.device(
485
- "cuda" if torch.cuda.is_available() else "cpu"
486
- ),
487
- ):
488
 
489
- model = Flux(params=model_params, dtype=dtype)
490
- model.load_state_dict(model_path.state_dict())
491
- model.to(device)
492
- return model
 
1
+ from collections import namedtuple
2
+ import os
3
  import torch
4
 
5
+ DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
6
  torch.backends.cuda.matmul.allow_tf32 = True
7
  torch.backends.cudnn.allow_tf32 = True
8
  torch.backends.cudnn.benchmark = True
9
  torch.backends.cudnn.benchmark_limit = 20
10
  torch.set_float32_matmul_precision("high")
11
  import math
 
12
 
 
 
13
  from torch import Tensor, nn
14
  from torch._dynamo import config
15
  from torch._inductor import config as ind_config
16
+ from xformers.ops import memory_efficient_attention_forward
17
  from pydantic import BaseModel
18
+ from torch.nn import functional as F
19
 
20
  config.cache_size_limit = 10000000000
21
+ ind_config.compile_threads = os.cpu_count()
22
+ ind_config.shape_padding = True
23
 
24
 
25
  class FluxParams(BaseModel):
 
37
  guidance_embed: bool
38
 
39
 
40
+ # attention is always same shape each time it's called per H*W, so compile with fullgraph
41
+ @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE)
42
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
43
  q, k = apply_rope(q, k, pe)
44
+ x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
 
 
45
  x = x.reshape(*x.shape[:-2], -1)
46
  return x
47
 
48
 
49
+ @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE)
50
  def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
51
  scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
52
  omega = 1.0 / (theta**scale)
 
120
  class MLPEmbedder(nn.Module):
121
  def __init__(self, in_dim: int, hidden_dim: int):
122
  super().__init__()
123
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
124
  self.silu = nn.SiLU()
125
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
126
 
127
  def forward(self, x: Tensor) -> Tensor:
128
  return self.out_layer(self.silu(self.in_layer(x)))
129
 
130
 
 
 
 
 
 
 
 
 
 
131
  class RMSNorm(torch.nn.Module):
132
  def __init__(self, dim: int):
133
  super().__init__()
134
  self.scale = nn.Parameter(torch.ones(dim))
135
 
136
  def forward(self, x: Tensor):
137
+ return F.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
138
 
139
 
140
  class QKNorm(torch.nn.Module):
 
155
  self.num_heads = num_heads
156
  head_dim = dim // num_heads
157
 
158
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
159
  self.norm = QKNorm(head_dim)
160
+ self.proj = nn.Linear(dim, dim)
161
+ self.K = 3
162
+ self.H = self.num_heads
163
+ self.KH = self.K * self.H
164
+
165
+ def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
166
+ B, L, D = x.shape
167
+ q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
168
+ return q, k, v
169
 
170
  def forward(self, x: Tensor, pe: Tensor) -> Tensor:
171
  qkv = self.qkv(x)
172
+ q, k, v = self.rearrange_for_norm(qkv)
173
  q, k = self.norm(q, k, v)
174
  x = attention(q, k, v, pe=pe)
175
  x = self.proj(x)
176
  return x
177
 
178
 
179
+ ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"])
 
 
 
 
180
 
181
 
182
  class Modulation(nn.Module):
 
220
 
221
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
222
  self.img_mlp = nn.Sequential(
223
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
224
  nn.GELU(approximate="tanh"),
225
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
226
  )
227
 
228
  self.txt_mod = Modulation(hidden_size, double=True)
 
233
 
234
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
235
  self.txt_mlp = nn.Sequential(
236
+ (nn.Linear(hidden_size, mlp_hidden_dim, bias=True)),
237
  nn.GELU(approximate="tanh"),
238
+ (nn.Linear(mlp_hidden_dim, hidden_size, bias=True)),
 
 
 
239
  )
240
+ self.K = 3
241
+ self.H = self.num_heads
242
+ self.KH = self.K * self.H
243
+
244
+ def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
245
+ B, L, D = x.shape
246
+ q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
247
+ return q, k, v
248
 
249
  def forward(
250
  self,
 
316
  # qkv and mlp_in
317
  self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
318
  # proj and mlp_out
319
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
320
 
321
  self.norm = QKNorm(head_dim)
322
 
 
325
 
326
  self.mlp_act = nn.GELU(approximate="tanh")
327
  self.modulation = Modulation(hidden_size, double=False)
328
+
329
+ self.K = 3
330
+ self.H = self.num_heads
331
+ self.KH = self.K * self.H
332
 
333
  def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
334
  mod = self.modulation(vec)[0]
 
339
  [3 * self.hidden_size, self.mlp_hidden_dim],
340
  dim=-1,
341
  )
342
+ B, L, D = qkv.shape
343
+ q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
344
  q, k = self.norm(q, k, v)
345
  attn = attention(q, k, v, pe=pe)
346
  output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)).clamp(
 
396
  axes_dim=params.axes_dim,
397
  dtype=self.dtype,
398
  )
399
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
400
  self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
401
  self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
402
  self.guidance_in = (
 
404
  if params.guidance_embed
405
  else nn.Identity()
406
  )
407
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
408
 
409
  self.double_blocks = nn.ModuleList(
410
  [
 
466
  ids = torch.cat((txt_ids, img_ids), dim=1)
467
  pe = self.pe_embedder(ids)
468
 
469
+ # double stream blocks
470
+ for block in self.double_blocks:
471
  img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
472
 
473
  img = torch.cat((txt, img), 1)
474
+
475
+ # single stream blocks
476
  for block in self.single_blocks:
477
  img = block(img, vec=vec, pe=pe)
478
 
 
481
  return img
482
 
483
  @classmethod
484
+ def from_pretrained(cls, path: str, dtype: torch.dtype = torch.bfloat16) -> "Flux":
485
+ from util import load_config_from_path
486
+ from safetensors.torch import load_file
487
+
488
+ config = load_config_from_path(path)
489
+ with torch.device("meta"):
490
+ klass = cls(params=config.params, dtype=dtype).type(dtype)
 
 
491
 
492
+ ckpt = load_file(config.ckpt_path, device="cpu")
493
+ klass.load_state_dict(ckpt, assign=True)
494
+ return klass.to("cpu")
 
quantize_swap_and_dispatch.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fnmatch import fnmatch
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ from click import secho
6
+ from cublas_ops import CublasLinear
7
+
8
+ from quanto.nn import QModuleMixin, quantize_module, QLinear, QConv2d, QLayerNorm
9
+ from quanto.tensor import Optimizer, qtype, qfloat8
10
+ from torch import nn
11
+
12
+
13
+ def _set_module_by_name(parent_module, name, child_module):
14
+ module_names = name.split(".")
15
+ if len(module_names) == 1:
16
+ setattr(parent_module, name, child_module)
17
+ else:
18
+ parent_module_name = name[: name.rindex(".")]
19
+ parent_module = parent_module.get_submodule(parent_module_name)
20
+ setattr(parent_module, module_names[-1], child_module)
21
+
22
+
23
+ def _quantize_submodule(
24
+ model: torch.nn.Module,
25
+ name: str,
26
+ module: torch.nn.Module,
27
+ weights: Optional[Union[str, qtype]] = None,
28
+ activations: Optional[Union[str, qtype]] = None,
29
+ optimizer: Optional[Optimizer] = None,
30
+ ):
31
+ if isinstance(module, CublasLinear):
32
+ return 0
33
+ num_quant = 0
34
+ qmodule = quantize_module(
35
+ module, weights=weights, activations=activations, optimizer=optimizer
36
+ )
37
+ if qmodule is not None:
38
+ _set_module_by_name(model, name, qmodule)
39
+ # num_quant += 1
40
+ qmodule.name = name
41
+ for name, param in module.named_parameters():
42
+ # Save device memory by clearing parameters
43
+ setattr(module, name, None)
44
+ del param
45
+ num_quant += 1
46
+
47
+ return num_quant
48
+
49
+
50
+ def _quantize(
51
+ model: torch.nn.Module,
52
+ weights: Optional[Union[str, qtype]] = None,
53
+ activations: Optional[Union[str, qtype]] = None,
54
+ optimizer: Optional[Optimizer] = None,
55
+ include: Optional[Union[str, List[str]]] = None,
56
+ exclude: Optional[Union[str, List[str]]] = None,
57
+ ):
58
+ """Quantize the specified model submodules
59
+
60
+ Recursively quantize the submodules of the specified parent model.
61
+
62
+ Only modules that have quantized counterparts will be quantized.
63
+
64
+ If include patterns are specified, the submodule name must match one of them.
65
+
66
+ If exclude patterns are specified, the submodule must not match one of them.
67
+
68
+ Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
69
+ https://docs.python.org/3/library/fnmatch.html for more details.
70
+
71
+ Note: quantization happens in-place and modifies the original model and its descendants.
72
+
73
+ Args:
74
+ model (`torch.nn.Module`): the model whose submodules will be quantized.
75
+ weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
76
+ activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
77
+ include (`Optional[Union[str, List[str]]]`):
78
+ Patterns constituting the allowlist. If provided, module names must match at
79
+ least one pattern from the allowlist.
80
+ exclude (`Optional[Union[str, List[str]]]`):
81
+ Patterns constituting the denylist. If provided, module names must not match
82
+ any patterns from the denylist.
83
+ """
84
+ num_quant = 0
85
+ if include is not None:
86
+ include = [include] if isinstance(include, str) else exclude
87
+ if exclude is not None:
88
+ exclude = [exclude] if isinstance(exclude, str) else exclude
89
+ for name, m in model.named_modules():
90
+ if include is not None and not any(
91
+ fnmatch(name, pattern) for pattern in include
92
+ ):
93
+ continue
94
+ if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
95
+ continue
96
+ num_quant += _quantize_submodule(
97
+ model,
98
+ name,
99
+ m,
100
+ weights=weights,
101
+ activations=activations,
102
+ optimizer=optimizer,
103
+ )
104
+ return num_quant
105
+
106
+
107
+ def _freeze(model):
108
+ for name, m in model.named_modules():
109
+ if isinstance(m, QModuleMixin):
110
+ m.freeze()
111
+
112
+
113
+ def _is_block_compilable(module: nn.Module) -> bool:
114
+ for module in module.modules():
115
+ if _is_quantized(module):
116
+ return False
117
+ if _is_quantized(module):
118
+ return False
119
+ return True
120
+
121
+
122
+ def _simple_swap_linears(model: nn.Module, root_name: str = ""):
123
+ for name, module in model.named_children():
124
+ if _is_linear(module):
125
+ weights = module.weight.data
126
+ bias = None
127
+ if module.bias is not None:
128
+ bias = module.bias.data
129
+ with torch.device(module.weight.device):
130
+ new_cublas = CublasLinear(
131
+ module.in_features,
132
+ module.out_features,
133
+ bias=bias is not None,
134
+ device=module.weight.device,
135
+ dtype=module.weight.dtype,
136
+ )
137
+ new_cublas.weight.data = weights
138
+ if bias is not None:
139
+ new_cublas.bias.data = bias
140
+ setattr(model, name, new_cublas)
141
+ if root_name == "":
142
+ secho(f"Replaced {name} with CublasLinear", fg="green")
143
+ else:
144
+ secho(f"Replaced {root_name}.{name} with CublasLinear", fg="green")
145
+ else:
146
+ if root_name == "":
147
+ _simple_swap_linears(module, str(name))
148
+ else:
149
+ _simple_swap_linears(module, str(root_name) + "." + str(name))
150
+
151
+
152
+ def _full_quant(
153
+ model, max_quants=24, current_quants=0, quantization_dtype: qtype = qfloat8
154
+ ):
155
+ if current_quants < max_quants:
156
+ current_quants += _quantize(model, quantization_dtype)
157
+ _freeze(model)
158
+ print(f"Quantized {current_quants} modules")
159
+ return current_quants
160
+
161
+
162
+ def _is_linear(module: nn.Module) -> bool:
163
+ return not isinstance(
164
+ module, (QLinear, QConv2d, QLayerNorm, CublasLinear)
165
+ ) and isinstance(module, nn.Linear)
166
+
167
+
168
+ def _is_quantized(module: nn.Module) -> bool:
169
+ return isinstance(module, (QLinear, QConv2d, QLayerNorm))
170
+
171
+
172
+ def quantize_and_dispatch_to_device(
173
+ flow_model: nn.Module,
174
+ flux_device: torch.device = torch.device("cuda"),
175
+ flux_dtype: torch.dtype = torch.float16,
176
+ num_layers_to_quantize: int = 20,
177
+ quantization_dtype: qtype = qfloat8,
178
+ compile_blocks: bool = True,
179
+ compile_extras: bool = True,
180
+ quantize_extras: bool = False,
181
+ ):
182
+ num_quanted = 0
183
+ flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
184
+ for block in flow_model.single_blocks:
185
+ block.cuda(flux_device)
186
+ if num_quanted < num_layers_to_quantize:
187
+ num_quanted = _full_quant(
188
+ block,
189
+ num_layers_to_quantize,
190
+ num_quanted,
191
+ quantization_dtype=quantization_dtype,
192
+ )
193
+
194
+ for block in flow_model.double_blocks:
195
+ block.cuda(flux_device)
196
+ if num_quanted < num_layers_to_quantize:
197
+ num_quanted = _full_quant(
198
+ block,
199
+ num_layers_to_quantize,
200
+ num_quanted,
201
+ quantization_dtype=quantization_dtype,
202
+ )
203
+
204
+ to_gpu_extras = [
205
+ "vector_in",
206
+ "img_in",
207
+ "txt_in",
208
+ "time_in",
209
+ "guidance_in",
210
+ "final_layer",
211
+ "pe_embedder",
212
+ ]
213
+
214
+ if compile_blocks:
215
+ for i, block in enumerate(flow_model.single_blocks):
216
+ if _is_block_compilable(block):
217
+ block.compile()
218
+ secho(f"Compiled block {i}", fg="green")
219
+ for i, block in enumerate(flow_model.double_blocks):
220
+ if _is_block_compilable(block):
221
+ block.compile()
222
+ secho(f"Compiled block {i}", fg="green")
223
+
224
+ _simple_swap_linears(flow_model)
225
+ for extra in to_gpu_extras:
226
+ m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
227
+ if compile_blocks:
228
+ if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
229
+ m_extra.compile()
230
+ secho(
231
+ f"Compiled extra {extra} -- {m_extra.__class__.__name__}",
232
+ fg="green",
233
+ )
234
+ elif quantize_extras:
235
+ _full_quant(
236
+ m_extra,
237
+ current_quants=num_quanted,
238
+ max_quants=num_layers_to_quantize,
239
+ quantization_dtype=quantization_dtype,
240
+ )
241
+ return flow_model
sampling.py CHANGED
@@ -32,7 +32,12 @@ def get_noise(
32
 
33
  @torch.inference_mode()
34
  def prepare(
35
- t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]
 
 
 
 
 
36
  ) -> dict[str, Tensor]:
37
  bs, c, h, w = img.shape
38
  if bs == 1 and not isinstance(prompt, str):
@@ -42,28 +47,34 @@ def prepare(
42
  if img.shape[0] == 1 and bs > 1:
43
  img = repeat(img, "1 ... -> bs ...", bs=bs)
44
 
45
- img_ids = torch.zeros(h // 2, w // 2, 3)
46
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
47
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
 
 
 
 
 
 
48
  img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
49
 
50
  if isinstance(prompt, str):
51
  prompt = [prompt]
52
- txt = t5(prompt)
53
  if txt.shape[0] == 1 and bs > 1:
54
  txt = repeat(txt, "1 ... -> bs ...", bs=bs)
55
- txt_ids = torch.zeros(bs, txt.shape[1], 3)
56
 
57
- vec = clip(prompt)
58
  if vec.shape[0] == 1 and bs > 1:
59
  vec = repeat(vec, "1 ... -> bs ...", bs=bs)
60
 
61
  return {
62
  "img": img,
63
- "img_ids": img_ids.to(img.device),
64
- "txt": txt.to(img.device),
65
- "txt_ids": txt_ids.to(img.device),
66
- "vec": vec.to(img.device),
67
  }
68
 
69
 
@@ -116,11 +127,6 @@ def denoise(
116
  from tqdm import tqdm
117
 
118
  # this is ignored for schnell
119
- img = img.to(device=device, dtype=dtype)
120
- img_ids = img_ids.to(device=device, dtype=dtype)
121
- txt = txt.to(device=device, dtype=dtype)
122
- txt_ids = txt_ids.to(device=device, dtype=dtype)
123
- vec = vec.to(device=device, dtype=dtype)
124
  guidance_vec = torch.full((img.shape[0],), guidance, device=device, dtype=dtype)
125
  for t_curr, t_prev in tqdm(
126
  zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1
 
32
 
33
  @torch.inference_mode()
34
  def prepare(
35
+ t5: HFEmbedder,
36
+ clip: HFEmbedder,
37
+ img: Tensor,
38
+ prompt: str | list[str],
39
+ target_device: torch.device = torch.device("cuda:0"),
40
+ target_dtype: torch.dtype = torch.float16,
41
  ) -> dict[str, Tensor]:
42
  bs, c, h, w = img.shape
43
  if bs == 1 and not isinstance(prompt, str):
 
47
  if img.shape[0] == 1 and bs > 1:
48
  img = repeat(img, "1 ... -> bs ...", bs=bs)
49
 
50
+ img_ids = torch.zeros(h // 2, w // 2, 3, device=target_device, dtype=target_dtype)
51
+ img_ids[..., 1] = (
52
+ img_ids[..., 1]
53
+ + torch.arange(h // 2, device=target_device, dtype=target_dtype)[:, None]
54
+ )
55
+ img_ids[..., 2] = (
56
+ img_ids[..., 2]
57
+ + torch.arange(w // 2, device=target_device, dtype=target_dtype)[None, :]
58
+ )
59
  img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
60
 
61
  if isinstance(prompt, str):
62
  prompt = [prompt]
63
+ txt = t5(prompt).to(target_device, dtype=target_dtype)
64
  if txt.shape[0] == 1 and bs > 1:
65
  txt = repeat(txt, "1 ... -> bs ...", bs=bs)
66
+ txt_ids = torch.zeros(bs, txt.shape[1], 3, device=target_device, dtype=target_dtype)
67
 
68
+ vec = clip(prompt).to(target_device, dtype=target_dtype)
69
  if vec.shape[0] == 1 and bs > 1:
70
  vec = repeat(vec, "1 ... -> bs ...", bs=bs)
71
 
72
  return {
73
  "img": img,
74
+ "img_ids": img_ids,
75
+ "txt": txt,
76
+ "txt_ids": txt_ids,
77
+ "vec": vec,
78
  }
79
 
80
 
 
127
  from tqdm import tqdm
128
 
129
  # this is ignored for schnell
 
 
 
 
 
130
  guidance_vec = torch.full((img.shape[0],), guidance, device=device, dtype=dtype)
131
  for t_curr, t_prev in tqdm(
132
  zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1
util.py CHANGED
@@ -36,6 +36,9 @@ class ModelSpec(BaseModel):
36
  ae_dtype: str = "bfloat16"
37
  text_enc_dtype: str = "bfloat16"
38
  num_to_quant: Optional[int] = 20
 
 
 
39
 
40
  model_config: ConfigDict = {
41
  "arbitrary_types_allowed": True,
@@ -93,6 +96,8 @@ def load_config(
93
  ae_dtype: str = "bfloat16",
94
  text_enc_dtype: str = "bfloat16",
95
  num_to_quant: Optional[int] = 20,
 
 
96
  ):
97
  text_enc_device = str(parse_device(text_enc_device))
98
  ae_device = str(parse_device(ae_device))
@@ -144,6 +149,8 @@ def load_config(
144
  text_enc_dtype=text_enc_dtype,
145
  text_enc_max_length=512 if name == ModelVersion.flux_dev else 256,
146
  num_to_quant=num_to_quant,
 
 
147
  )
148
 
149
 
 
36
  ae_dtype: str = "bfloat16"
37
  text_enc_dtype: str = "bfloat16"
38
  num_to_quant: Optional[int] = 20
39
+ quantize_extras: bool = False
40
+ compile_extras: bool = False
41
+ compile_blocks: bool = False
42
 
43
  model_config: ConfigDict = {
44
  "arbitrary_types_allowed": True,
 
96
  ae_dtype: str = "bfloat16",
97
  text_enc_dtype: str = "bfloat16",
98
  num_to_quant: Optional[int] = 20,
99
+ compile_extras: bool = False,
100
+ compile_blocks: bool = False,
101
  ):
102
  text_enc_device = str(parse_device(text_enc_device))
103
  ae_device = str(parse_device(ae_device))
 
149
  text_enc_dtype=text_enc_dtype,
150
  text_enc_max_length=512 if name == ModelVersion.flux_dev else 256,
151
  num_to_quant=num_to_quant,
152
+ compile_extras=compile_extras,
153
+ compile_blocks=compile_blocks,
154
  )
155
 
156