Spaces:
Runtime error
Runtime error
File size: 2,848 Bytes
31f2f28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os.path as osp
import math
import abc
from torch.utils.data import DataLoader
import torch.optim
import torchvision.transforms as transforms
from timer import Timer
from logger import colorlogger
from torch.nn.parallel.data_parallel import DataParallel
from config import cfg
from SMPLer_X import get_model
# ddp
import torch.distributed as dist
from torch.utils.data import DistributedSampler
import torch.utils.data.distributed
from utils.distribute_utils import (
get_rank, is_main_process, time_synchronized, get_group_idx, get_process_groups
)
class Base(object):
__metaclass__ = abc.ABCMeta
def __init__(self, log_name='logs.txt'):
self.cur_epoch = 0
# timer
self.tot_timer = Timer()
self.gpu_timer = Timer()
self.read_timer = Timer()
# logger
self.logger = colorlogger(cfg.log_dir, log_name=log_name)
@abc.abstractmethod
def _make_batch_generator(self):
return
@abc.abstractmethod
def _make_model(self):
return
class Demoer(Base):
def __init__(self, test_epoch=None):
if test_epoch is not None:
self.test_epoch = int(test_epoch)
super(Demoer, self).__init__(log_name='test_logs.txt')
def _make_batch_generator(self, demo_scene):
# data load and construct batch generator
self.logger.info("Creating dataset...")
from data.UBody.UBody import UBody
testset_loader = UBody(transforms.ToTensor(), "demo", demo_scene) # eval(demoset)(transforms.ToTensor(), "demo")
batch_generator = DataLoader(dataset=testset_loader, batch_size=cfg.num_gpus * cfg.test_batch_size,
shuffle=False, num_workers=cfg.num_thread, pin_memory=True)
self.testset = testset_loader
self.batch_generator = batch_generator
def _make_model(self):
self.logger.info('Load checkpoint from {}'.format(cfg.pretrained_model_path))
# prepare network
self.logger.info("Creating graph...")
model = get_model('test')
model = DataParallel(model).to(cfg.device)
ckpt = torch.load(cfg.pretrained_model_path, map_location=cfg.device)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in ckpt['network'].items():
if 'module' not in k:
k = 'module.' + k
k = k.replace('module.backbone', 'module.encoder').replace('body_rotation_net', 'body_regressor').replace(
'hand_rotation_net', 'hand_regressor')
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()
self.model = model
def _evaluate(self, outs, cur_sample_idx):
eval_result = self.testset.evaluate(outs, cur_sample_idx)
return eval_result
|