import os, glob |
import numpy as np |
from human_body_prior.tools.configurations import load_config, dump_config |
import os.path as osp |
def exprdir2model(expr_dir): |
if not os.path.exists(expr_dir): raise ValueError('Could not find the experiment directory: %s' % expr_dir) |
model_snapshots_dir = osp.join(expr_dir, 'snapshots') |
available_ckpts = sorted(glob.glob(osp.join(model_snapshots_dir, '*.ckpt')), key=osp.getmtime) |
assert len(available_ckpts) > 0, ValueError('No checck points found at {}'.format(model_snapshots_dir)) |
trained_weigths_fname = available_ckpts[-1] |
model_ps_fname = glob.glob(osp.join('/', '/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml')) |
if len(model_ps_fname) == 0: |
model_ps_fname = glob.glob(osp.join('/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml')) |
model_ps_fname = model_ps_fname[0] |
model_ps = load_config(default_ps_fname=model_ps_fname) |
model_ps.logging.best_model_fname = trained_weigths_fname |
return model_ps, trained_weigths_fname |
def load_model(expr_dir, model_code=None, remove_words_in_model_weights=None, load_only_ps=False, disable_grad=True, custom_ps = None): |
''' |
:param expr_dir: |
:param model_code: an imported module |
from supercap.train.supercap_smpl import SuperCap, then pass SuperCap to this function |
:param if True will load the model definition used for training, and not the one in current repository |
:return: |
''' |
import importlib |
import torch |
model_ps, trained_weigths_fname = exprdir2model(expr_dir) |
if load_only_ps: return model_ps |
if custom_ps is not None: model_ps = custom_ps |
assert model_code is not None, ValueError('mode_code should be provided') |
model_instance = model_code(model_ps) |
if disable_grad: |
for param in model_instance.parameters(): |
param.requires_grad = False |
state_dict = torch.load(trained_weigths_fname)['state_dict'] |
if remove_words_in_model_weights is not None: |
words = '{}'.format(remove_words_in_model_weights) |
state_dict = {k.replace(words, '') if k.startswith(words) else k: v for k, v in state_dict.items()} |
instance_model_keys = list(model_instance.state_dict().keys()) |
trained_model_keys = list(state_dict.keys()) |
wts_in_model_not_in_file = set(instance_model_keys).difference(set(trained_model_keys)) |
wts_in_file_not_in_model = set(trained_model_keys).difference(set(instance_model_keys)) |
state_dict = {k:v for k, v in state_dict.items() if k in instance_model_keys} |
model_instance.load_state_dict(state_dict, strict=False) |
model_instance.eval() |
return model_instance, model_ps |