fix unloading bug
Browse files- lora_loading.py +70 -7
lora_loading.py
CHANGED
@@ -606,7 +606,7 @@ def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
|
|
606 |
|
607 |
|
608 |
def get_lora_weights(lora_path: str | StateDict):
|
609 |
-
if isinstance(lora_path, dict):
|
610 |
return lora_path, True
|
611 |
else:
|
612 |
return load_file(lora_path, "cpu"), False
|
@@ -640,10 +640,41 @@ def apply_lora_to_model(
|
|
640 |
) -> Flux:
|
641 |
has_guidance = model.params.guidance_embed
|
642 |
logger.info(f"Loading LoRA weights for {lora_path}")
|
643 |
-
lora_weights,
|
644 |
-
|
645 |
-
keys_without_ab, lora_weights = resolve_lora_state_dict(lora_weights, has_guidance)
|
646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
648 |
module = get_module_for_key(key, model)
|
649 |
weight, is_f8, dtype = extract_weight_from_linear(module)
|
@@ -669,10 +700,42 @@ def remove_lora_from_module(
|
|
669 |
):
|
670 |
has_guidance = model.params.guidance_embed
|
671 |
logger.info(f"Loading LoRA weights for {lora_path}")
|
672 |
-
lora_weights = get_lora_weights(lora_path)
|
673 |
-
lora_weights, _ = get_lora_weights(lora_path)
|
674 |
|
675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
|
677 |
for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
|
678 |
module = get_module_for_key(key, model)
|
|
|
606 |
|
607 |
|
608 |
def get_lora_weights(lora_path: str | StateDict):
|
609 |
+
if isinstance(lora_path, (dict, LoraWeights)):
|
610 |
return lora_path, True
|
611 |
else:
|
612 |
return load_file(lora_path, "cpu"), False
|
|
|
640 |
) -> Flux:
|
641 |
has_guidance = model.params.guidance_embed
|
642 |
logger.info(f"Loading LoRA weights for {lora_path}")
|
643 |
+
lora_weights, already_loaded = get_lora_weights(lora_path)
|
|
|
|
|
644 |
|
645 |
+
if not already_loaded:
|
646 |
+
keys_without_ab, lora_weights = resolve_lora_state_dict(
|
647 |
+
lora_weights, has_guidance
|
648 |
+
)
|
649 |
+
elif isinstance(lora_weights, LoraWeights):
|
650 |
+
b_ = lora_weights
|
651 |
+
lora_weights = b_.weights
|
652 |
+
keys_without_ab = list(
|
653 |
+
set(
|
654 |
+
[
|
655 |
+
key.replace(".lora_A.weight", "")
|
656 |
+
.replace(".lora_B.weight", "")
|
657 |
+
.replace(".lora_A", "")
|
658 |
+
.replace(".lora_B", "")
|
659 |
+
.replace(".alpha", "")
|
660 |
+
for key in lora_weights.keys()
|
661 |
+
]
|
662 |
+
)
|
663 |
+
)
|
664 |
+
else:
|
665 |
+
lora_weights = lora_weights
|
666 |
+
keys_without_ab = list(
|
667 |
+
set(
|
668 |
+
[
|
669 |
+
key.replace(".lora_A.weight", "")
|
670 |
+
.replace(".lora_B.weight", "")
|
671 |
+
.replace(".lora_A", "")
|
672 |
+
.replace(".lora_B", "")
|
673 |
+
.replace(".alpha", "")
|
674 |
+
for key in lora_weights.keys()
|
675 |
+
]
|
676 |
+
)
|
677 |
+
)
|
678 |
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
679 |
module = get_module_for_key(key, model)
|
680 |
weight, is_f8, dtype = extract_weight_from_linear(module)
|
|
|
700 |
):
|
701 |
has_guidance = model.params.guidance_embed
|
702 |
logger.info(f"Loading LoRA weights for {lora_path}")
|
703 |
+
lora_weights, already_loaded = get_lora_weights(lora_path)
|
|
|
704 |
|
705 |
+
if not already_loaded:
|
706 |
+
keys_without_ab, lora_weights = resolve_lora_state_dict(
|
707 |
+
lora_weights, has_guidance
|
708 |
+
)
|
709 |
+
elif isinstance(lora_weights, LoraWeights):
|
710 |
+
b_ = lora_weights
|
711 |
+
lora_weights = b_.weights
|
712 |
+
keys_without_ab = list(
|
713 |
+
set(
|
714 |
+
[
|
715 |
+
key.replace(".lora_A.weight", "")
|
716 |
+
.replace(".lora_B.weight", "")
|
717 |
+
.replace(".lora_A", "")
|
718 |
+
.replace(".lora_B", "")
|
719 |
+
.replace(".alpha", "")
|
720 |
+
for key in lora_weights.keys()
|
721 |
+
]
|
722 |
+
)
|
723 |
+
)
|
724 |
+
lora_scale = b_.scale
|
725 |
+
else:
|
726 |
+
lora_weights = lora_weights
|
727 |
+
keys_without_ab = list(
|
728 |
+
set(
|
729 |
+
[
|
730 |
+
key.replace(".lora_A.weight", "")
|
731 |
+
.replace(".lora_B.weight", "")
|
732 |
+
.replace(".lora_A", "")
|
733 |
+
.replace(".lora_B", "")
|
734 |
+
.replace(".alpha", "")
|
735 |
+
for key in lora_weights.keys()
|
736 |
+
]
|
737 |
+
)
|
738 |
+
)
|
739 |
|
740 |
for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
|
741 |
module = get_module_for_key(key, model)
|