aredden commited on
Commit
0f3134f
·
1 Parent(s): 4a2503e

Remove unnecessary code, hide prints behind debug flag, hide warnings

Browse files
float8_quantize.py CHANGED
@@ -1,4 +1,3 @@
1
- from typing import Any, Mapping
2
  import torch
3
  import torch.nn as nn
4
  from torchao.float8.float8_utils import (
@@ -11,6 +10,7 @@ import math
11
  from torch.compiler import is_compiling
12
  from torch import __version__
13
  from torch.version import cuda
 
14
 
15
  IS_TORCH_2_4 = __version__ < (2, 4, 9)
16
  LT_TORCH_2_4 = __version__ < (2, 4)
@@ -29,23 +29,7 @@ try:
29
  except ImportError:
30
  CublasLinear = type(None)
31
 
32
-
33
- def check_scale_tensor(tensor):
34
- return (
35
- tensor is not None
36
- and isinstance(tensor, torch.Tensor)
37
- and tensor.dtype == torch.float32
38
- and tensor.numel() == 1
39
- and tensor != torch.zeros_like(tensor)
40
- )
41
-
42
-
43
- def check_scale_in_state_dict(state_dict, key):
44
- return key in state_dict and check_scale_tensor(state_dict[key])
45
-
46
-
47
- def check_scales_given_state_dict_and_keys(state_dict, keys):
48
- return all(check_scale_in_state_dict(state_dict, key) for key in keys)
49
 
50
 
51
  class F8Linear(nn.Module):
@@ -245,6 +229,7 @@ class F8Linear(nn.Module):
245
  init.uniform_(self.bias, -bound, bound)
246
  self.quantize_weight()
247
  self.max_value = torch.finfo(self.float8_dtype).max
 
248
 
249
  def forward(self, x: torch.Tensor) -> torch.Tensor:
250
  if self.input_scale_initialized or is_compiling():
@@ -280,7 +265,7 @@ class F8Linear(nn.Module):
280
  linear: nn.Linear,
281
  float8_dtype=torch.float8_e4m3fn,
282
  input_float8_dtype=torch.float8_e5m2,
283
- ):
284
  f8_lin = cls(
285
  in_features=linear.in_features,
286
  out_features=linear.out_features,
@@ -300,7 +285,7 @@ def recursive_swap_linears(
300
  model: nn.Module,
301
  float8_dtype=torch.float8_e4m3fn,
302
  input_float8_dtype=torch.float8_e5m2,
303
- ):
304
  """
305
  Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
306
 
@@ -337,23 +322,29 @@ def recursive_swap_linears(
337
 
338
  @torch.inference_mode()
339
  def quantize_flow_transformer_and_dispatch_float8(
340
- flow_model: nn.Module,
341
  device=torch.device("cuda"),
342
  float8_dtype=torch.float8_e4m3fn,
343
  input_float8_dtype=torch.float8_e5m2,
344
  offload_flow=False,
345
- ):
346
  """
347
  Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
 
 
 
 
 
 
348
  """
349
- for i, module in enumerate(flow_model.double_blocks):
350
  module.to(device)
351
  module.eval()
352
  recursive_swap_linears(
353
  module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
354
  )
355
  torch.cuda.empty_cache()
356
- for i, module in enumerate(flow_model.single_blocks):
357
  module.to(device)
358
  module.eval()
359
  recursive_swap_linears(
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torchao.float8.float8_utils import (
 
10
  from torch.compiler import is_compiling
11
  from torch import __version__
12
  from torch.version import cuda
13
+ from typing import TypeVar
14
 
15
  IS_TORCH_2_4 = __version__ < (2, 4, 9)
16
  LT_TORCH_2_4 = __version__ < (2, 4)
 
29
  except ImportError:
30
  CublasLinear = type(None)
31
 
32
+ FluxType = TypeVar("FluxType", nn.Module)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  class F8Linear(nn.Module):
 
229
  init.uniform_(self.bias, -bound, bound)
230
  self.quantize_weight()
231
  self.max_value = torch.finfo(self.float8_dtype).max
232
+ self.input_max_value = torch.finfo(self.input_float8_dtype).max
233
 
234
  def forward(self, x: torch.Tensor) -> torch.Tensor:
235
  if self.input_scale_initialized or is_compiling():
 
265
  linear: nn.Linear,
266
  float8_dtype=torch.float8_e4m3fn,
267
  input_float8_dtype=torch.float8_e5m2,
268
+ ) -> "F8Linear":
269
  f8_lin = cls(
270
  in_features=linear.in_features,
271
  out_features=linear.out_features,
 
285
  model: nn.Module,
286
  float8_dtype=torch.float8_e4m3fn,
287
  input_float8_dtype=torch.float8_e5m2,
288
+ ) -> None:
289
  """
290
  Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
291
 
 
322
 
323
  @torch.inference_mode()
324
  def quantize_flow_transformer_and_dispatch_float8(
325
+ flow_model: FluxType,
326
  device=torch.device("cuda"),
327
  float8_dtype=torch.float8_e4m3fn,
328
  input_float8_dtype=torch.float8_e5m2,
329
  offload_flow=False,
330
+ ) -> FluxType:
331
  """
332
  Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
333
+
334
+ Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes.
335
+
336
+ Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
337
+
338
+ After dispatching, if offload_flow is True, offloads the model to cpu.
339
  """
340
+ for module in flow_model.double_blocks:
341
  module.to(device)
342
  module.eval()
343
  recursive_swap_linears(
344
  module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
345
  )
346
  torch.cuda.empty_cache()
347
+ for module in flow_model.single_blocks:
348
  module.to(device)
349
  module.eval()
350
  recursive_swap_linears(
flux_emphasis.py CHANGED
@@ -111,7 +111,9 @@ def parse_prompt_attention(text):
111
  return res
112
 
113
 
114
- def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
 
 
115
  """
116
  Get prompt token ids and weights, this function works for both prompt and negative prompt
117
 
@@ -152,13 +154,14 @@ def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
152
  ).input_ids
153
  # so that tokenize whatever length prompt
154
  # the returned token is a 1d list: [320, 1125, 539, 320]
155
- print(
156
- token,
157
- "|FOR MODEL LEN{}|".format(maxlen),
158
- clip_tokenizer.decode(
159
- token, skip_special_tokens=True, clean_up_tokenization_spaces=True
160
- ),
161
- )
 
162
  # merge the new tokens to the all tokens holder: text_tokens
163
  text_tokens = [*text_tokens, *token]
164
 
@@ -306,6 +309,7 @@ def get_weighted_text_embeddings_flux(
306
  device: Optional[torch.device] = None,
307
  target_device: Optional[torch.device] = torch.device("cuda:0"),
308
  target_dtype: Optional[torch.dtype] = torch.bfloat16,
 
309
  ):
310
  """
311
  This function can process long prompt with weights, no length limitation
@@ -350,12 +354,12 @@ def get_weighted_text_embeddings_flux(
350
 
351
  # tokenizer 1
352
  prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights(
353
- tokenizer_clip, prompt
354
  )
355
 
356
  # tokenizer 2
357
  prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights(
358
- tokenizer_t5, prompt
359
  )
360
 
361
  prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights(
@@ -428,7 +432,8 @@ def get_weighted_text_embeddings_flux(
428
  "last_hidden_state"
429
  ]
430
  t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2)
431
- print(t5_embeds.shape)
 
432
  if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1:
433
  t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
434
  txt_ids = torch.zeros(
 
111
  return res
112
 
113
 
114
+ def get_prompts_tokens_with_weights(
115
+ clip_tokenizer: CLIPTokenizer, prompt: str, debug: bool = False
116
+ ):
117
  """
118
  Get prompt token ids and weights, this function works for both prompt and negative prompt
119
 
 
154
  ).input_ids
155
  # so that tokenize whatever length prompt
156
  # the returned token is a 1d list: [320, 1125, 539, 320]
157
+ if debug:
158
+ print(
159
+ token,
160
+ "|FOR MODEL LEN{}|".format(maxlen),
161
+ clip_tokenizer.decode(
162
+ token, skip_special_tokens=True, clean_up_tokenization_spaces=True
163
+ ),
164
+ )
165
  # merge the new tokens to the all tokens holder: text_tokens
166
  text_tokens = [*text_tokens, *token]
167
 
 
309
  device: Optional[torch.device] = None,
310
  target_device: Optional[torch.device] = torch.device("cuda:0"),
311
  target_dtype: Optional[torch.dtype] = torch.bfloat16,
312
+ debug: bool = False,
313
  ):
314
  """
315
  This function can process long prompt with weights, no length limitation
 
354
 
355
  # tokenizer 1
356
  prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights(
357
+ tokenizer_clip, prompt, debug=debug
358
  )
359
 
360
  # tokenizer 2
361
  prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights(
362
+ tokenizer_t5, prompt, debug=debug
363
  )
364
 
365
  prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights(
 
432
  "last_hidden_state"
433
  ]
434
  t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2)
435
+ if debug:
436
+ print(t5_embeds.shape)
437
  if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1:
438
  t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
439
  txt_ids = torch.zeros(
flux_pipeline.py CHANGED
@@ -3,7 +3,11 @@ import math
3
  from typing import TYPE_CHECKING, Callable, List
4
  from PIL import Image
5
  import numpy as np
 
6
 
 
 
 
7
  import torch
8
 
9
  from einops import rearrange
@@ -61,6 +65,7 @@ class FluxPipeline:
61
  clip_device: torch.device | str = "cuda:1",
62
  t5_device: torch.device | str = "cuda:1",
63
  config: ModelSpec = None,
 
64
  ):
65
  """
66
  Initialize the FluxPipeline class.
@@ -68,6 +73,7 @@ class FluxPipeline:
68
  This class is responsible for preparing input tensors for the Flux model, generating
69
  timesteps and noise, and handling device management for model offloading.
70
  """
 
71
  self.name = name
72
  self.device_flux = (
73
  flux_device
@@ -113,7 +119,7 @@ class FluxPipeline:
113
 
114
  if self.config.compile_blocks or self.config.compile_extras:
115
  if not self.config.prequantized_flow:
116
- print("Warmups for compile...")
117
  warmup_dict = dict(
118
  prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
119
  height=768,
@@ -204,6 +210,8 @@ class FluxPipeline:
204
  if self.offload_text_encoder:
205
  self.clip.to(device=self.device_clip)
206
  self.t5.to(device=self.device_t5)
 
 
207
  vec, txt, txt_ids = get_weighted_text_embeddings_flux(
208
  self,
209
  prompt,
@@ -211,7 +219,9 @@ class FluxPipeline:
211
  device=self.device_clip,
212
  target_device=target_device,
213
  target_dtype=target_dtype,
 
214
  )
 
215
  if self.offload_text_encoder:
216
  self.clip.to("cpu")
217
  self.t5.to("cpu")
@@ -494,6 +504,8 @@ class FluxPipeline:
494
  logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
495
 
496
  generator = torch.Generator(device=self.device_flux).manual_seed(seed)
 
 
497
  img, timesteps = self.preprocess_latent(
498
  init_image=init_image,
499
  height=height,
@@ -503,6 +515,8 @@ class FluxPipeline:
503
  generator=generator,
504
  num_images=num_images,
505
  )
 
 
506
  img, img_ids, vec, txt, txt_ids = map(
507
  lambda x: x.contiguous(),
508
  self.prepare(
@@ -518,8 +532,11 @@ class FluxPipeline:
518
  (img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
519
  )
520
  t_vec = None
 
521
  if self.offload_flow:
522
  self.model.to(self.device_flux)
 
 
523
  for t_curr, t_prev in tqdm(
524
  zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
525
  ):
@@ -532,6 +549,7 @@ class FluxPipeline:
532
  )
533
  else:
534
  t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
 
535
  pred = self.model.forward(
536
  img=img,
537
  img_ids=img_ids,
@@ -544,6 +562,7 @@ class FluxPipeline:
544
 
545
  img = img + (t_prev - t_curr) * pred
546
 
 
547
  if self.offload_flow:
548
  self.model.to("cpu")
549
  torch.cuda.empty_cache()
@@ -557,16 +576,18 @@ class FluxPipeline:
557
 
558
  @classmethod
559
  def load_pipeline_from_config_path(
560
- cls, path: str, flow_model_path: str = None
561
  ) -> "FluxPipeline":
562
  with torch.inference_mode():
563
  config = load_config_from_path(path)
564
  if flow_model_path:
565
  config.ckpt_path = flow_model_path
566
- return cls.load_pipeline_from_config(config)
567
 
568
  @classmethod
569
- def load_pipeline_from_config(cls, config: ModelSpec) -> "FluxPipeline":
 
 
570
  from float8_quantize import quantize_flow_transformer_and_dispatch_float8
571
 
572
  with torch.inference_mode():
@@ -603,4 +624,5 @@ class FluxPipeline:
603
  clip_device=clip_device,
604
  t5_device=t5_device,
605
  config=config,
 
606
  )
 
3
  from typing import TYPE_CHECKING, Callable, List
4
  from PIL import Image
5
  import numpy as np
6
+ import warnings
7
 
8
+ warnings.filterwarnings("ignore", category=UserWarning)
9
+ warnings.filterwarnings("ignore", category=FutureWarning)
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
  import torch
12
 
13
  from einops import rearrange
 
65
  clip_device: torch.device | str = "cuda:1",
66
  t5_device: torch.device | str = "cuda:1",
67
  config: ModelSpec = None,
68
+ debug: bool = False,
69
  ):
70
  """
71
  Initialize the FluxPipeline class.
 
73
  This class is responsible for preparing input tensors for the Flux model, generating
74
  timesteps and noise, and handling device management for model offloading.
75
  """
76
+ self.debug = debug
77
  self.name = name
78
  self.device_flux = (
79
  flux_device
 
119
 
120
  if self.config.compile_blocks or self.config.compile_extras:
121
  if not self.config.prequantized_flow:
122
+ logger.info("Running warmups for compile...")
123
  warmup_dict = dict(
124
  prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
125
  height=768,
 
210
  if self.offload_text_encoder:
211
  self.clip.to(device=self.device_clip)
212
  self.t5.to(device=self.device_t5)
213
+
214
+ # get the text embeddings
215
  vec, txt, txt_ids = get_weighted_text_embeddings_flux(
216
  self,
217
  prompt,
 
219
  device=self.device_clip,
220
  target_device=target_device,
221
  target_dtype=target_dtype,
222
+ debug=self.debug,
223
  )
224
+ # offload text encoder to cpu if needed
225
  if self.offload_text_encoder:
226
  self.clip.to("cpu")
227
  self.t5.to("cpu")
 
504
  logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
505
 
506
  generator = torch.Generator(device=self.device_flux).manual_seed(seed)
507
+
508
+ # preprocess the latent
509
  img, timesteps = self.preprocess_latent(
510
  init_image=init_image,
511
  height=height,
 
515
  generator=generator,
516
  num_images=num_images,
517
  )
518
+
519
+ # prepare inputs
520
  img, img_ids, vec, txt, txt_ids = map(
521
  lambda x: x.contiguous(),
522
  self.prepare(
 
532
  (img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
533
  )
534
  t_vec = None
535
+ # dispatch to gpu if offloaded
536
  if self.offload_flow:
537
  self.model.to(self.device_flux)
538
+
539
+ # perform the denoising loop
540
  for t_curr, t_prev in tqdm(
541
  zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
542
  ):
 
549
  )
550
  else:
551
  t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
552
+
553
  pred = self.model.forward(
554
  img=img,
555
  img_ids=img_ids,
 
562
 
563
  img = img + (t_prev - t_curr) * pred
564
 
565
+ # offload the model to cpu if needed
566
  if self.offload_flow:
567
  self.model.to("cpu")
568
  torch.cuda.empty_cache()
 
576
 
577
  @classmethod
578
  def load_pipeline_from_config_path(
579
+ cls, path: str, flow_model_path: str = None, debug: bool = False
580
  ) -> "FluxPipeline":
581
  with torch.inference_mode():
582
  config = load_config_from_path(path)
583
  if flow_model_path:
584
  config.ckpt_path = flow_model_path
585
+ return cls.load_pipeline_from_config(config, debug=debug)
586
 
587
  @classmethod
588
+ def load_pipeline_from_config(
589
+ cls, config: ModelSpec, debug: bool = False
590
+ ) -> "FluxPipeline":
591
  from float8_quantize import quantize_flow_transformer_and_dispatch_float8
592
 
593
  with torch.inference_mode():
 
624
  clip_device=clip_device,
625
  t5_device=t5_device,
626
  config=config,
627
+ debug=debug,
628
  )
modules/conditioner.py CHANGED
@@ -14,19 +14,6 @@ from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesCon
14
  CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
15
 
16
 
17
- def into_quantization_name(quantization_dtype: str) -> str:
18
- if quantization_dtype == "qfloat8":
19
- return "float8"
20
- elif quantization_dtype == "qint4":
21
- return "int4"
22
- elif quantization_dtype == "qint8":
23
- return "int8"
24
- elif quantization_dtype == "qint2":
25
- return "int2"
26
- else:
27
- raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
28
-
29
-
30
  def auto_quantization_config(
31
  quantization_dtype: str,
32
  ) -> QuantoConfig | BitsAndBytesConfig:
 
14
  CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def auto_quantization_config(
18
  quantization_dtype: str,
19
  ) -> QuantoConfig | BitsAndBytesConfig:
quantize_swap_and_dispatch.py DELETED
@@ -1,274 +0,0 @@
1
- from fnmatch import fnmatch
2
- from typing import List, Optional, Union
3
-
4
- import torch
5
- from click import secho
6
- from cublas_ops import CublasLinear
7
-
8
- from quanto import (
9
- QModuleMixin,
10
- quantize_module,
11
- QLinear,
12
- QConv2d,
13
- QLayerNorm,
14
- )
15
- from quanto.tensor import Optimizer, qtype, qfloat8, qint4, qint8
16
- from torch import nn
17
-
18
-
19
- class QuantizationDtype:
20
- qfloat8 = "qfloat8"
21
- qint2 = "qint2"
22
- qint4 = "qint4"
23
- qint8 = "qint8"
24
-
25
-
26
- def into_qtype(qtype: QuantizationDtype) -> qtype:
27
- if qtype == QuantizationDtype.qfloat8:
28
- return qfloat8
29
- elif qtype == QuantizationDtype.qint4:
30
- return qint4
31
- elif qtype == QuantizationDtype.qint8:
32
- return qint8
33
- else:
34
- raise ValueError(f"Unknown qtype: {qtype}")
35
-
36
-
37
- def _set_module_by_name(parent_module, name, child_module):
38
- module_names = name.split(".")
39
- if len(module_names) == 1:
40
- setattr(parent_module, name, child_module)
41
- else:
42
- parent_module_name = name[: name.rindex(".")]
43
- parent_module = parent_module.get_submodule(parent_module_name)
44
- setattr(parent_module, module_names[-1], child_module)
45
-
46
-
47
- def _quantize_submodule(
48
- model: torch.nn.Module,
49
- name: str,
50
- module: torch.nn.Module,
51
- weights: Optional[Union[str, qtype]] = None,
52
- activations: Optional[Union[str, qtype]] = None,
53
- optimizer: Optional[Optimizer] = None,
54
- ):
55
- if isinstance(module, CublasLinear):
56
- return 0
57
- num_quant = 0
58
- qmodule = quantize_module(
59
- module, weights=weights, activations=activations, optimizer=optimizer
60
- )
61
- if qmodule is not None:
62
- _set_module_by_name(model, name, qmodule)
63
- # num_quant += 1
64
- qmodule.name = name
65
- for name, param in module.named_parameters():
66
- # Save device memory by clearing parameters
67
- setattr(module, name, None)
68
- del param
69
- num_quant += 1
70
-
71
- return num_quant
72
-
73
-
74
- def _quantize(
75
- model: torch.nn.Module,
76
- weights: Optional[Union[str, qtype]] = None,
77
- activations: Optional[Union[str, qtype]] = None,
78
- optimizer: Optional[Optimizer] = None,
79
- include: Optional[Union[str, List[str]]] = None,
80
- exclude: Optional[Union[str, List[str]]] = None,
81
- ):
82
- """Quantize the specified model submodules
83
-
84
- Recursively quantize the submodules of the specified parent model.
85
-
86
- Only modules that have quantized counterparts will be quantized.
87
-
88
- If include patterns are specified, the submodule name must match one of them.
89
-
90
- If exclude patterns are specified, the submodule must not match one of them.
91
-
92
- Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
93
- https://docs.python.org/3/library/fnmatch.html for more details.
94
-
95
- Note: quantization happens in-place and modifies the original model and its descendants.
96
-
97
- Args:
98
- model (`torch.nn.Module`): the model whose submodules will be quantized.
99
- weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
100
- activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
101
- include (`Optional[Union[str, List[str]]]`):
102
- Patterns constituting the allowlist. If provided, module names must match at
103
- least one pattern from the allowlist.
104
- exclude (`Optional[Union[str, List[str]]]`):
105
- Patterns constituting the denylist. If provided, module names must not match
106
- any patterns from the denylist.
107
- """
108
- num_quant = 0
109
- if include is not None:
110
- include = [include] if isinstance(include, str) else exclude
111
- if exclude is not None:
112
- exclude = [exclude] if isinstance(exclude, str) else exclude
113
- for name, m in model.named_modules():
114
- if include is not None and not any(
115
- fnmatch(name, pattern) for pattern in include
116
- ):
117
- continue
118
- if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
119
- continue
120
- num_quant += _quantize_submodule(
121
- model,
122
- name,
123
- m,
124
- weights=weights,
125
- activations=activations,
126
- optimizer=optimizer,
127
- )
128
- return num_quant
129
-
130
-
131
- def _freeze(model):
132
- for name, m in model.named_modules():
133
- if isinstance(m, QModuleMixin):
134
- m.freeze()
135
-
136
-
137
- def _is_block_compilable(module: nn.Module) -> bool:
138
- for module in module.modules():
139
- if _is_quantized(module):
140
- return False
141
- if _is_quantized(module):
142
- return False
143
- return True
144
-
145
-
146
- def _simple_swap_linears(model: nn.Module, root_name: str = ""):
147
- for name, module in model.named_children():
148
- if (
149
- _is_linear(module)
150
- and hasattr(module, "weight")
151
- and module.weight is not None
152
- and module.weight.data is not None
153
- ):
154
- weights = module.weight.data
155
- bias = None
156
- if module.bias is not None:
157
- bias = module.bias.data
158
- with torch.device(module.weight.device):
159
- new_cublas = CublasLinear(
160
- module.in_features,
161
- module.out_features,
162
- bias=bias is not None,
163
- device=module.weight.device,
164
- dtype=module.weight.dtype,
165
- )
166
- new_cublas.weight.data = weights
167
- if bias is not None:
168
- new_cublas.bias.data = bias
169
- setattr(model, name, new_cublas)
170
- if root_name == "":
171
- secho(f"Replaced {name} with CublasLinear", fg="green")
172
- else:
173
- secho(f"Replaced {root_name}.{name} with CublasLinear", fg="green")
174
- else:
175
- if root_name == "":
176
- _simple_swap_linears(module, str(name))
177
- else:
178
- _simple_swap_linears(module, str(root_name) + "." + str(name))
179
-
180
-
181
- def _full_quant(
182
- model, max_quants=24, current_quants=0, quantization_dtype: qtype = qfloat8
183
- ):
184
- if current_quants < max_quants:
185
- current_quants += _quantize(model, quantization_dtype)
186
- _freeze(model)
187
- print(f"Quantized {current_quants} modules with {quantization_dtype}")
188
- return current_quants
189
-
190
-
191
- def _is_linear(module: nn.Module) -> bool:
192
- return not isinstance(
193
- module, (QLinear, QConv2d, QLayerNorm, CublasLinear)
194
- ) and isinstance(module, nn.Linear)
195
-
196
-
197
- def _is_quantized(module: nn.Module) -> bool:
198
- return isinstance(module, (QLinear, QConv2d, QLayerNorm))
199
-
200
-
201
- def quantize_and_dispatch_to_device(
202
- flow_model: nn.Module,
203
- flux_device: torch.device = torch.device("cuda"),
204
- flux_dtype: torch.dtype = torch.float16,
205
- num_layers_to_quantize: int = 20,
206
- quantization_dtype: QuantizationDtype = QuantizationDtype.qfloat8,
207
- compile_blocks: bool = True,
208
- compile_extras: bool = True,
209
- quantize_extras: bool = False,
210
- replace_linears: bool = True,
211
- ):
212
- quant_type = into_qtype(quantization_dtype)
213
- num_quanted = 0
214
- flow_model = flow_model.requires_grad_(False).eval().type(flux_dtype)
215
- for block in flow_model.single_blocks:
216
- block.cuda(flux_device)
217
- if num_quanted < num_layers_to_quantize:
218
- num_quanted = _full_quant(
219
- block,
220
- num_layers_to_quantize,
221
- num_quanted,
222
- quantization_dtype=quant_type,
223
- )
224
-
225
- for block in flow_model.double_blocks:
226
- block.cuda(flux_device)
227
- if num_quanted < num_layers_to_quantize:
228
- num_quanted = _full_quant(
229
- block,
230
- num_layers_to_quantize,
231
- num_quanted,
232
- quantization_dtype=quant_type,
233
- )
234
-
235
- to_gpu_extras = [
236
- "vector_in",
237
- "img_in",
238
- "txt_in",
239
- "time_in",
240
- "guidance_in",
241
- "final_layer",
242
- "pe_embedder",
243
- ]
244
-
245
- if compile_blocks:
246
- for i, block in enumerate(flow_model.single_blocks):
247
- if _is_block_compilable(block):
248
- block.compile()
249
- secho(f"Compiled block {i}", fg="green")
250
- for i, block in enumerate(flow_model.double_blocks):
251
- if _is_block_compilable(block):
252
- block.compile()
253
- secho(f"Compiled block {i}", fg="green")
254
-
255
- if replace_linears:
256
- _simple_swap_linears(flow_model)
257
- for extra in to_gpu_extras:
258
- m_extra = getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
259
- if compile_extras:
260
- if extra in ["time_in", "vector_in", "guidance_in", "final_layer"]:
261
- m_extra.compile()
262
- secho(
263
- f"Compiled extra {extra} -- {m_extra.__class__.__name__}",
264
- fg="green",
265
- )
266
- elif quantize_extras:
267
- if not isinstance(m_extra, nn.Linear):
268
- _full_quant(
269
- m_extra,
270
- current_quants=num_quanted,
271
- max_quants=num_layers_to_quantize,
272
- quantization_dtype=quantization_dtype,
273
- )
274
- return flow_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util.py CHANGED
@@ -42,6 +42,7 @@ class ModelSpec(BaseModel):
42
  flow_dtype: str = "float16"
43
  ae_dtype: str = "bfloat16"
44
  text_enc_dtype: str = "bfloat16"
 
45
  num_to_quant: Optional[int] = 20
46
  quantize_extras: bool = False
47
  compile_extras: bool = False
 
42
  flow_dtype: str = "float16"
43
  ae_dtype: str = "bfloat16"
44
  text_enc_dtype: str = "bfloat16"
45
+ # unused / deprecated
46
  num_to_quant: Optional[int] = 20
47
  quantize_extras: bool = False
48
  compile_extras: bool = False