aredden commited on
Commit
cc15333
·
1 Parent(s): 61a425a

fix unloading bug

Browse files
Files changed (1) hide show
  1. lora_loading.py +70 -7
lora_loading.py CHANGED
@@ -606,7 +606,7 @@ def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
606
 
607
 
608
  def get_lora_weights(lora_path: str | StateDict):
609
- if isinstance(lora_path, dict):
610
  return lora_path, True
611
  else:
612
  return load_file(lora_path, "cpu"), False
@@ -640,10 +640,41 @@ def apply_lora_to_model(
640
  ) -> Flux:
641
  has_guidance = model.params.guidance_embed
642
  logger.info(f"Loading LoRA weights for {lora_path}")
643
- lora_weights, _ = get_lora_weights(lora_path)
644
-
645
- keys_without_ab, lora_weights = resolve_lora_state_dict(lora_weights, has_guidance)
646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
648
  module = get_module_for_key(key, model)
649
  weight, is_f8, dtype = extract_weight_from_linear(module)
@@ -669,10 +700,42 @@ def remove_lora_from_module(
669
  ):
670
  has_guidance = model.params.guidance_embed
671
  logger.info(f"Loading LoRA weights for {lora_path}")
672
- lora_weights = get_lora_weights(lora_path)
673
- lora_weights, _ = get_lora_weights(lora_path)
674
 
675
- keys_without_ab, lora_weights = resolve_lora_state_dict(lora_weights, has_guidance)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
678
  module = get_module_for_key(key, model)
 
606
 
607
 
608
  def get_lora_weights(lora_path: str | StateDict):
609
+ if isinstance(lora_path, (dict, LoraWeights)):
610
  return lora_path, True
611
  else:
612
  return load_file(lora_path, "cpu"), False
 
640
  ) -> Flux:
641
  has_guidance = model.params.guidance_embed
642
  logger.info(f"Loading LoRA weights for {lora_path}")
643
+ lora_weights, already_loaded = get_lora_weights(lora_path)
 
 
644
 
645
+ if not already_loaded:
646
+ keys_without_ab, lora_weights = resolve_lora_state_dict(
647
+ lora_weights, has_guidance
648
+ )
649
+ elif isinstance(lora_weights, LoraWeights):
650
+ b_ = lora_weights
651
+ lora_weights = b_.weights
652
+ keys_without_ab = list(
653
+ set(
654
+ [
655
+ key.replace(".lora_A.weight", "")
656
+ .replace(".lora_B.weight", "")
657
+ .replace(".lora_A", "")
658
+ .replace(".lora_B", "")
659
+ .replace(".alpha", "")
660
+ for key in lora_weights.keys()
661
+ ]
662
+ )
663
+ )
664
+ else:
665
+ lora_weights = lora_weights
666
+ keys_without_ab = list(
667
+ set(
668
+ [
669
+ key.replace(".lora_A.weight", "")
670
+ .replace(".lora_B.weight", "")
671
+ .replace(".lora_A", "")
672
+ .replace(".lora_B", "")
673
+ .replace(".alpha", "")
674
+ for key in lora_weights.keys()
675
+ ]
676
+ )
677
+ )
678
  for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
679
  module = get_module_for_key(key, model)
680
  weight, is_f8, dtype = extract_weight_from_linear(module)
 
700
  ):
701
  has_guidance = model.params.guidance_embed
702
  logger.info(f"Loading LoRA weights for {lora_path}")
703
+ lora_weights, already_loaded = get_lora_weights(lora_path)
 
704
 
705
+ if not already_loaded:
706
+ keys_without_ab, lora_weights = resolve_lora_state_dict(
707
+ lora_weights, has_guidance
708
+ )
709
+ elif isinstance(lora_weights, LoraWeights):
710
+ b_ = lora_weights
711
+ lora_weights = b_.weights
712
+ keys_without_ab = list(
713
+ set(
714
+ [
715
+ key.replace(".lora_A.weight", "")
716
+ .replace(".lora_B.weight", "")
717
+ .replace(".lora_A", "")
718
+ .replace(".lora_B", "")
719
+ .replace(".alpha", "")
720
+ for key in lora_weights.keys()
721
+ ]
722
+ )
723
+ )
724
+ lora_scale = b_.scale
725
+ else:
726
+ lora_weights = lora_weights
727
+ keys_without_ab = list(
728
+ set(
729
+ [
730
+ key.replace(".lora_A.weight", "")
731
+ .replace(".lora_B.weight", "")
732
+ .replace(".lora_A", "")
733
+ .replace(".lora_B", "")
734
+ .replace(".alpha", "")
735
+ for key in lora_weights.keys()
736
+ ]
737
+ )
738
+ )
739
 
740
  for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
741
  module = get_module_for_key(key, model)