Spaces:
Sleeping
Sleeping
File size: 2,668 Bytes
d16b52d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from typing import Dict, List, Optional, Union
import torch
from diffusers.loaders.lora import LoraLoaderMixin
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from diffusers.utils import USE_PEFT_BACKEND
class LoraLoaderWithWarmup(LoraLoaderMixin):
unet_warmup_name = "unet_warmup"
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
**kwargs,
):
# load lora for text encoder and unet-streaming
super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs)
# load lora for unet-warmup
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self,
)
def fuse_lora(
self,
fuse_unet: bool = True,
fuse_text_encoder: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
):
# fuse lora for text encoder and unet-streaming
super().fuse_lora(fuse_unet, fuse_text_encoder, lora_scale, safe_fusing, adapter_names)
# fuse lora for unet-warmup
if fuse_unet:
unet_warmup = (
getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup
)
unet_warmup.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
# unfuse lora for text encoder and unet-streaming
super().unfuse_lora(unfuse_unet, unfuse_text_encoder)
# unfuse lora for unet-warmup
if unfuse_unet:
unet_warmup = (
getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup
)
if not USE_PEFT_BACKEND:
unet_warmup.unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
for module in unet_warmup.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
|