Spaces:
Running
on
Zero
Running
on
Zero
from utils.utils import instantiate_from_config | |
import torch | |
import copy | |
from omegaconf import OmegaConf | |
import logging | |
main_logger = logging.getLogger("main_logger") | |
def expand_conv_kernel(pretrained_dict): | |
"""expand 2d conv parameters from 4D -> 5D""" | |
for k, v in pretrained_dict.items(): | |
if v.dim() == 4 and not k.startswith("first_stage_model"): | |
v = v.unsqueeze(2) | |
pretrained_dict[k] = v | |
return pretrained_dict | |
def print_state_dict(state_dict): | |
print("====== Dumping State Dict ======") | |
for k, v in state_dict.items(): | |
print(k, v.shape) | |
def load_from_pretrainedSD_checkpoint( | |
model, | |
pretained_ckpt, | |
expand_to_3d=True, | |
adapt_keyname=False, | |
echo_empty_params=False, | |
): | |
sd_state_dict = torch.load(pretained_ckpt, map_location="cpu") | |
if "state_dict" in list(sd_state_dict.keys()): | |
sd_state_dict = sd_state_dict["state_dict"] | |
model_state_dict = model.state_dict() | |
# delete ema_weights just for <precise param counting> | |
for k in list(sd_state_dict.keys()): | |
if k.startswith("model_ema"): | |
del sd_state_dict[k] | |
main_logger.info( | |
f"Num of model params of Source:{len(sd_state_dict.keys())} VS. Target:{len(model_state_dict.keys())}" | |
) | |
# print_state_dict(model_state_dict) | |
# print_state_dict(sd_state_dict) | |
if adapt_keyname: | |
# adapting to standard 2d network: modify the key name because of the add of temporal-attention | |
mapping_dict = { | |
"middle_block.2": "middle_block.3", | |
"output_blocks.5.2": "output_blocks.5.3", | |
"output_blocks.8.2": "output_blocks.8.3", | |
} | |
cnt = 0 | |
for k in list(sd_state_dict.keys()): | |
for src_word, dst_word in mapping_dict.items(): | |
if src_word in k: | |
new_key = k.replace(src_word, dst_word) | |
sd_state_dict[new_key] = sd_state_dict[k] | |
del sd_state_dict[k] | |
cnt += 1 | |
main_logger.info(f"[renamed {cnt} Source keys to match Target model]") | |
pretrained_dict = { | |
k: v for k, v in sd_state_dict.items() if k in model_state_dict | |
} # drop extra keys | |
empty_paras = [ | |
k for k, v in model_state_dict.items() if k not in pretrained_dict | |
] # log no pretrained keys | |
assert len(empty_paras) + len(pretrained_dict.keys()) == len( | |
model_state_dict.keys() | |
) | |
if expand_to_3d: | |
# adapting to 2d inflated network | |
pretrained_dict = expand_conv_kernel(pretrained_dict) | |
# overwrite entries in the existing state dict | |
model_state_dict.update(pretrained_dict) | |
# load the new state dict | |
try: | |
model.load_state_dict(model_state_dict) | |
except: | |
skipped = [] | |
model_dict_ori = model.state_dict() | |
for n, p in model_state_dict.items(): | |
if p.shape != model_dict_ori[n].shape: | |
# skip by using original empty paras | |
model_state_dict[n] = model_dict_ori[n] | |
main_logger.info( | |
f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_state_dict[n].shape} in current model" | |
) | |
skipped.append(n) | |
main_logger.info( | |
f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!" | |
) | |
model.load_state_dict(model_state_dict) | |
empty_paras += skipped | |
# only count Unet part of depth estimation model | |
unet_empty_paras = [ | |
name for name in empty_paras if name.startswith("model.diffusion_model") | |
] | |
main_logger.info( | |
f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)} [Unet:{len(unet_empty_paras)}]" | |
) | |
if echo_empty_params: | |
print("Printing empty parameters:") | |
for k in empty_paras: | |
print(k) | |
return model, empty_paras | |
# Below: written by Yingqing -------------------------------------------------------- | |
def load_model_from_config(config, ckpt, verbose=False): | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
sd = pl_sd["state_dict"] | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
main_logger.info("missing keys:") | |
main_logger.info(m) | |
if len(u) > 0 and verbose: | |
main_logger.info("unexpected keys:") | |
main_logger.info(u) | |
model.eval() | |
return model | |
def init_and_load_ldm_model(config_path, ckpt_path, device=None): | |
assert config_path.endswith(".yaml"), f"config_path = {config_path}" | |
assert ckpt_path.endswith(".ckpt"), f"ckpt_path = {ckpt_path}" | |
config = OmegaConf.load(config_path) | |
model = load_model_from_config(config, ckpt_path) | |
if device is not None: | |
model = model.to(device) | |
return model | |
def load_img_model_to_video_model( | |
model, | |
device=None, | |
expand_to_3d=True, | |
adapt_keyname=False, | |
config_path="configs/latent-diffusion/txt2img-1p4B-eval.yaml", | |
ckpt_path="models/ldm/text2img-large/model.ckpt", | |
): | |
pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device) | |
model, empty_paras = load_partial_weights( | |
model, | |
pretrained_ldm.state_dict(), | |
expand_to_3d=expand_to_3d, | |
adapt_keyname=adapt_keyname, | |
) | |
return model, empty_paras | |
def load_partial_weights( | |
model, pretrained_dict, expand_to_3d=True, adapt_keyname=False | |
): | |
model2 = copy.deepcopy(model) | |
model_dict = model.state_dict() | |
model_dict_ori = copy.deepcopy(model_dict) | |
main_logger.info(f"[Load pretrained LDM weights]") | |
main_logger.info( | |
f"Num of parameters of source model:{len(pretrained_dict.keys())} VS. target model:{len(model_dict.keys())}" | |
) | |
if adapt_keyname: | |
# adapting to menghan's standard 2d network: modify the key name because of the add of temporal-attention | |
mapping_dict = { | |
"middle_block.2": "middle_block.3", | |
"output_blocks.5.2": "output_blocks.5.3", | |
"output_blocks.8.2": "output_blocks.8.3", | |
} | |
cnt = 0 | |
newpretrained_dict = copy.deepcopy(pretrained_dict) | |
for k, v in newpretrained_dict.items(): | |
for src_word, dst_word in mapping_dict.items(): | |
if src_word in k: | |
new_key = k.replace(src_word, dst_word) | |
pretrained_dict[new_key] = v | |
pretrained_dict.pop(k) | |
cnt += 1 | |
main_logger.info(f"--renamed {cnt} source keys to match target model.") | |
pretrained_dict = { | |
k: v for k, v in pretrained_dict.items() if k in model_dict | |
} # drop extra keys | |
empty_paras = [ | |
k for k, v in model_dict.items() if k not in pretrained_dict | |
] # log no pretrained keys | |
main_logger.info( | |
f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)}" | |
) | |
# disable info | |
# main_logger.info(f'Empty parameters: {empty_paras} ') | |
assert len(empty_paras) + len(pretrained_dict.keys()) == len(model_dict.keys()) | |
if expand_to_3d: | |
# adapting to yingqing's 2d inflation network | |
pretrained_dict = expand_conv_kernel(pretrained_dict) | |
# overwrite entries in the existing state dict | |
model_dict.update(pretrained_dict) | |
# load the new state dict | |
try: | |
model2.load_state_dict(model_dict) | |
except: | |
# if parameter size mismatch, skip them | |
skipped = [] | |
for n, p in model_dict.items(): | |
if p.shape != model_dict_ori[n].shape: | |
# skip by using original empty paras | |
model_dict[n] = model_dict_ori[n] | |
main_logger.info( | |
f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_dict[n].shape} in current model" | |
) | |
skipped.append(n) | |
main_logger.info( | |
f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!" | |
) | |
model2.load_state_dict(model_dict) | |
empty_paras += skipped | |
main_logger.info(f"Empty parameters: {len(empty_paras)} ") | |
main_logger.info(f"Finished.") | |
return model2, empty_paras | |
def load_autoencoder(model, config_path=None, ckpt_path=None, device=None): | |
if config_path is None: | |
config_path = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" | |
if ckpt_path is None: | |
ckpt_path = "models/ldm/text2img-large/model.ckpt" | |
pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device) | |
autoencoder_dict = {} | |
for n, p in pretrained_ldm.state_dict().items(): | |
if n.startswith("first_stage_model"): | |
autoencoder_dict[n] = p | |
model_dict = model.state_dict() | |
model_dict.update(autoencoder_dict) | |
main_logger.info(f"Load [{len(autoencoder_dict)}] autoencoder parameters!") | |
model.load_state_dict(model_dict) | |
return model | |