aredden commited on
Commit
e21ae14
·
1 Parent(s): 264acad

Improve lora implementation

Browse files
Files changed (2) hide show
  1. flux_pipeline.py +15 -4
  2. lora_loading.py +222 -81
flux_pipeline.py CHANGED
@@ -2,7 +2,7 @@ import io
2
  import math
3
  import random
4
  import warnings
5
- from typing import TYPE_CHECKING, Callable, List
6
 
7
  import numpy as np
8
  from PIL import Image
@@ -148,7 +148,9 @@ class FluxPipeline:
148
  random.seed(seed)
149
  return cuda_generator, seed
150
 
151
- def load_lora(self, lora_path: str, scale: float):
 
 
152
  """
153
  Loads a LoRA checkpoint into the Flux flow transformer.
154
 
@@ -156,11 +158,20 @@ class FluxPipeline:
156
  or loras which contain keys which start with lora_unet_[...].
157
 
158
  Args:
159
- lora_path (str): Path to the LoRA checkpoint.
160
  scale (float): Scaling factor for the LoRA weights.
161
 
162
  """
163
- self.model = lora_loading.apply_lora_to_model(self.model, lora_path, scale)
 
 
 
 
 
 
 
 
 
164
 
165
  @torch.inference_mode()
166
  def compile(self):
 
2
  import math
3
  import random
4
  import warnings
5
+ from typing import TYPE_CHECKING, Callable, List, OrderedDict, Union
6
 
7
  import numpy as np
8
  from PIL import Image
 
148
  random.seed(seed)
149
  return cuda_generator, seed
150
 
151
+ def load_lora(
152
+ self, lora_path: Union[str, OrderedDict[str, torch.Tensor]], scale: float
153
+ ):
154
  """
155
  Loads a LoRA checkpoint into the Flux flow transformer.
156
 
 
158
  or loras which contain keys which start with lora_unet_[...].
159
 
160
  Args:
161
+ lora_path (str | OrderedDict[str, torch.Tensor]): Path to the LoRA checkpoint or an ordered dictionary containing the LoRA weights.
162
  scale (float): Scaling factor for the LoRA weights.
163
 
164
  """
165
+ self.model.load_lora(lora_path, scale)
166
+
167
+ def unload_lora(self, path_or_identifier: str):
168
+ """
169
+ Unloads the LoRA checkpoint from the Flux flow transformer.
170
+
171
+ Args:
172
+ path_or_identifier (str): Path to the LoRA checkpoint or the name given to the LoRA checkpoint when it was loaded.
173
+ """
174
+ self.model.unload_lora(path_or_identifier)
175
 
176
  @torch.inference_mode()
177
  def compile(self):
lora_loading.py CHANGED
@@ -13,7 +13,7 @@ except Exception as e:
13
  from float8_quantize import F8Linear
14
  from modules.flux_model import Flux
15
 
16
- path_regex = re.compile(r"\/|\\")
17
 
18
  StateDict: TypeAlias = OrderedDict[str, torch.Tensor]
19
 
@@ -138,59 +138,126 @@ def convert_diffusers_to_flux_transformer_checkpoint(
138
  f"double_blocks.{i}.txt_mod.lin.weight",
139
  )
140
 
141
- sample_q_A = diffusers_state_dict.pop(
142
- f"{prefix}{block_prefix}attn.to_q.lora_A.weight"
143
- )
144
- sample_q_B = diffusers_state_dict.pop(
145
- f"{prefix}{block_prefix}attn.to_q.lora_B.weight"
146
- )
147
-
148
- sample_k_A = diffusers_state_dict.pop(
149
- f"{prefix}{block_prefix}attn.to_k.lora_A.weight"
150
- )
151
- sample_k_B = diffusers_state_dict.pop(
152
- f"{prefix}{block_prefix}attn.to_k.lora_B.weight"
153
- )
154
 
155
- sample_v_A = diffusers_state_dict.pop(
156
- f"{prefix}{block_prefix}attn.to_v.lora_A.weight"
157
- )
158
- sample_v_B = diffusers_state_dict.pop(
159
- f"{prefix}{block_prefix}attn.to_v.lora_B.weight"
160
- )
161
 
162
- context_q_A = diffusers_state_dict.pop(
163
- f"{prefix}{block_prefix}attn.add_q_proj.lora_A.weight"
164
- )
165
- context_q_B = diffusers_state_dict.pop(
166
- f"{prefix}{block_prefix}attn.add_q_proj.lora_B.weight"
167
- )
 
 
168
 
169
- context_k_A = diffusers_state_dict.pop(
170
- f"{prefix}{block_prefix}attn.add_k_proj.lora_A.weight"
171
- )
172
- context_k_B = diffusers_state_dict.pop(
173
- f"{prefix}{block_prefix}attn.add_k_proj.lora_B.weight"
174
- )
175
- context_v_A = diffusers_state_dict.pop(
176
- f"{prefix}{block_prefix}attn.add_v_proj.lora_A.weight"
177
- )
178
- context_v_B = diffusers_state_dict.pop(
179
- f"{prefix}{block_prefix}attn.add_v_proj.lora_B.weight"
180
- )
181
-
182
- original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_A.weight"] = (
183
- torch.cat([sample_q_A, sample_k_A, sample_v_A], dim=0)
184
- )
185
- original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_B.weight"] = (
186
- torch.cat([sample_q_B, sample_k_B, sample_v_B], dim=0)
187
- )
188
- original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_A.weight"] = (
189
- torch.cat([context_q_A, context_k_A, context_v_A], dim=0)
190
- )
191
- original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_B.weight"] = (
192
- torch.cat([context_q_B, context_k_B, context_v_B], dim=0)
193
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # qk_norm
196
  original_state_dict, diffusers_state_dict = convert_if_lora_exists(
@@ -265,32 +332,73 @@ def convert_diffusers_to_flux_transformer_checkpoint(
265
  for i in range(num_single_layers):
266
  block_prefix = f"single_transformer_blocks.{i}."
267
  # norm.linear -> single_blocks.0.modulation.lin
 
268
  original_state_dict, diffusers_state_dict = convert_if_lora_exists(
269
  original_state_dict,
270
  diffusers_state_dict,
271
- f"{prefix}{block_prefix}norm.linear.weight",
272
  f"single_blocks.{i}.modulation.lin.weight",
273
  )
274
 
 
 
 
275
  # Q, K, V, mlp
276
  q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
277
  q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
 
 
 
 
278
  k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
279
  k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
 
 
 
 
280
  v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
281
  v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
 
 
 
 
282
  mlp_A = diffusers_state_dict.pop(
283
  f"{prefix}{block_prefix}proj_mlp.lora_A.weight"
284
  )
285
  mlp_B = diffusers_state_dict.pop(
286
  f"{prefix}{block_prefix}proj_mlp.lora_B.weight"
287
  )
288
- original_state_dict[f"single_blocks.{i}.linear1.lora_A.weight"] = torch.cat(
289
- [q_A, k_A, v_A, mlp_A], dim=0
290
- )
291
- original_state_dict[f"single_blocks.{i}.linear1.lora_B.weight"] = torch.cat(
292
- [q_B, k_B, v_B, mlp_B], dim=0
293
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
  # output projections
296
  original_state_dict, diffusers_state_dict = convert_if_lora_exists(
@@ -324,9 +432,16 @@ def convert_diffusers_to_flux_transformer_checkpoint(
324
  return original_state_dict
325
 
326
 
327
- def convert_from_original_flux_checkpoint(
328
- original_state_dict,
329
- ):
 
 
 
 
 
 
 
330
  sd = {
331
  k.replace("lora_unet_", "")
332
  .replace("double_blocks_", "double_blocks.")
@@ -358,14 +473,39 @@ def get_module_for_key(
358
  return module
359
 
360
 
361
- def get_lora_for_key(key: str, lora_weights: dict):
 
 
 
 
 
 
 
 
 
 
 
 
362
  prefix = key.split(".lora")[0]
363
- lora_A = lora_weights[f"{prefix}.lora_A.weight"]
364
- lora_B = lora_weights[f"{prefix}.lora_B.weight"]
365
- alpha = lora_weights.get(f"{prefix}.alpha", None)
 
 
 
366
  return lora_A, lora_B, alpha
367
 
368
 
 
 
 
 
 
 
 
 
 
 
369
  def calculate_lora_weight(
370
  lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]],
371
  rank: Optional[int] = None,
@@ -389,12 +529,16 @@ def calculate_lora_weight(
389
  w_down = lora_B.to(dtype=dtype, device=device)
390
 
391
  if alpha != rank:
392
- w_up = w_up * (alpha / rank)
393
-
394
  if uneven_rank:
395
- fused_lora = lora_scale * torch.mm(
396
- w_down.repeat_interleave(int(rank_diff), dim=1), w_up
397
- )
 
 
 
 
 
398
  else:
399
  fused_lora = lora_scale * torch.mm(w_down, w_up)
400
  return fused_lora
@@ -445,16 +589,6 @@ def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
445
  lora_weights = convert_from_original_flux_checkpoint(lora_weights)
446
  logger.info("LoRA weights loaded")
447
  logger.debug("Extracting keys")
448
- keys_without_ab = [
449
- key.replace(".lora_A.weight", "")
450
- .replace(".lora_B.weight", "")
451
- .replace(".lora_A", "")
452
- .replace(".lora_B", "")
453
- .replace(".alpha", "")
454
- for key in lora_weights.keys()
455
- ]
456
- logger.debug("Keys extracted")
457
- keys_without_ab = list(set(keys_without_ab))
458
  keys_without_ab = list(
459
  set(
460
  [
@@ -463,10 +597,11 @@ def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
463
  .replace(".lora_A", "")
464
  .replace(".lora_B", "")
465
  .replace(".alpha", "")
466
- for key in keys_without_ab
467
  ]
468
  )
469
  )
 
470
  return keys_without_ab, lora_weights
471
 
472
 
@@ -513,6 +648,9 @@ def apply_lora_to_model(
513
  module = get_module_for_key(key, model)
514
  weight, is_f8, dtype = extract_weight_from_linear(module)
515
  lora_sd = get_lora_for_key(key, lora_weights)
 
 
 
516
  weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale)
517
  if is_f8:
518
  module.set_weight_tensor(weight.type(dtype))
@@ -540,6 +678,9 @@ def remove_lora_from_module(
540
  module = get_module_for_key(key, model)
541
  weight, is_f8, dtype = extract_weight_from_linear(module)
542
  lora_sd = get_lora_for_key(key, lora_weights)
 
 
 
543
  weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale)
544
  if is_f8:
545
  module.set_weight_tensor(weight.type(dtype))
 
13
  from float8_quantize import F8Linear
14
  from modules.flux_model import Flux
15
 
16
+ path_regex = re.compile(r"/|\\")
17
 
18
  StateDict: TypeAlias = OrderedDict[str, torch.Tensor]
19
 
 
138
  f"double_blocks.{i}.txt_mod.lin.weight",
139
  )
140
 
141
+ # Q, K, V
142
+ temp_dict = {}
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ expected_shape_qkv_a = None
145
+ expected_shape_qkv_b = None
146
+ expected_shape_add_qkv_a = None
147
+ expected_shape_add_qkv_b = None
148
+ dtype = None
149
+ device = None
150
 
151
+ for component in [
152
+ "to_q",
153
+ "to_k",
154
+ "to_v",
155
+ "add_q_proj",
156
+ "add_k_proj",
157
+ "add_v_proj",
158
+ ]:
159
 
160
+ sample_component_A_key = (
161
+ f"{prefix}{block_prefix}attn.{component}.lora_A.weight"
162
+ )
163
+ sample_component_B_key = (
164
+ f"{prefix}{block_prefix}attn.{component}.lora_B.weight"
165
+ )
166
+ if (
167
+ sample_component_A_key in diffusers_state_dict
168
+ and sample_component_B_key in diffusers_state_dict
169
+ ):
170
+ sample_component_A = diffusers_state_dict.pop(sample_component_A_key)
171
+ sample_component_B = diffusers_state_dict.pop(sample_component_B_key)
172
+ temp_dict[f"{component}"] = [sample_component_A, sample_component_B]
173
+ if expected_shape_qkv_a is None and not component.startswith("add_"):
174
+ expected_shape_qkv_a = sample_component_A.shape
175
+ expected_shape_qkv_b = sample_component_B.shape
176
+ dtype = sample_component_A.dtype
177
+ device = sample_component_A.device
178
+ if expected_shape_add_qkv_a is None and component.startswith("add_"):
179
+ expected_shape_add_qkv_a = sample_component_A.shape
180
+ expected_shape_add_qkv_b = sample_component_B.shape
181
+ dtype = sample_component_A.dtype
182
+ device = sample_component_A.device
183
+ else:
184
+ logger.info(
185
+ f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}"
186
+ )
187
+ temp_dict[f"{component}"] = [None, None]
188
+
189
+ if device is not None:
190
+ if expected_shape_qkv_a is not None:
191
+
192
+ if (sq := temp_dict["to_q"])[0] is not None:
193
+ sample_q_A, sample_q_B = sq
194
+ else:
195
+ sample_q_A, sample_q_B = [
196
+ torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
197
+ torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
198
+ ]
199
+ if (sq := temp_dict["to_k"])[0] is not None:
200
+ sample_k_A, sample_k_B = sq
201
+ else:
202
+ sample_k_A, sample_k_B = [
203
+ torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
204
+ torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
205
+ ]
206
+ if (sq := temp_dict["to_v"])[0] is not None:
207
+ sample_v_A, sample_v_B = sq
208
+ else:
209
+ sample_v_A, sample_v_B = [
210
+ torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
211
+ torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
212
+ ]
213
+ original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_A.weight"] = (
214
+ torch.cat([sample_q_A, sample_k_A, sample_v_A], dim=0)
215
+ )
216
+ original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_B.weight"] = (
217
+ torch.cat([sample_q_B, sample_k_B, sample_v_B], dim=0)
218
+ )
219
+ if expected_shape_add_qkv_a is not None:
220
+
221
+ if (sq := temp_dict["add_q_proj"])[0] is not None:
222
+ context_q_A, context_q_B = sq
223
+ else:
224
+ context_q_A, context_q_B = [
225
+ torch.zeros(
226
+ expected_shape_add_qkv_a, dtype=dtype, device=device
227
+ ),
228
+ torch.zeros(
229
+ expected_shape_add_qkv_b, dtype=dtype, device=device
230
+ ),
231
+ ]
232
+ if (sq := temp_dict["add_k_proj"])[0] is not None:
233
+ context_k_A, context_k_B = sq
234
+ else:
235
+ context_k_A, context_k_B = [
236
+ torch.zeros(
237
+ expected_shape_add_qkv_a, dtype=dtype, device=device
238
+ ),
239
+ torch.zeros(
240
+ expected_shape_add_qkv_b, dtype=dtype, device=device
241
+ ),
242
+ ]
243
+ if (sq := temp_dict["add_v_proj"])[0] is not None:
244
+ context_v_A, context_v_B = sq
245
+ else:
246
+ context_v_A, context_v_B = [
247
+ torch.zeros(
248
+ expected_shape_add_qkv_a, dtype=dtype, device=device
249
+ ),
250
+ torch.zeros(
251
+ expected_shape_add_qkv_b, dtype=dtype, device=device
252
+ ),
253
+ ]
254
+
255
+ original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_A.weight"] = (
256
+ torch.cat([context_q_A, context_k_A, context_v_A], dim=0)
257
+ )
258
+ original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_B.weight"] = (
259
+ torch.cat([context_q_B, context_k_B, context_v_B], dim=0)
260
+ )
261
 
262
  # qk_norm
263
  original_state_dict, diffusers_state_dict = convert_if_lora_exists(
 
332
  for i in range(num_single_layers):
333
  block_prefix = f"single_transformer_blocks.{i}."
334
  # norm.linear -> single_blocks.0.modulation.lin
335
+ key_norm = f"{prefix}{block_prefix}norm.linear.weight"
336
  original_state_dict, diffusers_state_dict = convert_if_lora_exists(
337
  original_state_dict,
338
  diffusers_state_dict,
339
+ key_norm,
340
  f"single_blocks.{i}.modulation.lin.weight",
341
  )
342
 
343
+ has_q, has_k, has_v, has_mlp = False, False, False, False
344
+ shape_qkv_a = None
345
+ shape_qkv_b = None
346
  # Q, K, V, mlp
347
  q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
348
  q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
349
+ if q_A is not None and q_B is not None:
350
+ has_q = True
351
+ shape_qkv_a = q_A.shape
352
+ shape_qkv_b = q_B.shape
353
  k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
354
  k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
355
+ if k_A is not None and k_B is not None:
356
+ has_k = True
357
+ shape_qkv_a = k_A.shape
358
+ shape_qkv_b = k_B.shape
359
  v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
360
  v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
361
+ if v_A is not None and v_B is not None:
362
+ has_v = True
363
+ shape_qkv_a = v_A.shape
364
+ shape_qkv_b = v_B.shape
365
  mlp_A = diffusers_state_dict.pop(
366
  f"{prefix}{block_prefix}proj_mlp.lora_A.weight"
367
  )
368
  mlp_B = diffusers_state_dict.pop(
369
  f"{prefix}{block_prefix}proj_mlp.lora_B.weight"
370
  )
371
+ if mlp_A is not None and mlp_B is not None:
372
+ has_mlp = True
373
+ shape_qkv_a = mlp_A.shape
374
+ shape_qkv_b = mlp_B.shape
375
+ if any([has_q, has_k, has_v, has_mlp]):
376
+ if not has_q:
377
+ q_A, q_B = [
378
+ torch.zeros(shape_qkv_a, dtype=dtype, device=device),
379
+ torch.zeros(shape_qkv_b, dtype=dtype, device=device),
380
+ ]
381
+ if not has_k:
382
+ k_A, k_B = [
383
+ torch.zeros(shape_qkv_a, dtype=dtype, device=device),
384
+ torch.zeros(shape_qkv_b, dtype=dtype, device=device),
385
+ ]
386
+ if not has_v:
387
+ v_A, v_B = [
388
+ torch.zeros(shape_qkv_a, dtype=dtype, device=device),
389
+ torch.zeros(shape_qkv_b, dtype=dtype, device=device),
390
+ ]
391
+ if not has_mlp:
392
+ mlp_A, mlp_B = [
393
+ torch.zeros(shape_qkv_a, dtype=dtype, device=device),
394
+ torch.zeros(shape_qkv_b, dtype=dtype, device=device),
395
+ ]
396
+ original_state_dict[f"single_blocks.{i}.linear1.lora_A.weight"] = torch.cat(
397
+ [q_A, k_A, v_A, mlp_A], dim=0
398
+ )
399
+ original_state_dict[f"single_blocks.{i}.linear1.lora_B.weight"] = torch.cat(
400
+ [q_B, k_B, v_B, mlp_B], dim=0
401
+ )
402
 
403
  # output projections
404
  original_state_dict, diffusers_state_dict = convert_if_lora_exists(
 
432
  return original_state_dict
433
 
434
 
435
+ def convert_from_original_flux_checkpoint(original_state_dict: StateDict) -> StateDict:
436
+ """
437
+ Convert the state dict from the original Flux checkpoint format to the new format.
438
+
439
+ Args:
440
+ original_state_dict (Dict[str, torch.Tensor]): The original Flux checkpoint state dict.
441
+
442
+ Returns:
443
+ Dict[str, torch.Tensor]: The converted state dict in the new format.
444
+ """
445
  sd = {
446
  k.replace("lora_unet_", "")
447
  .replace("double_blocks_", "double_blocks.")
 
473
  return module
474
 
475
 
476
+ def get_lora_for_key(
477
+ key: str, lora_weights: dict
478
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]:
479
+ """
480
+ Get LoRA weights for a specific key.
481
+
482
+ Args:
483
+ key (str): The key to look up in the LoRA weights.
484
+ lora_weights (dict): Dictionary containing LoRA weights.
485
+
486
+ Returns:
487
+ Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]: A tuple containing lora_A, lora_B, and alpha if found, None otherwise.
488
+ """
489
  prefix = key.split(".lora")[0]
490
+ lora_A = lora_weights.get(f"{prefix}.lora_A.weight")
491
+ lora_B = lora_weights.get(f"{prefix}.lora_B.weight")
492
+ alpha = lora_weights.get(f"{prefix}.alpha")
493
+
494
+ if lora_A is None or lora_B is None:
495
+ return None
496
  return lora_A, lora_B, alpha
497
 
498
 
499
+ def get_module_for_key(
500
+ key: str, model: Flux
501
+ ) -> F8Linear | torch.nn.Linear | CublasLinear:
502
+ parts = key.split(".")
503
+ module = model
504
+ for part in parts:
505
+ module = getattr(module, part)
506
+ return module
507
+
508
+
509
  def calculate_lora_weight(
510
  lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]],
511
  rank: Optional[int] = None,
 
529
  w_down = lora_B.to(dtype=dtype, device=device)
530
 
531
  if alpha != rank:
532
+ w_up = w_up * alpha / rank
 
533
  if uneven_rank:
534
+ # Fuse each lora instead of repeat interleave for each individual lora,
535
+ # seems to fuse more correctly.
536
+ fused_lora = torch.zeros(
537
+ (lora_B.shape[0], lora_A.shape[1]), device=device, dtype=dtype
538
+ )
539
+ w_up = w_up.chunk(int(rank_diff), dim=0)
540
+ for w_up_chunk in w_up:
541
+ fused_lora = fused_lora + (lora_scale * torch.mm(w_down, w_up_chunk))
542
  else:
543
  fused_lora = lora_scale * torch.mm(w_down, w_up)
544
  return fused_lora
 
589
  lora_weights = convert_from_original_flux_checkpoint(lora_weights)
590
  logger.info("LoRA weights loaded")
591
  logger.debug("Extracting keys")
 
 
 
 
 
 
 
 
 
 
592
  keys_without_ab = list(
593
  set(
594
  [
 
597
  .replace(".lora_A", "")
598
  .replace(".lora_B", "")
599
  .replace(".alpha", "")
600
+ for key in lora_weights.keys()
601
  ]
602
  )
603
  )
604
+ logger.debug("Keys extracted")
605
  return keys_without_ab, lora_weights
606
 
607
 
 
648
  module = get_module_for_key(key, model)
649
  weight, is_f8, dtype = extract_weight_from_linear(module)
650
  lora_sd = get_lora_for_key(key, lora_weights)
651
+ if lora_sd is None:
652
+ # Skipping LoRA application for this module
653
+ continue
654
  weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale)
655
  if is_f8:
656
  module.set_weight_tensor(weight.type(dtype))
 
678
  module = get_module_for_key(key, model)
679
  weight, is_f8, dtype = extract_weight_from_linear(module)
680
  lora_sd = get_lora_for_key(key, lora_weights)
681
+ if lora_sd is None:
682
+ # Skipping LoRA application for this module
683
+ continue
684
  weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale)
685
  if is_f8:
686
  module.set_weight_tensor(weight.type(dtype))