Remove unnecessary code, hide prints behind debug flag, hide warnings
Browse files- float8_quantize.py +15 -24
- flux_emphasis.py +16 -11
- flux_pipeline.py +26 -4
- modules/conditioner.py +0 -13
- quantize_swap_and_dispatch.py +0 -274
- util.py +1 -0
float8_quantize.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from typing import Any, Mapping
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
from torchao.float8.float8_utils import (
|
@@ -11,6 +10,7 @@ import math
|
|
11 |
from torch.compiler import is_compiling
|
12 |
from torch import __version__
|
13 |
from torch.version import cuda
|
|
|
14 |
|
15 |
IS_TORCH_2_4 = __version__ < (2, 4, 9)
|
16 |
LT_TORCH_2_4 = __version__ < (2, 4)
|
@@ -29,23 +29,7 @@ try:
|
|
29 |
except ImportError:
|
30 |
CublasLinear = type(None)
|
31 |
|
32 |
-
|
33 |
-
def check_scale_tensor(tensor):
|
34 |
-
return (
|
35 |
-
tensor is not None
|
36 |
-
and isinstance(tensor, torch.Tensor)
|
37 |
-
and tensor.dtype == torch.float32
|
38 |
-
and tensor.numel() == 1
|
39 |
-
and tensor != torch.zeros_like(tensor)
|
40 |
-
)
|
41 |
-
|
42 |
-
|
43 |
-
def check_scale_in_state_dict(state_dict, key):
|
44 |
-
return key in state_dict and check_scale_tensor(state_dict[key])
|
45 |
-
|
46 |
-
|
47 |
-
def check_scales_given_state_dict_and_keys(state_dict, keys):
|
48 |
-
return all(check_scale_in_state_dict(state_dict, key) for key in keys)
|
49 |
|
50 |
|
51 |
class F8Linear(nn.Module):
|
@@ -245,6 +229,7 @@ class F8Linear(nn.Module):
|
|
245 |
init.uniform_(self.bias, -bound, bound)
|
246 |
self.quantize_weight()
|
247 |
self.max_value = torch.finfo(self.float8_dtype).max
|
|
|
248 |
|
249 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
250 |
if self.input_scale_initialized or is_compiling():
|
@@ -280,7 +265,7 @@ class F8Linear(nn.Module):
|
|
280 |
linear: nn.Linear,
|
281 |
float8_dtype=torch.float8_e4m3fn,
|
282 |
input_float8_dtype=torch.float8_e5m2,
|
283 |
-
):
|
284 |
f8_lin = cls(
|
285 |
in_features=linear.in_features,
|
286 |
out_features=linear.out_features,
|
@@ -300,7 +285,7 @@ def recursive_swap_linears(
|
|
300 |
model: nn.Module,
|
301 |
float8_dtype=torch.float8_e4m3fn,
|
302 |
input_float8_dtype=torch.float8_e5m2,
|
303 |
-
):
|
304 |
"""
|
305 |
Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
|
306 |
|
@@ -337,23 +322,29 @@ def recursive_swap_linears(
|
|
337 |
|
338 |
@torch.inference_mode()
|
339 |
def quantize_flow_transformer_and_dispatch_float8(
|
340 |
-
flow_model:
|
341 |
device=torch.device("cuda"),
|
342 |
float8_dtype=torch.float8_e4m3fn,
|
343 |
input_float8_dtype=torch.float8_e5m2,
|
344 |
offload_flow=False,
|
345 |
-
):
|
346 |
"""
|
347 |
Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
"""
|
349 |
-
for
|
350 |
module.to(device)
|
351 |
module.eval()
|
352 |
recursive_swap_linears(
|
353 |
module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
|
354 |
)
|
355 |
torch.cuda.empty_cache()
|
356 |
-
for
|
357 |
module.to(device)
|
358 |
module.eval()
|
359 |
recursive_swap_linears(
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from torchao.float8.float8_utils import (
|
|
|
10 |
from torch.compiler import is_compiling
|
11 |
from torch import __version__
|
12 |
from torch.version import cuda
|
13 |
+
from typing import TypeVar
|
14 |
|
15 |
IS_TORCH_2_4 = __version__ < (2, 4, 9)
|
16 |
LT_TORCH_2_4 = __version__ < (2, 4)
|
|
|
29 |
except ImportError:
|
30 |
CublasLinear = type(None)
|
31 |
|
32 |
+
FluxType = TypeVar("FluxType", nn.Module)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
class F8Linear(nn.Module):
|
|
|
229 |
init.uniform_(self.bias, -bound, bound)
|
230 |
self.quantize_weight()
|
231 |
self.max_value = torch.finfo(self.float8_dtype).max
|
232 |
+
self.input_max_value = torch.finfo(self.input_float8_dtype).max
|
233 |
|
234 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
235 |
if self.input_scale_initialized or is_compiling():
|
|
|
265 |
linear: nn.Linear,
|
266 |
float8_dtype=torch.float8_e4m3fn,
|
267 |
input_float8_dtype=torch.float8_e5m2,
|
268 |
+
) -> "F8Linear":
|
269 |
f8_lin = cls(
|
270 |
in_features=linear.in_features,
|
271 |
out_features=linear.out_features,
|
|
|
285 |
model: nn.Module,
|
286 |
float8_dtype=torch.float8_e4m3fn,
|
287 |
input_float8_dtype=torch.float8_e5m2,
|
288 |
+
) -> None:
|
289 |
"""
|
290 |
Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
|
291 |
|
|
|
322 |
|
323 |
@torch.inference_mode()
|
324 |
def quantize_flow_transformer_and_dispatch_float8(
|
325 |
+
flow_model: FluxType,
|
326 |
device=torch.device("cuda"),
|
327 |
float8_dtype=torch.float8_e4m3fn,
|
328 |
input_float8_dtype=torch.float8_e5m2,
|
329 |
offload_flow=False,
|
330 |
+
) -> FluxType:
|
331 |
"""
|
332 |
Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
|
333 |
+
|
334 |
+
Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes.
|
335 |
+
|
336 |
+
Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
|
337 |
+
|
338 |
+
After dispatching, if offload_flow is True, offloads the model to cpu.
|
339 |
"""
|
340 |
+
for module in flow_model.double_blocks:
|
341 |
module.to(device)
|
342 |
module.eval()
|
343 |
recursive_swap_linears(
|
344 |
module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
|
345 |
)
|
346 |
torch.cuda.empty_cache()
|
347 |
+
for module in flow_model.single_blocks:
|
348 |
module.to(device)
|
349 |
module.eval()
|
350 |
recursive_swap_linears(
|
flux_emphasis.py
CHANGED
@@ -111,7 +111,9 @@ def parse_prompt_attention(text):
|
|
111 |
return res
|
112 |
|
113 |
|
114 |
-
def get_prompts_tokens_with_weights(
|
|
|
|
|
115 |
"""
|
116 |
Get prompt token ids and weights, this function works for both prompt and negative prompt
|
117 |
|
@@ -152,13 +154,14 @@ def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
|
|
152 |
).input_ids
|
153 |
# so that tokenize whatever length prompt
|
154 |
# the returned token is a 1d list: [320, 1125, 539, 320]
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
162 |
# merge the new tokens to the all tokens holder: text_tokens
|
163 |
text_tokens = [*text_tokens, *token]
|
164 |
|
@@ -306,6 +309,7 @@ def get_weighted_text_embeddings_flux(
|
|
306 |
device: Optional[torch.device] = None,
|
307 |
target_device: Optional[torch.device] = torch.device("cuda:0"),
|
308 |
target_dtype: Optional[torch.dtype] = torch.bfloat16,
|
|
|
309 |
):
|
310 |
"""
|
311 |
This function can process long prompt with weights, no length limitation
|
@@ -350,12 +354,12 @@ def get_weighted_text_embeddings_flux(
|
|
350 |
|
351 |
# tokenizer 1
|
352 |
prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights(
|
353 |
-
tokenizer_clip, prompt
|
354 |
)
|
355 |
|
356 |
# tokenizer 2
|
357 |
prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights(
|
358 |
-
tokenizer_t5, prompt
|
359 |
)
|
360 |
|
361 |
prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights(
|
@@ -428,7 +432,8 @@ def get_weighted_text_embeddings_flux(
|
|
428 |
"last_hidden_state"
|
429 |
]
|
430 |
t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2)
|
431 |
-
|
|
|
432 |
if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1:
|
433 |
t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
|
434 |
txt_ids = torch.zeros(
|
|
|
111 |
return res
|
112 |
|
113 |
|
114 |
+
def get_prompts_tokens_with_weights(
|
115 |
+
clip_tokenizer: CLIPTokenizer, prompt: str, debug: bool = False
|
116 |
+
):
|
117 |
"""
|
118 |
Get prompt token ids and weights, this function works for both prompt and negative prompt
|
119 |
|
|
|
154 |
).input_ids
|
155 |
# so that tokenize whatever length prompt
|
156 |
# the returned token is a 1d list: [320, 1125, 539, 320]
|
157 |
+
if debug:
|
158 |
+
print(
|
159 |
+
token,
|
160 |
+
"|FOR MODEL LEN{}|".format(maxlen),
|
161 |
+
clip_tokenizer.decode(
|
162 |
+
token, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
163 |
+
),
|
164 |
+
)
|
165 |
# merge the new tokens to the all tokens holder: text_tokens
|
166 |
text_tokens = [*text_tokens, *token]
|
167 |
|
|
|
309 |
device: Optional[torch.device] = None,
|
310 |
target_device: Optional[torch.device] = torch.device("cuda:0"),
|
311 |
target_dtype: Optional[torch.dtype] = torch.bfloat16,
|
312 |
+
debug: bool = False,
|
313 |
):
|
314 |
"""
|
315 |
This function can process long prompt with weights, no length limitation
|
|
|
354 |
|
355 |
# tokenizer 1
|
356 |
prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights(
|
357 |
+
tokenizer_clip, prompt, debug=debug
|
358 |
)
|
359 |
|
360 |
# tokenizer 2
|
361 |
prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights(
|
362 |
+
tokenizer_t5, prompt, debug=debug
|
363 |
)
|
364 |
|
365 |
prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights(
|
|
|
432 |
"last_hidden_state"
|
433 |
]
|
434 |
t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2)
|
435 |
+
if debug:
|
436 |
+
print(t5_embeds.shape)
|
437 |
if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1:
|
438 |
t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
|
439 |
txt_ids = torch.zeros(
|
flux_pipeline.py
CHANGED
@@ -3,7 +3,11 @@ import math
|
|
3 |
from typing import TYPE_CHECKING, Callable, List
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
|
|
6 |
|
|
|
|
|
|
|
7 |
import torch
|
8 |
|
9 |
from einops import rearrange
|
@@ -61,6 +65,7 @@ class FluxPipeline:
|
|
61 |
clip_device: torch.device | str = "cuda:1",
|
62 |
t5_device: torch.device | str = "cuda:1",
|
63 |
config: ModelSpec = None,
|
|
|
64 |
):
|
65 |
"""
|
66 |
Initialize the FluxPipeline class.
|
@@ -68,6 +73,7 @@ class FluxPipeline:
|
|
68 |
This class is responsible for preparing input tensors for the Flux model, generating
|
69 |
timesteps and noise, and handling device management for model offloading.
|
70 |
"""
|
|
|
71 |
self.name = name
|
72 |
self.device_flux = (
|
73 |
flux_device
|
@@ -113,7 +119,7 @@ class FluxPipeline:
|
|
113 |
|
114 |
if self.config.compile_blocks or self.config.compile_extras:
|
115 |
if not self.config.prequantized_flow:
|
116 |
-
|
117 |
warmup_dict = dict(
|
118 |
prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
|
119 |
height=768,
|
@@ -204,6 +210,8 @@ class FluxPipeline:
|
|
204 |
if self.offload_text_encoder:
|
205 |
self.clip.to(device=self.device_clip)
|
206 |
self.t5.to(device=self.device_t5)
|
|
|
|
|
207 |
vec, txt, txt_ids = get_weighted_text_embeddings_flux(
|
208 |
self,
|
209 |
prompt,
|
@@ -211,7 +219,9 @@ class FluxPipeline:
|
|
211 |
device=self.device_clip,
|
212 |
target_device=target_device,
|
213 |
target_dtype=target_dtype,
|
|
|
214 |
)
|
|
|
215 |
if self.offload_text_encoder:
|
216 |
self.clip.to("cpu")
|
217 |
self.t5.to("cpu")
|
@@ -494,6 +504,8 @@ class FluxPipeline:
|
|
494 |
logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
|
495 |
|
496 |
generator = torch.Generator(device=self.device_flux).manual_seed(seed)
|
|
|
|
|
497 |
img, timesteps = self.preprocess_latent(
|
498 |
init_image=init_image,
|
499 |
height=height,
|
@@ -503,6 +515,8 @@ class FluxPipeline:
|
|
503 |
generator=generator,
|
504 |
num_images=num_images,
|
505 |
)
|
|
|
|
|
506 |
img, img_ids, vec, txt, txt_ids = map(
|
507 |
lambda x: x.contiguous(),
|
508 |
self.prepare(
|
@@ -518,8 +532,11 @@ class FluxPipeline:
|
|
518 |
(img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
|
519 |
)
|
520 |
t_vec = None
|
|
|
521 |
if self.offload_flow:
|
522 |
self.model.to(self.device_flux)
|
|
|
|
|
523 |
for t_curr, t_prev in tqdm(
|
524 |
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
|
525 |
):
|
@@ -532,6 +549,7 @@ class FluxPipeline:
|
|
532 |
)
|
533 |
else:
|
534 |
t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
|
|
|
535 |
pred = self.model.forward(
|
536 |
img=img,
|
537 |
img_ids=img_ids,
|
@@ -544,6 +562,7 @@ class FluxPipeline:
|
|
544 |
|
545 |
img = img + (t_prev - t_curr) * pred
|
546 |
|
|
|
547 |
if self.offload_flow:
|
548 |
self.model.to("cpu")
|
549 |
torch.cuda.empty_cache()
|
@@ -557,16 +576,18 @@ class FluxPipeline:
|
|
557 |
|
558 |
@classmethod
|
559 |
def load_pipeline_from_config_path(
|
560 |
-
cls, path: str, flow_model_path: str = None
|
561 |
) -> "FluxPipeline":
|
562 |
with torch.inference_mode():
|
563 |
config = load_config_from_path(path)
|
564 |
if flow_model_path:
|
565 |
config.ckpt_path = flow_model_path
|
566 |
-
return cls.load_pipeline_from_config(config)
|
567 |
|
568 |
@classmethod
|
569 |
-
def load_pipeline_from_config(
|
|
|
|
|
570 |
from float8_quantize import quantize_flow_transformer_and_dispatch_float8
|
571 |
|
572 |
with torch.inference_mode():
|
@@ -603,4 +624,5 @@ class FluxPipeline:
|
|
603 |
clip_device=clip_device,
|
604 |
t5_device=t5_device,
|
605 |
config=config,
|
|
|
606 |
)
|
|
|
3 |
from typing import TYPE_CHECKING, Callable, List
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
+
import warnings
|
7 |
|
8 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
9 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
11 |
import torch
|
12 |
|
13 |
from einops import rearrange
|
|
|
65 |
clip_device: torch.device | str = "cuda:1",
|
66 |
t5_device: torch.device | str = "cuda:1",
|
67 |
config: ModelSpec = None,
|
68 |
+
debug: bool = False,
|
69 |
):
|
70 |
"""
|
71 |
Initialize the FluxPipeline class.
|
|
|
73 |
This class is responsible for preparing input tensors for the Flux model, generating
|
74 |
timesteps and noise, and handling device management for model offloading.
|
75 |
"""
|
76 |
+
self.debug = debug
|
77 |
self.name = name
|
78 |
self.device_flux = (
|
79 |
flux_device
|
|
|
119 |
|
120 |
if self.config.compile_blocks or self.config.compile_extras:
|
121 |
if not self.config.prequantized_flow:
|
122 |
+
logger.info("Running warmups for compile...")
|
123 |
warmup_dict = dict(
|
124 |
prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
|
125 |
height=768,
|
|
|
210 |
if self.offload_text_encoder:
|
211 |
self.clip.to(device=self.device_clip)
|
212 |
self.t5.to(device=self.device_t5)
|
213 |
+
|
214 |
+
# get the text embeddings
|
215 |
vec, txt, txt_ids = get_weighted_text_embeddings_flux(
|
216 |
self,
|
217 |
prompt,
|
|
|
219 |
device=self.device_clip,
|
220 |
target_device=target_device,
|
221 |
target_dtype=target_dtype,
|
222 |
+
debug=self.debug,
|
223 |
)
|
224 |
+
# offload text encoder to cpu if needed
|
225 |
if self.offload_text_encoder:
|
226 |
self.clip.to("cpu")
|
227 |
self.t5.to("cpu")
|
|
|
504 |
logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
|
505 |
|
506 |
generator = torch.Generator(device=self.device_flux).manual_seed(seed)
|
507 |
+
|
508 |
+
# preprocess the latent
|
509 |
img, timesteps = self.preprocess_latent(
|
510 |
init_image=init_image,
|
511 |
height=height,
|
|
|
515 |
generator=generator,
|
516 |
num_images=num_images,
|
517 |
)
|
518 |
+
|
519 |
+
# prepare inputs
|
520 |
img, img_ids, vec, txt, txt_ids = map(
|
521 |
lambda x: x.contiguous(),
|
522 |
self.prepare(
|
|
|
532 |
(img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
|
533 |
)
|
534 |
t_vec = None
|
535 |
+
# dispatch to gpu if offloaded
|
536 |
if self.offload_flow:
|
537 |
self.model.to(self.device_flux)
|
538 |
+
|
539 |
+
# perform the denoising loop
|
540 |
for t_curr, t_prev in tqdm(
|
541 |
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
|
542 |
):
|
|
|
549 |
)
|
550 |
else:
|
551 |
t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
|
552 |
+
|
553 |
pred = self.model.forward(
|
554 |
img=img,
|
555 |
img_ids=img_ids,
|
|
|
562 |
|
563 |
img = img + (t_prev - t_curr) * pred
|
564 |
|
565 |
+
# offload the model to cpu if needed
|
566 |
if self.offload_flow:
|
567 |
self.model.to("cpu")
|
568 |
torch.cuda.empty_cache()
|
|
|
576 |
|
577 |
@classmethod
|
578 |
def load_pipeline_from_config_path(
|
579 |
+
cls, path: str, flow_model_path: str = None, debug: bool = False
|
580 |
) -> "FluxPipeline":
|
581 |
with torch.inference_mode():
|
582 |
config = load_config_from_path(path)
|
583 |
if flow_model_path:
|
584 |
config.ckpt_path = flow_model_path
|
585 |
+
return cls.load_pipeline_from_config(config, debug=debug)
|
586 |
|
587 |
@classmethod
|
588 |
+
def load_pipeline_from_config(
|
589 |
+
cls, config: ModelSpec, debug: bool = False
|
590 |
+
) -> "FluxPipeline":
|
591 |
from float8_quantize import quantize_flow_transformer_and_dispatch_float8
|
592 |
|
593 |
with torch.inference_mode():
|
|
|
624 |
clip_device=clip_device,
|
625 |
t5_device=t5_device,
|
626 |
config=config,
|
627 |
+
debug=debug,
|
628 |
)
|
modules/conditioner.py
CHANGED
@@ -14,19 +14,6 @@ from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesCon
|
|
14 |
CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
|
15 |
|
16 |
|
17 |
-
def into_quantization_name(quantization_dtype: str) -> str:
|
18 |
-
if quantization_dtype == "qfloat8":
|
19 |
-
return "float8"
|
20 |
-
elif quantization_dtype == "qint4":
|
21 |
-
return "int4"
|
22 |
-
elif quantization_dtype == "qint8":
|
23 |
-
return "int8"
|
24 |
-
elif quantization_dtype == "qint2":
|
25 |
-
return "int2"
|
26 |
-
else:
|
27 |
-
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
|
28 |
-
|
29 |
-
|
30 |
def auto_quantization_config(
|
31 |
quantization_dtype: str,
|
32 |
) -> QuantoConfig | BitsAndBytesConfig:
|
|
|
14 |
CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def auto_quantization_config(
|
18 |
quantization_dtype: str,
|
19 |
) -> QuantoConfig | BitsAndBytesConfig:
|
quantize_swap_and_dispatch.py
DELETED
@@ -1,274 +0,0 @@
|
|
1 |
-
from fnmatch import fnmatch
|
2 |
-
from typing import List, Optional, Union
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from click import secho
|
6 |
-
from cublas_ops import CublasLinear
|
7 |
-
|
8 |
-
from quanto import (
|
9 |
-
QModuleMixin,
|
10 |
-
quantize_module,
|
11 |
-
QLinear,
|
12 |
-
QConv2d,
|
13 |
-
QLayerNorm,
|
14 |
-
)
|
15 |
-
from quanto.tensor import Optimizer, qtype, qfloat8, qint4, qint8
|
16 |
-
from torch import nn
|
17 |
-
|
18 |
-
|
19 |
-
class QuantizationDtype:
|
20 |
-
qfloat8 = "qfloat8"
|
21 |
-
qint2 = "qint2"
|
22 |
-
qint4 = "qint4"
|
23 |
-
qint8 = "qint8"
|
24 |
-
|
25 |
-
|
26 |
-
def into_qtype(qtype: QuantizationDtype) -> qtype:
|
27 |
-
if qtype == QuantizationDtype.qfloat8:
|
28 |
-
return qfloat8
|
29 |
-
elif qtype == QuantizationDtype.qint4:
|
30 |
-
return qint4
|
31 |
-
elif qtype == QuantizationDtype.qint8:
|
32 |
-
return qint8
|
33 |
-
else:
|
34 |
-
raise ValueError(f"Unknown qtype: {qtype}")
|
35 |
-
|
36 |
-
|
37 |
-
def _set_module_by_name(parent_module, name, child_module):
|
38 |
-
module_names = name.split(".")
|
39 |
-
if len(module_names) == 1:
|
40 |
-
setattr(parent_module, name, child_module)
|
41 |
-
else:
|
42 |
-
parent_module_name = name[: name.rindex(".")]
|
43 |
-
parent_module = parent_module.get_submodule(parent_module_name)
|
44 |
-
setattr(parent_module, module_names[-1], child_module)
|
45 |
-
|
46 |
-
|
47 |
-
def _quantize_submodule(
|
48 |
-
model: torch.nn.Module,
|
49 |
-
name: str,
|
50 |
-
module: torch.nn.Module,
|
51 |
-
weights: Optional[Union[str, qtype]] = None,
|
52 |
-
activations: Optional[Union[str, qtype]] = None,
|
53 |
-
optimizer: Optional[Optimizer] = None,
|
54 |
-
):
|
55 |
-
if isinstance(module, CublasLinear):
|
56 |
-
return 0
|
57 |
-
num_quant = 0
|
58 |
-
qmodule = quantize_module(
|
59 |
-
module, weights=weights, activations=activations, optimizer=optimizer
|
60 |
-
)
|
61 |
-
if qmodule is not None:
|
62 |
-
_set_module_by_name(model, name, qmodule)
|
63 |
-
# num_quant += 1
|
64 |
-
qmodule.name = name
|
65 |
-
for name, param in module.named_parameters():
|
66 |
-
# Save device memory by clearing parameters
|
67 |
-
setattr(module, name, None)
|
68 |
-
del param
|
69 |
-
num_quant += 1
|
70 |
-
|
71 |
-
return num_quant
|
72 |
-
|
73 |
-
|
74 |
-
def _quantize(
|
75 |
-
model: torch.nn.Module,
|
76 |
-
weights: Optional[Union[str, qtype]] = None,
|
77 |
-
activations: Optional[Union[str, qtype]] = None,
|
78 |
-
optimizer: Optional[Optimizer] = None,
|
79 |
-
include: Optional[Union[str, List[str]]] = None,
|
80 |
-
exclude: Optional[Union[str, List[str]]] = None,
|
81 |
-
):
|
82 |
-
"""Quantize the specified model submodules
|
83 |
-
|
84 |
-
Recursively quantize the submodules of the specified parent model.
|
85 |
-
|
86 |
-
Only modules that have quantized counterparts will be quantized.
|
87 |
-
|
88 |
-
If include patterns are specified, the submodule name must match one of them.
|
89 |
-
|
90 |
-
If exclude patterns are specified, the submodule must not match one of them.
|
91 |
-
|
92 |
-
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
|
93 |
-
https://docs.python.org/3/library/fnmatch.html for more details.
|
94 |
-
|
95 |
-
Note: quantization happens in-place and modifies the original model and its descendants.
|
96 |
-
|
97 |
-
Args:
|
98 |
-
model (`torch.nn.Module`): the model whose submodules will be quantized.
|
99 |
-
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
|
100 |
-
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
|
101 |
-
include (`Optional[Union[str, List[str]]]`):
|
102 |
-
Patterns constituting the allowlist. If provided, module names must match at
|
103 |
-
least one pattern from the allowlist.
|
104 |
-
exclude (`Optional[Union[str, List[str]]]`):
|
105 |
-
Patterns constituting the denylist. If provided, module names must not match
|
106 |
-
any patterns from the denylist.
|
107 |
-
"""
|
108 |
-
num_quant = 0
|
109 |
-
if include is not None:
|
110 |
-
include = [include] if isinstance(include, str) else exclude
|
111 |
-
if exclude is not None:
|
112 |
-
exclude = [exclude] if isinstance(exclude, str) else exclude
|
113 |
-
for name, m in model.named_modules():
|
114 |
-
if include is not None and not any(
|
115 |
-
fnmatch(name, pattern) for pattern in include
|
116 |
-
):
|
117 |
-
continue
|
118 |
-
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
|
119 |
-
continue
|
120 |
-
num_quant += _quantize_submodule(
|
121 |
-
model,
|
122 |
-
name,
|
123 |
-
m,
|
124 |
-
weights=weights,
|
125 |
-
activations=activations,
|
126 |
-
optimizer=optimizer,
|
127 |
-
)
|
128 |
-
return num_quant
|
129 |
-
|
130 |
-
|
131 |
-
def _freeze(model):
|
132 |
-
for name, m in model.named_modules():
|
133 |
-
if isinstance(m, QModuleMixin):
|
134 |
-
m.freeze()
|
135 |
-
|
136 |
-
|
137 |
-
def _is_block_compilable(module: nn.Module) -> bool:
|
138 |
-
for module in module.modules():
|
139 |
-
if _is_quantized(module):
|
140 |
-
return False
|
141 |
-
if _is_quantized(module):
|
142 |
-
return False
|
143 |
-
return True
|
144 |
-
|
145 |
-
|
146 |
-
def _simple_swap_linears(model: nn.Module, root_name: str = ""):
|
147 |
-
for name, module in model.named_children():
|
148 |
-
if (
|
149 |
-
_is_linear(module)
|
150 |
-
and hasattr(module, "weight")
|
151 |
-
and module.weight is not None
|
152 |
-
and module.weight.data is not None
|
153 |
-
):
|
154 |
-
weights = module.weight.data
|
155 |
-
bias = None
|
156 |
-
if module.bias is not None:
|
157 |
-
bias = module.bias.data
|
158 |
-
with torch.device(module.weight.device):
|
159 |
-
new_cublas = CublasLinear(
|
160 |
-
module.in_features,
|
161 |
-
module.out_features,
|
162 |
-
bias=bias is not None,
|
163 |
-
device=module.weight.device,
|
164 |
-
dtype=module.weight.dtype,
|
165 |
-
)
|
166 |
-
new_cublas.weight.data = weights
|
167 |
-
if bias is not None:
|
168 |
-
new_cublas.bias.data = bias
|
169 |
-
setattr(model, name, new_cublas)
|
170 |
-
if root_name == "":
|
171 |
-
secho(f"Replaced {name} with CublasLinear", fg="green")
|
172 |
-
else:
|
173 |
-
secho(f"Replaced {root_name}.{name} with CublasLinear", fg="green")
|
174 |
-
else:
|
175 |
-
if root_name == "":
|
176 |
-
_simple_swap_linears(module, str(name))
|
177 |
-
else:
|
178 |
-
_simple_swap_linears(module, str(root_name) + "." + str(name))
|
179 |
-
|
180 |
-
|
181 |
-
def _full_quant(
|
182 |
-
model, max_quants=24, current_quants=0, quantization_dtype: qtype = qfloat8
|
183 |
-
):
|
184 |
-
if current_quants < max_quants:
|
185 |
-
current_quants += _quantize(model, quantization_dtype)
|
186 |
-
_freeze(model)
|
187 |
-
print(f"Quantized {current_quants} modules with {quantization_dtype}")
|
188 |
-
return current_quants
|
189 |
-
|
190 |
-
|
191 |
-
def _is_linear(module: nn.Module) -> bool:
|
192 |
-
return not isinstance(
|
193 |
-
module, (QLinear, QConv2d, QLayerNorm, CublasLinear)
|
194 |
-
) and isinstance(module, nn.Linear)
|
195 |
-
|
196 |
-
|
197 |
-
def _is_quantized(module: nn.Module) -> bool:
|
198 |
-
return isinstance(module, (QLinear, QConv2d, QLayerNorm))
|
199 |
-
|
200 |
-
|
201 |
-
def quantize_and_dispatch_to_device(
|
202 |
-
flow_model: nn.Module,
|
203 |
-
flux_device: torch.device = torch.device("cuda"),
|
204 |
-
flux_dtype: torch.dtype = torch.float16,
|
205 |
-
num_layers_to_quantize: int = 20,
|
206 |
-
quantization_dtype: QuantizationDtype = QuantizationDtype.qfloat8,
|
207 |
-
compile_blocks: bool = True,
|
208 |
-
compile_extras: bool = True,
|
209 |
-
quantize_extras: bool = False,
|
210 |
-
replace_linears: bool = True,
|
211 |
-
):
|
212 |
-
quant_type = into_qtype(quantization_dtype)
|
213 |
-
num_quanted = 0
|
214 |
-
flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
|
215 |
-
for block in flow_model.single_blocks:
|
216 |
-
block.cuda(flux_device)
|
217 |
-
if num_quanted < num_layers_to_quantize:
|
218 |
-
num_quanted = _full_quant(
|
219 |
-
block,
|
220 |
-
num_layers_to_quantize,
|
221 |
-
num_quanted,
|
222 |
-
quantization_dtype=quant_type,
|
223 |
-
)
|
224 |
-
|
225 |
-
for block in flow_model.double_blocks:
|
226 |
-
block.cuda(flux_device)
|
227 |
-
if num_quanted < num_layers_to_quantize:
|
228 |
-
num_quanted = _full_quant(
|
229 |
-
block,
|
230 |
-
num_layers_to_quantize,
|
231 |
-
num_quanted,
|
232 |
-
quantization_dtype=quant_type,
|
233 |
-
)
|
234 |
-
|
235 |
-
to_gpu_extras = [
|
236 |
-
"vector_in",
|
237 |
-
"img_in",
|
238 |
-
"txt_in",
|
239 |
-
"time_in",
|
240 |
-
"guidance_in",
|
241 |
-
"final_layer",
|
242 |
-
"pe_embedder",
|
243 |
-
]
|
244 |
-
|
245 |
-
if compile_blocks:
|
246 |
-
for i, block in enumerate(flow_model.single_blocks):
|
247 |
-
if _is_block_compilable(block):
|
248 |
-
block.compile()
|
249 |
-
secho(f"Compiled block {i}", fg="green")
|
250 |
-
for i, block in enumerate(flow_model.double_blocks):
|
251 |
-
if _is_block_compilable(block):
|
252 |
-
block.compile()
|
253 |
-
secho(f"Compiled block {i}", fg="green")
|
254 |
-
|
255 |
-
if replace_linears:
|
256 |
-
_simple_swap_linears(flow_model)
|
257 |
-
for extra in to_gpu_extras:
|
258 |
-
m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
|
259 |
-
if compile_extras:
|
260 |
-
if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
|
261 |
-
m_extra.compile()
|
262 |
-
secho(
|
263 |
-
f"Compiled extra {extra} -- {m_extra.__class__.__name__}",
|
264 |
-
fg="green",
|
265 |
-
)
|
266 |
-
elif quantize_extras:
|
267 |
-
if not isinstance(m_extra, nn.Linear):
|
268 |
-
_full_quant(
|
269 |
-
m_extra,
|
270 |
-
current_quants=num_quanted,
|
271 |
-
max_quants=num_layers_to_quantize,
|
272 |
-
quantization_dtype=quantization_dtype,
|
273 |
-
)
|
274 |
-
return flow_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
util.py
CHANGED
@@ -42,6 +42,7 @@ class ModelSpec(BaseModel):
|
|
42 |
flow_dtype: str = "float16"
|
43 |
ae_dtype: str = "bfloat16"
|
44 |
text_enc_dtype: str = "bfloat16"
|
|
|
45 |
num_to_quant: Optional[int] = 20
|
46 |
quantize_extras: bool = False
|
47 |
compile_extras: bool = False
|
|
|
42 |
flow_dtype: str = "float16"
|
43 |
ae_dtype: str = "bfloat16"
|
44 |
text_enc_dtype: str = "bfloat16"
|
45 |
+
# unused / deprecated
|
46 |
num_to_quant: Optional[int] = 20
|
47 |
quantize_extras: bool = False
|
48 |
compile_extras: bool = False
|