Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
import time | |
import pickle | |
import torch | |
import torch.nn as nn | |
from utilities.distributed import is_main_process | |
logger = logging.getLogger(__name__) | |
NORM_MODULES = [ | |
torch.nn.BatchNorm1d, | |
torch.nn.BatchNorm2d, | |
torch.nn.BatchNorm3d, | |
torch.nn.SyncBatchNorm, | |
# NaiveSyncBatchNorm inherits from BatchNorm2d | |
torch.nn.GroupNorm, | |
torch.nn.InstanceNorm1d, | |
torch.nn.InstanceNorm2d, | |
torch.nn.InstanceNorm3d, | |
torch.nn.LayerNorm, | |
torch.nn.LocalResponseNorm, | |
] | |
def register_norm_module(cls): | |
NORM_MODULES.append(cls) | |
return cls | |
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): | |
model_keys = sorted(model_state_dict.keys()) | |
ckpt_keys = sorted(ckpt_state_dict.keys()) | |
result_dicts = {} | |
matched_log = [] | |
unmatched_log = [] | |
unloaded_log = [] | |
for model_key in model_keys: | |
model_weight = model_state_dict[model_key] | |
if model_key in ckpt_keys: | |
ckpt_weight = ckpt_state_dict[model_key] | |
if model_weight.shape == ckpt_weight.shape: | |
result_dicts[model_key] = ckpt_weight | |
ckpt_keys.pop(ckpt_keys.index(model_key)) | |
matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) | |
else: | |
unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) | |
else: | |
unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) | |
if is_main_process(): | |
for info in matched_log: | |
logger.info(info) | |
for info in unloaded_log: | |
logger.warning(info) | |
for key in ckpt_keys: | |
logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) | |
for info in unmatched_log: | |
logger.warning(info) | |
return result_dicts |