Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: [email protected] | |
import yaml | |
import os.path as osp | |
import torch | |
import numpy as np | |
from ..dataset.mesh_util import * | |
from ..net.geometry import orthogonal | |
import cv2, PIL | |
from tqdm import tqdm | |
import os | |
from termcolor import colored | |
import pytorch_lightning as pl | |
def init_loss(): | |
losses = { | |
# Cloth: Normal_recon - Normal_pred | |
"cloth": { | |
"weight": 1e3, | |
"value": 0.0 | |
}, | |
# Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2) | |
"stiffness": { | |
"weight": 1e5, | |
"value": 0.0 | |
}, | |
# Cloth: det(R) = 1 | |
"rigid": { | |
"weight": 1e5, | |
"value": 0.0 | |
}, | |
# Cloth: edge length | |
"edge": { | |
"weight": 0, | |
"value": 0.0 | |
}, | |
# Cloth: normal consistency | |
"nc": { | |
"weight": 0, | |
"value": 0.0 | |
}, | |
# Cloth: laplacian smoonth | |
"laplacian": { | |
"weight": 1e2, | |
"value": 0.0 | |
}, | |
# Body: Normal_pred - Normal_smpl | |
"normal": { | |
"weight": 1e0, | |
"value": 0.0 | |
}, | |
# Body: Silhouette_pred - Silhouette_smpl | |
"silhouette": { | |
"weight": 1e0, | |
"value": 0.0 | |
}, | |
# Joint: reprojected joints difference | |
"joint": { | |
"weight": 5e0, | |
"value": 0.0 | |
}, | |
} | |
return losses | |
class SubTrainer(pl.Trainer): | |
def save_checkpoint(self, filepath, weights_only=False): | |
"""Save model/training states as a checkpoint file through state-dump and file-write. | |
Args: | |
filepath: write-target file's path | |
weights_only: saving model weights only | |
""" | |
_checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only) | |
del_keys = [] | |
for key in _checkpoint["state_dict"].keys(): | |
for ignore_key in ["normal_filter", "voxelization", "reconEngine"]: | |
if ignore_key in key: | |
del_keys.append(key) | |
for key in del_keys: | |
del _checkpoint["state_dict"][key] | |
pl.utilities.cloud_io.atomic_save(_checkpoint, filepath) | |
def rename(old_dict, old_name, new_name): | |
new_dict = {} | |
for key, value in zip(old_dict.keys(), old_dict.values()): | |
new_key = key if key != old_name else new_name | |
new_dict[new_key] = old_dict[key] | |
return new_dict | |
def load_normal_networks(model, normal_path): | |
pretrained_dict = torch.load( | |
normal_path, | |
map_location=model.device)["state_dict"] | |
model_dict = model.state_dict() | |
# 1. filter out unnecessary keys | |
pretrained_dict = { | |
k: v | |
for k, v in pretrained_dict.items() | |
if k in model_dict and v.shape == model_dict[k].shape | |
} | |
# # 2. overwrite entries in the existing state dict | |
model_dict.update(pretrained_dict) | |
# 3. load the new state dict | |
model.load_state_dict(model_dict) | |
del pretrained_dict | |
del model_dict | |
print(colored(f"Resume Normal weights from {normal_path}", "green")) | |
def load_networks(model, mlp_path, normal_path=None): | |
model_dict = model.state_dict() | |
main_dict = {} | |
normal_dict = {} | |
# MLP part loading | |
if os.path.exists(mlp_path) and mlp_path.endswith("ckpt"): | |
main_dict = torch.load( | |
mlp_path, | |
map_location=model.device)["state_dict"] | |
main_dict = { | |
k: v | |
for k, v in main_dict.items() | |
if k in model_dict and v.shape == model_dict[k].shape and ( | |
"reconEngine" not in k) and ("normal_filter" not in k) and ( | |
"voxelization" not in k) | |
} | |
print(colored(f"Resume MLP weights from {mlp_path}", "green")) | |
# normal network part loading | |
if normal_path is not None and os.path.exists(normal_path) and normal_path.endswith("ckpt"): | |
normal_dict = torch.load( | |
normal_path, | |
map_location=model.device)["state_dict"] | |
for key in normal_dict.keys(): | |
normal_dict = rename(normal_dict, key, | |
key.replace("netG", "netG.normal_filter")) | |
normal_dict = { | |
k: v | |
for k, v in normal_dict.items() | |
if k in model_dict and v.shape == model_dict[k].shape | |
} | |
print(colored(f"Resume normal model from {normal_path}", "green")) | |
model_dict.update(main_dict) | |
model_dict.update(normal_dict) | |
model.load_state_dict(model_dict) | |
# clean unused GPU memory | |
del main_dict | |
del normal_dict | |
del model_dict | |
torch.cuda.empty_cache() | |
def reshape_sample_tensor(sample_tensor, num_views): | |
if num_views == 1: | |
return sample_tensor | |
# Need to repeat sample_tensor along the batch dim num_views times | |
sample_tensor = sample_tensor.unsqueeze(dim=1) | |
sample_tensor = sample_tensor.repeat(1, num_views, 1, 1) | |
sample_tensor = sample_tensor.view( | |
sample_tensor.shape[0] * sample_tensor.shape[1], | |
sample_tensor.shape[2], | |
sample_tensor.shape[3], | |
) | |
return sample_tensor | |
def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): | |
"""Sets the learning rate to the initial LR decayed by schedule""" | |
if epoch in schedule: | |
lr *= gamma | |
for param_group in optimizer.param_groups: | |
param_group["lr"] = lr | |
return lr | |
def compute_acc(pred, gt, thresh=0.5): | |
""" | |
return: | |
IOU, precision, and recall | |
""" | |
with torch.no_grad(): | |
vol_pred = pred > thresh | |
vol_gt = gt > thresh | |
union = vol_pred | vol_gt | |
inter = vol_pred & vol_gt | |
true_pos = inter.sum().float() | |
union = union.sum().float() | |
if union == 0: | |
union = 1 | |
vol_pred = vol_pred.sum().float() | |
if vol_pred == 0: | |
vol_pred = 1 | |
vol_gt = vol_gt.sum().float() | |
if vol_gt == 0: | |
vol_gt = 1 | |
return true_pos / union, true_pos / vol_pred, true_pos / vol_gt | |
def calc_error(opt, net, cuda, dataset, num_tests): | |
if num_tests > len(dataset): | |
num_tests = len(dataset) | |
with torch.no_grad(): | |
erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], [] | |
for idx in tqdm(range(num_tests)): | |
data = dataset[idx * len(dataset) // num_tests] | |
# retrieve the data | |
image_tensor = data["img"].to(device=cuda) | |
calib_tensor = data["calib"].to(device=cuda) | |
sample_tensor = data["samples"].to(device=cuda).unsqueeze(0) | |
if opt.num_views > 1: | |
sample_tensor = reshape_sample_tensor(sample_tensor, | |
opt.num_views) | |
label_tensor = data["labels"].to(device=cuda).unsqueeze(0) | |
res, error = net.forward(image_tensor, | |
sample_tensor, | |
calib_tensor, | |
labels=label_tensor) | |
IOU, prec, recall = compute_acc(res, label_tensor) | |
# print( | |
# '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}' | |
# .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item())) | |
erorr_arr.append(error.item()) | |
IOU_arr.append(IOU.item()) | |
prec_arr.append(prec.item()) | |
recall_arr.append(recall.item()) | |
return ( | |
np.average(erorr_arr), | |
np.average(IOU_arr), | |
np.average(prec_arr), | |
np.average(recall_arr), | |
) | |
def calc_error_color(opt, netG, netC, cuda, dataset, num_tests): | |
if num_tests > len(dataset): | |
num_tests = len(dataset) | |
with torch.no_grad(): | |
error_color_arr = [] | |
for idx in tqdm(range(num_tests)): | |
data = dataset[idx * len(dataset) // num_tests] | |
# retrieve the data | |
image_tensor = data["img"].to(device=cuda) | |
calib_tensor = data["calib"].to(device=cuda) | |
color_sample_tensor = data["color_samples"].to( | |
device=cuda).unsqueeze(0) | |
if opt.num_views > 1: | |
color_sample_tensor = reshape_sample_tensor( | |
color_sample_tensor, opt.num_views) | |
rgb_tensor = data["rgbs"].to(device=cuda).unsqueeze(0) | |
netG.filter(image_tensor) | |
_, errorC = netC.forward( | |
image_tensor, | |
netG.get_im_feat(), | |
color_sample_tensor, | |
calib_tensor, | |
labels=rgb_tensor, | |
) | |
# print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}' | |
# .format(idx, num_tests, errorG.item(), errorC.item())) | |
error_color_arr.append(errorC.item()) | |
return np.average(error_color_arr) | |
# pytorch lightning training related fucntions | |
def query_func(opt, netG, features, points, proj_matrix=None): | |
""" | |
- points: size of (bz, N, 3) | |
- proj_matrix: size of (bz, 4, 4) | |
return: size of (bz, 1, N) | |
""" | |
assert len(points) == 1 | |
samples = points.repeat(opt.num_views, 1, 1) | |
samples = samples.permute(0, 2, 1) # [bz, 3, N] | |
# view specific query | |
if proj_matrix is not None: | |
samples = orthogonal(samples, proj_matrix) | |
calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples) | |
preds = netG.query( | |
features=features, | |
points=samples, | |
calibs=calib_tensor, | |
regressor=netG.if_regressor, | |
) | |
if type(preds) is list: | |
preds = preds[0] | |
return preds | |
def query_func_IF(batch, netG, points): | |
""" | |
- points: size of (bz, N, 3) | |
return: size of (bz, 1, N) | |
""" | |
batch["samples_geo"] = points | |
batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points) | |
preds = netG(batch) | |
return preds.unsqueeze(1) | |
def isin(ar1, ar2): | |
return (ar1[..., None] == ar2).any(-1) | |
def in1d(ar1, ar2): | |
mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool) | |
mask[ar2.unique()] = True | |
return mask[ar1] | |
def batch_mean(res, key): | |
return torch.stack([ | |
x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) | |
for x in res | |
]).mean() | |
def tf_log_convert(log_dict): | |
new_log_dict = log_dict.copy() | |
for k, v in log_dict.items(): | |
new_log_dict[k.replace("_", "/")] = v | |
del new_log_dict[k] | |
return new_log_dict | |
def bar_log_convert(log_dict, name=None, rot=None): | |
from decimal import Decimal | |
new_log_dict = {} | |
if name is not None: | |
new_log_dict["name"] = name[0] | |
if rot is not None: | |
new_log_dict["rot"] = rot[0] | |
for k, v in log_dict.items(): | |
color = "yellow" | |
if "loss" in k: | |
color = "red" | |
k = k.replace("loss", "L") | |
elif "acc" in k: | |
color = "green" | |
k = k.replace("acc", "A") | |
elif "iou" in k: | |
color = "green" | |
k = k.replace("iou", "I") | |
elif "prec" in k: | |
color = "green" | |
k = k.replace("prec", "P") | |
elif "recall" in k: | |
color = "green" | |
k = k.replace("recall", "R") | |
if "lr" not in k: | |
new_log_dict[colored(k.split("_")[1], | |
color)] = colored(f"{v:.3f}", color) | |
else: | |
new_log_dict[colored(k.split("_")[1], | |
color)] = colored(f"{Decimal(str(v)):.1E}", | |
color) | |
if "loss" in new_log_dict.keys(): | |
del new_log_dict["loss"] | |
return new_log_dict | |
def accumulate(outputs, rot_num, split): | |
hparam_log_dict = {} | |
metrics = outputs[0].keys() | |
datasets = split.keys() | |
for dataset in datasets: | |
for metric in metrics: | |
keyword = f"{dataset}/{metric}" | |
if keyword not in hparam_log_dict.keys(): | |
hparam_log_dict[keyword] = 0 | |
for idx in range(split[dataset][0] * rot_num, | |
split[dataset][1] * rot_num): | |
hparam_log_dict[keyword] += outputs[idx][metric].item() | |
hparam_log_dict[keyword] /= (split[dataset][1] - | |
split[dataset][0]) * rot_num | |
print(colored(hparam_log_dict, "green")) | |
return hparam_log_dict | |
def calc_error_N(outputs, targets): | |
"""calculate the error of normal (IGR) | |
Args: | |
outputs (torch.tensor): [B, 3, N] | |
target (torch.tensor): [B, N, 3] | |
# manifold loss and grad_loss in IGR paper | |
grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() | |
normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean() | |
Returns: | |
torch.tensor: error of valid normals on the surface | |
""" | |
# outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3)) | |
outputs = -outputs.permute(0, 2, 1).reshape(-1, 1) | |
targets = targets.reshape(-1, 3)[:, 2:3] | |
with_normals = targets.sum(dim=1).abs() > 0.0 | |
# eikonal loss | |
grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean() | |
# normals loss | |
normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean() | |
return grad_loss * 0.0 + normal_loss | |
def calc_knn_acc(preds, carn_verts, labels, pick_num): | |
"""calculate knn accuracy | |
Args: | |
preds (torch.tensor): [B, 3, N] | |
carn_verts (torch.tensor): [SMPLX_V_num, 3] | |
labels (torch.tensor): [B, N_knn, N] | |
""" | |
N_knn_full = labels.shape[1] | |
preds = preds.permute(0, 2, 1).reshape(-1, 3) | |
labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn] | |
labels = labels[:, :pick_num] | |
dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num] | |
knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn] | |
cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0] | |
bool_col = torch.zeros_like(cat_mat)[:, 0] | |
for i in range(pick_num * 2 - 1): | |
bool_col += cat_mat[:, i] == cat_mat[:, i + 1] | |
acc = (bool_col > 0).sum() / len(bool_col) | |
return acc | |
def calc_acc_seg(output, target, num_multiseg): | |
from pytorch_lightning.metrics import Accuracy | |
return Accuracy()(output.reshape(-1, num_multiseg).cpu(), | |
target.flatten().cpu()) | |
def add_watermark(imgs, titles): | |
# Write some Text | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
bottomLeftCornerOfText = (350, 50) | |
bottomRightCornerOfText = (800, 50) | |
fontScale = 1 | |
fontColor = (1.0, 1.0, 1.0) | |
lineType = 2 | |
for i in range(len(imgs)): | |
title = titles[i + 1] | |
cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale, | |
fontColor, lineType) | |
if i == 0: | |
cv2.putText( | |
imgs[i], | |
str(titles[i][0]), | |
bottomRightCornerOfText, | |
font, | |
fontScale, | |
fontColor, | |
lineType, | |
) | |
result = np.concatenate(imgs, axis=0).transpose(2, 0, 1) | |
return result | |
def make_test_gif(img_dir): | |
if img_dir is not None and len(os.listdir(img_dir)) > 0: | |
for dataset in os.listdir(img_dir): | |
for subject in sorted(os.listdir(osp.join(img_dir, dataset))): | |
img_lst = [] | |
im1 = None | |
for file in sorted( | |
os.listdir(osp.join(img_dir, dataset, subject))): | |
if file[-3:] not in ["obj", "gif"]: | |
img_path = os.path.join(img_dir, dataset, subject, | |
file) | |
if im1 == None: | |
im1 = PIL.Image.open(img_path) | |
else: | |
img_lst.append(PIL.Image.open(img_path)) | |
print(os.path.join(img_dir, dataset, subject, "out.gif")) | |
im1.save( | |
os.path.join(img_dir, dataset, subject, "out.gif"), | |
save_all=True, | |
append_images=img_lst, | |
duration=500, | |
loop=0, | |
) | |
def export_cfg(logger, dir, cfg): | |
cfg_export_file = osp.join(dir, f"cfg_{logger.version}.yaml") | |
if not osp.exists(cfg_export_file): | |
os.makedirs(osp.dirname(cfg_export_file), exist_ok=True) | |
with open(cfg_export_file, "w+") as file: | |
_ = yaml.dump(cfg, file) | |
from yacs.config import CfgNode | |
_VALID_TYPES = {tuple, list, str, int, float, bool} | |
def convert_to_dict(cfg_node, key_list=[]): | |
""" Convert a config node to dictionary """ | |
if not isinstance(cfg_node, CfgNode): | |
if type(cfg_node) not in _VALID_TYPES: | |
print( | |
"Key {} with value {} is not a valid type; valid types: {}". | |
format(".".join(key_list), type(cfg_node), _VALID_TYPES), ) | |
return cfg_node | |
else: | |
cfg_dict = dict(cfg_node) | |
for k, v in cfg_dict.items(): | |
cfg_dict[k] = convert_to_dict(v, key_list + [k]) | |
return cfg_dict | |