File size: 9,001 Bytes
0b23d5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
from utils.utils import instantiate_from_config
import torch
import copy
from omegaconf import OmegaConf
import logging

main_logger = logging.getLogger("main_logger")


def expand_conv_kernel(pretrained_dict):
    """expand 2d conv parameters from 4D -> 5D"""
    for k, v in pretrained_dict.items():
        if v.dim() == 4 and not k.startswith("first_stage_model"):
            v = v.unsqueeze(2)
            pretrained_dict[k] = v
    return pretrained_dict


def print_state_dict(state_dict):
    print("====== Dumping State Dict ======")
    for k, v in state_dict.items():
        print(k, v.shape)


def load_from_pretrainedSD_checkpoint(
    model,
    pretained_ckpt,
    expand_to_3d=True,
    adapt_keyname=False,
    echo_empty_params=False,
):
    sd_state_dict = torch.load(pretained_ckpt, map_location="cpu")
    if "state_dict" in list(sd_state_dict.keys()):
        sd_state_dict = sd_state_dict["state_dict"]
    model_state_dict = model.state_dict()
    # delete ema_weights just for <precise param counting>
    for k in list(sd_state_dict.keys()):
        if k.startswith("model_ema"):
            del sd_state_dict[k]
    main_logger.info(
        f"Num of model params of Source:{len(sd_state_dict.keys())} VS. Target:{len(model_state_dict.keys())}"
    )
    # print_state_dict(model_state_dict)
    # print_state_dict(sd_state_dict)

    if adapt_keyname:
        # adapting to standard 2d network: modify the key name because of the add of temporal-attention
        mapping_dict = {
            "middle_block.2": "middle_block.3",
            "output_blocks.5.2": "output_blocks.5.3",
            "output_blocks.8.2": "output_blocks.8.3",
        }
        cnt = 0
        for k in list(sd_state_dict.keys()):
            for src_word, dst_word in mapping_dict.items():
                if src_word in k:
                    new_key = k.replace(src_word, dst_word)
                    sd_state_dict[new_key] = sd_state_dict[k]
                    del sd_state_dict[k]
                    cnt += 1
        main_logger.info(f"[renamed {cnt} Source keys to match Target model]")

    pretrained_dict = {
        k: v for k, v in sd_state_dict.items() if k in model_state_dict
    }  # drop extra keys
    empty_paras = [
        k for k, v in model_state_dict.items() if k not in pretrained_dict
    ]  # log no pretrained keys
    assert len(empty_paras) + len(pretrained_dict.keys()) == len(
        model_state_dict.keys()
    )

    if expand_to_3d:
        # adapting to 2d inflated network
        pretrained_dict = expand_conv_kernel(pretrained_dict)

    # overwrite entries in the existing state dict
    model_state_dict.update(pretrained_dict)

    # load the new state dict
    try:
        model.load_state_dict(model_state_dict)
    except:
        skipped = []
        model_dict_ori = model.state_dict()
        for n, p in model_state_dict.items():
            if p.shape != model_dict_ori[n].shape:
                # skip by using original empty paras
                model_state_dict[n] = model_dict_ori[n]
                main_logger.info(
                    f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_state_dict[n].shape} in current model"
                )
                skipped.append(n)
        main_logger.info(
            f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!"
        )
        model.load_state_dict(model_state_dict)
        empty_paras += skipped

    # only count Unet  part of depth estimation model
    unet_empty_paras = [
        name for name in empty_paras if name.startswith("model.diffusion_model")
    ]
    main_logger.info(
        f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)} [Unet:{len(unet_empty_paras)}]"
    )
    if echo_empty_params:
        print("Printing empty parameters:")
        for k in empty_paras:
            print(k)
    return model, empty_paras


# Below: written by Yingqing --------------------------------------------------------


def load_model_from_config(config, ckpt, verbose=False):
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        main_logger.info("missing keys:")
        main_logger.info(m)
    if len(u) > 0 and verbose:
        main_logger.info("unexpected keys:")
        main_logger.info(u)

    model.eval()
    return model


def init_and_load_ldm_model(config_path, ckpt_path, device=None):
    assert config_path.endswith(".yaml"), f"config_path = {config_path}"
    assert ckpt_path.endswith(".ckpt"), f"ckpt_path = {ckpt_path}"
    config = OmegaConf.load(config_path)
    model = load_model_from_config(config, ckpt_path)
    if device is not None:
        model = model.to(device)
    return model


def load_img_model_to_video_model(
    model,
    device=None,
    expand_to_3d=True,
    adapt_keyname=False,
    config_path="configs/latent-diffusion/txt2img-1p4B-eval.yaml",
    ckpt_path="models/ldm/text2img-large/model.ckpt",
):
    pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device)
    model, empty_paras = load_partial_weights(
        model,
        pretrained_ldm.state_dict(),
        expand_to_3d=expand_to_3d,
        adapt_keyname=adapt_keyname,
    )
    return model, empty_paras


def load_partial_weights(
    model, pretrained_dict, expand_to_3d=True, adapt_keyname=False
):
    model2 = copy.deepcopy(model)
    model_dict = model.state_dict()
    model_dict_ori = copy.deepcopy(model_dict)

    main_logger.info(f"[Load pretrained LDM weights]")
    main_logger.info(
        f"Num of parameters of source model:{len(pretrained_dict.keys())} VS. target model:{len(model_dict.keys())}"
    )

    if adapt_keyname:
        # adapting to menghan's standard 2d network: modify the key name because of the add of temporal-attention
        mapping_dict = {
            "middle_block.2": "middle_block.3",
            "output_blocks.5.2": "output_blocks.5.3",
            "output_blocks.8.2": "output_blocks.8.3",
        }
        cnt = 0
        newpretrained_dict = copy.deepcopy(pretrained_dict)
        for k, v in newpretrained_dict.items():
            for src_word, dst_word in mapping_dict.items():
                if src_word in k:
                    new_key = k.replace(src_word, dst_word)
                    pretrained_dict[new_key] = v
                    pretrained_dict.pop(k)
                    cnt += 1
        main_logger.info(f"--renamed {cnt} source keys to match target model.")
    pretrained_dict = {
        k: v for k, v in pretrained_dict.items() if k in model_dict
    }  # drop extra keys
    empty_paras = [
        k for k, v in model_dict.items() if k not in pretrained_dict
    ]  # log no pretrained keys
    main_logger.info(
        f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)}"
    )
    # disable info
    # main_logger.info(f'Empty parameters: {empty_paras} ')
    assert len(empty_paras) + len(pretrained_dict.keys()) == len(model_dict.keys())

    if expand_to_3d:
        # adapting to yingqing's 2d inflation network
        pretrained_dict = expand_conv_kernel(pretrained_dict)

    # overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)

    # load the new state dict
    try:
        model2.load_state_dict(model_dict)
    except:
        # if parameter size mismatch, skip them
        skipped = []
        for n, p in model_dict.items():
            if p.shape != model_dict_ori[n].shape:
                # skip by using original empty paras
                model_dict[n] = model_dict_ori[n]
                main_logger.info(
                    f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_dict[n].shape} in current model"
                )
                skipped.append(n)
        main_logger.info(
            f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!"
        )
        model2.load_state_dict(model_dict)
        empty_paras += skipped
        main_logger.info(f"Empty parameters: {len(empty_paras)} ")

    main_logger.info(f"Finished.")
    return model2, empty_paras


def load_autoencoder(model, config_path=None, ckpt_path=None, device=None):
    if config_path is None:
        config_path = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
    if ckpt_path is None:
        ckpt_path = "models/ldm/text2img-large/model.ckpt"

    pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device)
    autoencoder_dict = {}
    for n, p in pretrained_ldm.state_dict().items():
        if n.startswith("first_stage_model"):
            autoencoder_dict[n] = p
    model_dict = model.state_dict()
    model_dict.update(autoencoder_dict)
    main_logger.info(f"Load [{len(autoencoder_dict)}] autoencoder parameters!")

    model.load_state_dict(model_dict)

    return model