|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, glob |
|
import numpy as np |
|
from human_body_prior.tools.configurations import load_config, dump_config |
|
import os.path as osp |
|
|
|
def exprdir2model(expr_dir): |
|
|
|
if not os.path.exists(expr_dir): raise ValueError('Could not find the experiment directory: %s' % expr_dir) |
|
|
|
model_snapshots_dir = osp.join(expr_dir, 'snapshots') |
|
available_ckpts = sorted(glob.glob(osp.join(model_snapshots_dir, '*.ckpt')), key=osp.getmtime) |
|
assert len(available_ckpts) > 0, ValueError('No checck points found at {}'.format(model_snapshots_dir)) |
|
trained_weigths_fname = available_ckpts[-1] |
|
|
|
model_ps_fname = glob.glob(osp.join('/', '/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml')) |
|
if len(model_ps_fname) == 0: |
|
model_ps_fname = glob.glob(osp.join('/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml')) |
|
|
|
model_ps_fname = model_ps_fname[0] |
|
model_ps = load_config(default_ps_fname=model_ps_fname) |
|
|
|
model_ps.logging.best_model_fname = trained_weigths_fname |
|
|
|
return model_ps, trained_weigths_fname |
|
|
|
|
|
def load_model(expr_dir, model_code=None, remove_words_in_model_weights=None, load_only_ps=False, disable_grad=True, custom_ps = None): |
|
''' |
|
|
|
:param expr_dir: |
|
:param model_code: an imported module |
|
from supercap.train.supercap_smpl import SuperCap, then pass SuperCap to this function |
|
:param if True will load the model definition used for training, and not the one in current repository |
|
:return: |
|
''' |
|
import importlib |
|
import torch |
|
|
|
model_ps, trained_weigths_fname = exprdir2model(expr_dir) |
|
if load_only_ps: return model_ps |
|
if custom_ps is not None: model_ps = custom_ps |
|
assert model_code is not None, ValueError('mode_code should be provided') |
|
model_instance = model_code(model_ps) |
|
if disable_grad: |
|
for param in model_instance.parameters(): |
|
param.requires_grad = False |
|
state_dict = torch.load(trained_weigths_fname)['state_dict'] |
|
if remove_words_in_model_weights is not None: |
|
words = '{}'.format(remove_words_in_model_weights) |
|
state_dict = {k.replace(words, '') if k.startswith(words) else k: v for k, v in state_dict.items()} |
|
|
|
|
|
instance_model_keys = list(model_instance.state_dict().keys()) |
|
trained_model_keys = list(state_dict.keys()) |
|
wts_in_model_not_in_file = set(instance_model_keys).difference(set(trained_model_keys)) |
|
|
|
wts_in_file_not_in_model = set(trained_model_keys).difference(set(instance_model_keys)) |
|
|
|
|
|
state_dict = {k:v for k, v in state_dict.items() if k in instance_model_keys} |
|
model_instance.load_state_dict(state_dict, strict=False) |
|
model_instance.eval() |
|
|
|
return model_instance, model_ps |
|
|
|
|
|
|