multimodalart HF staff commited on
Commit
ce92feb
·
1 Parent(s): fc8ab35

Update lora.py

Browse files
Files changed (1) hide show
  1. lora.py +77 -26
lora.py CHANGED
@@ -5,12 +5,16 @@
5
 
6
  import math
7
  import os
8
- from typing import List, Tuple, Union
 
 
9
  import numpy as np
10
  import torch
11
  import re
12
 
13
 
 
 
14
  RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
15
 
16
 
@@ -400,7 +404,16 @@ def parse_block_lr_kwargs(nw_kwargs):
400
  return down_lr_weight, mid_lr_weight, up_lr_weight
401
 
402
 
403
- def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs):
 
 
 
 
 
 
 
 
 
404
  if network_dim is None:
405
  network_dim = 4 # default
406
  if network_alpha is None:
@@ -719,33 +732,36 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
719
  class LoRANetwork(torch.nn.Module):
720
  NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
721
 
722
- # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
723
- UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
724
  UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
725
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
726
  LORA_PREFIX_UNET = "lora_unet"
727
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
728
 
 
 
 
 
729
  def __init__(
730
  self,
731
- text_encoder,
732
  unet,
733
- multiplier=1.0,
734
- lora_dim=4,
735
- alpha=1,
736
- dropout=None,
737
- rank_dropout=None,
738
- module_dropout=None,
739
- conv_lora_dim=None,
740
- conv_alpha=None,
741
- block_dims=None,
742
- block_alphas=None,
743
- conv_block_dims=None,
744
- conv_block_alphas=None,
745
- modules_dim=None,
746
- modules_alpha=None,
747
- module_class=LoRAModule,
748
- varbose=False,
749
  ) -> None:
750
  """
751
  LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -783,8 +799,21 @@ class LoRANetwork(torch.nn.Module):
783
  print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
784
 
785
  # create module instances
786
- def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
787
- prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  loras = []
789
  skipped = []
790
  for name, module in root_module.named_modules():
@@ -800,11 +829,14 @@ class LoRANetwork(torch.nn.Module):
800
 
801
  dim = None
802
  alpha = None
 
803
  if modules_dim is not None:
 
804
  if lora_name in modules_dim:
805
  dim = modules_dim[lora_name]
806
  alpha = modules_alpha[lora_name]
807
  elif is_unet and block_dims is not None:
 
808
  block_idx = get_block_index(lora_name)
809
  if is_linear or is_conv2d_1x1:
810
  dim = block_dims[block_idx]
@@ -813,6 +845,7 @@ class LoRANetwork(torch.nn.Module):
813
  dim = conv_block_dims[block_idx]
814
  alpha = conv_block_alphas[block_idx]
815
  else:
 
816
  if is_linear or is_conv2d_1x1:
817
  dim = self.lora_dim
818
  alpha = self.alpha
@@ -821,6 +854,7 @@ class LoRANetwork(torch.nn.Module):
821
  alpha = self.conv_alpha
822
 
823
  if dim is None or dim == 0:
 
824
  if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
825
  skipped.append(lora_name)
826
  continue
@@ -838,7 +872,24 @@ class LoRANetwork(torch.nn.Module):
838
  loras.append(lora)
839
  return loras, skipped
840
 
841
- self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
843
 
844
  # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
@@ -846,7 +897,7 @@ class LoRANetwork(torch.nn.Module):
846
  if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
847
  target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
848
 
849
- self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
850
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
851
 
852
  skipped = skipped_te + skipped_un
@@ -880,7 +931,6 @@ class LoRANetwork(torch.nn.Module):
880
  weights_sd = load_file(file)
881
  else:
882
  weights_sd = torch.load(file, map_location="cpu")
883
-
884
  info = self.load_state_dict(weights_sd, False)
885
  return info
886
 
@@ -961,6 +1011,7 @@ class LoRANetwork(torch.nn.Module):
961
 
962
  return lr_weight
963
 
 
964
  def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
965
  self.requires_grad_(True)
966
  all_params = []
 
5
 
6
  import math
7
  import os
8
+ from typing import Dict, List, Optional, Tuple, Type, Union
9
+ from diffusers import AutoencoderKL
10
+ from transformers import CLIPTextModel
11
  import numpy as np
12
  import torch
13
  import re
14
 
15
 
16
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
17
+
18
  RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
19
 
20
 
 
404
  return down_lr_weight, mid_lr_weight, up_lr_weight
405
 
406
 
407
+ def create_network(
408
+ multiplier: float,
409
+ network_dim: Optional[int],
410
+ network_alpha: Optional[float],
411
+ vae: AutoencoderKL,
412
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
413
+ unet,
414
+ neuron_dropout: Optional[float] = None,
415
+ **kwargs,
416
+ ):
417
  if network_dim is None:
418
  network_dim = 4 # default
419
  if network_alpha is None:
 
732
  class LoRANetwork(torch.nn.Module):
733
  NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
734
 
735
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
 
736
  UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
737
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
738
  LORA_PREFIX_UNET = "lora_unet"
739
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
740
 
741
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
742
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
743
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
744
+
745
  def __init__(
746
  self,
747
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
748
  unet,
749
+ multiplier: float = 1.0,
750
+ lora_dim: int = 4,
751
+ alpha: float = 1,
752
+ dropout: Optional[float] = None,
753
+ rank_dropout: Optional[float] = None,
754
+ module_dropout: Optional[float] = None,
755
+ conv_lora_dim: Optional[int] = None,
756
+ conv_alpha: Optional[float] = None,
757
+ block_dims: Optional[List[int]] = None,
758
+ block_alphas: Optional[List[float]] = None,
759
+ conv_block_dims: Optional[List[int]] = None,
760
+ conv_block_alphas: Optional[List[float]] = None,
761
+ modules_dim: Optional[Dict[str, int]] = None,
762
+ modules_alpha: Optional[Dict[str, int]] = None,
763
+ module_class: Type[object] = LoRAModule,
764
+ varbose: Optional[bool] = False,
765
  ) -> None:
766
  """
767
  LoRA network: すごく引数が多いが、パターンは以下の通り
 
799
  print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
800
 
801
  # create module instances
802
+ def create_modules(
803
+ is_unet: bool,
804
+ text_encoder_idx: Optional[int], # None, 1, 2
805
+ root_module: torch.nn.Module,
806
+ target_replace_modules: List[torch.nn.Module],
807
+ ) -> List[LoRAModule]:
808
+ prefix = (
809
+ self.LORA_PREFIX_UNET
810
+ if is_unet
811
+ else (
812
+ self.LORA_PREFIX_TEXT_ENCODER
813
+ if text_encoder_idx is None
814
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
815
+ )
816
+ )
817
  loras = []
818
  skipped = []
819
  for name, module in root_module.named_modules():
 
829
 
830
  dim = None
831
  alpha = None
832
+
833
  if modules_dim is not None:
834
+ # モジュール指定あり
835
  if lora_name in modules_dim:
836
  dim = modules_dim[lora_name]
837
  alpha = modules_alpha[lora_name]
838
  elif is_unet and block_dims is not None:
839
+ # U-Netでblock_dims指定あり
840
  block_idx = get_block_index(lora_name)
841
  if is_linear or is_conv2d_1x1:
842
  dim = block_dims[block_idx]
 
845
  dim = conv_block_dims[block_idx]
846
  alpha = conv_block_alphas[block_idx]
847
  else:
848
+ # 通常、すべて対象とする
849
  if is_linear or is_conv2d_1x1:
850
  dim = self.lora_dim
851
  alpha = self.alpha
 
854
  alpha = self.conv_alpha
855
 
856
  if dim is None or dim == 0:
857
+ # skipした情報を出力
858
  if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
859
  skipped.append(lora_name)
860
  continue
 
872
  loras.append(lora)
873
  return loras, skipped
874
 
875
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
876
+ print(text_encoders)
877
+ # create LoRA for text encoder
878
+ # 毎回すべてのモジュールを作るのは無駄なので要検討
879
+ self.text_encoder_loras = []
880
+ skipped_te = []
881
+ for i, text_encoder in enumerate(text_encoders):
882
+ if len(text_encoders) > 1:
883
+ index = i + 1
884
+ print(f"create LoRA for Text Encoder {index}:")
885
+ else:
886
+ index = None
887
+ print(f"create LoRA for Text Encoder:")
888
+
889
+ print(text_encoder)
890
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
891
+ self.text_encoder_loras.extend(text_encoder_loras)
892
+ skipped_te += skipped
893
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
894
 
895
  # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
 
897
  if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
898
  target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
899
 
900
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
901
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
902
 
903
  skipped = skipped_te + skipped_un
 
931
  weights_sd = load_file(file)
932
  else:
933
  weights_sd = torch.load(file, map_location="cpu")
 
934
  info = self.load_state_dict(weights_sd, False)
935
  return info
936
 
 
1011
 
1012
  return lr_weight
1013
 
1014
+ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1015
  def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1016
  self.requires_grad_(True)
1017
  all_params = []