File size: 2,668 Bytes
d16b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Dict, List, Optional, Union

import torch
from diffusers.loaders.lora import LoraLoaderMixin
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from diffusers.utils import USE_PEFT_BACKEND


class LoraLoaderWithWarmup(LoraLoaderMixin):
    unet_warmup_name = "unet_warmup"

    def load_lora_weights(
        self,
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        adapter_name=None,
        **kwargs,
    ):
        # load lora for text encoder and unet-streaming
        super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs)

        # load lora for unet-warmup
        state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)

        self.load_lora_into_unet(
            state_dict,
            network_alphas=network_alphas,
            unet=getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup,
            low_cpu_mem_usage=low_cpu_mem_usage,
            adapter_name=adapter_name,
            _pipeline=self,
        )

    def fuse_lora(
        self,
        fuse_unet: bool = True,
        fuse_text_encoder: bool = True,
        lora_scale: float = 1.0,
        safe_fusing: bool = False,
        adapter_names: Optional[List[str]] = None,
    ):
        # fuse lora for text encoder and unet-streaming
        super().fuse_lora(fuse_unet, fuse_text_encoder, lora_scale, safe_fusing, adapter_names)

        # fuse lora for unet-warmup
        if fuse_unet:
            unet_warmup = (
                getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup
            )
            unet_warmup.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)

    def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
        # unfuse lora for text encoder and unet-streaming
        super().unfuse_lora(unfuse_unet, unfuse_text_encoder)

        # unfuse lora for unet-warmup
        if unfuse_unet:
            unet_warmup = (
                getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup
            )
            if not USE_PEFT_BACKEND:
                unet_warmup.unfuse_lora()
            else:
                from peft.tuners.tuners_utils import BaseTunerLayer

                for module in unet_warmup.modules():
                    if isinstance(module, BaseTunerLayer):
                        module.unmerge()