File size: 5,612 Bytes
6f08eef 2633f6b 4a7e4e0 2633f6b 4a7e4e0 2633f6b 4a7e4e0 5cd2bb7 6f08eef 4a7e4e0 2633f6b 6f08eef 4a7e4e0 6f08eef 4a7e4e0 6f08eef 4a7e4e0 6f08eef 4a7e4e0 6f08eef 4a7e4e0 6f08eef 2633f6b 6f08eef 2633f6b 6f08eef 4a7e4e0 6f08eef 4a7e4e0 6f08eef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import numpy as np
def zeromean_normalize(vertices):
vertices = np.array(vertices)
vertices = vertices - vertices.mean(axis=0)
vertices = vertices / (1e-6 + np.linalg.norm(vertices, axis=1)[:, None]) # project all verts to sphere (not what we meant)
return vertices
def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
mu_target = target_verts.mean(axis=0)
mu_in = verts_to_transform.mean(axis=0)
std_target = np.std(target_verts, axis=0)
std_in = np.std(verts_to_transform, axis=0)
if np.any(std_in == 0):
std_in[std_in == 0] = 1
if np.any(std_target == 0):
std_target[std_target == 0] = 1
if np.any(np.isnan(std_in)):
std_in[np.isnan(std_in)] = 1
if np.any(np.isnan(std_target)):
std_target[np.isnan(std_target)] = 1
if single_scale:
std_target = np.linalg.norm(std_target)
std_in = np.linalg.norm(std_in)
transformed_verts = (verts_to_transform - mu_in) / std_in
transformed_verts = transformed_verts * std_target + mu_target
return transformed_verts
def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1000.0, ce=1.0, normalized=True, prenorm=False, preregister=True, register=False, single_scale=True):
pd_vertices = np.array(pd_vertices)
gt_vertices = np.array(gt_vertices)
# Step 0: Prenormalize / preregister
if prenorm:
pd_vertices = zeromean_normalize(pd_vertices)
gt_vertices = zeromean_normalize(gt_vertices)
if preregister:
pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
pd_edges = np.array(pd_edges)
gt_edges = np.array(gt_edges)
# Step 0.5: Register
if register:
# find the optimal rotation, translation, and scale
from scipy.spatial.transform import Rotation as R
from scipy.optimize import minimize
def transform(x, pd_vertices):
# x is a 7-element vector, first 3 elements are the rotation vector, next 3 elements are the translation vector, finally scale
rotation = R.from_rotvec(x[:3])
translation = x[3:6]
scale = x[6]
return scale * rotation.apply(pd_vertices) + translation
def cost_function(x, pd_vertices, gt_vertices):
pd_vertices_transformed = transform(x, pd_vertices)
distances = cdist(pd_vertices_transformed, gt_vertices, metric='euclidean')
row_ind, col_ind = linear_sum_assignment(distances)
translation_costs = np.sum(distances[row_ind, col_ind])
return translation_costs
x0 = np.array([0, 0, 0, 0, 0, 0, 1])
# minimize subject to scale > 1e-6
# res = minimize(cost_function, x0, args=(pd_vertices, gt_vertices), constraints={'type': 'ineq', 'fun': lambda x: x[6] - 1e-6})
res = minimize(cost_function, x0, args=(pd_vertices, gt_vertices), bounds=[(-np.pi, np.pi), (-np.pi, np.pi), (-np.pi, np.pi), (-500, 500), (-500, 500), (-500, 500), (0.1, 3)])
# print("scale:", res.x)
pd_vertices = transform(res.x, pd_vertices)
# Step 1: Bipartite Matching
distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
row_ind, col_ind = linear_sum_assignment(distances)
# Step 2: Vertex Translation
translation_costs = np.sum(distances[row_ind, col_ind])
# Additional: Vertex Deletion
unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
deletion_costs = cv * len(unmatched_pd_indices)
# Step 3: Vertex Insertion
unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
insertion_costs = cv * len(unmatched_gt_indices)
# Step 4: Edge Deletion and Insertion
updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
pd_edges_set = set(map(tuple, [set(edge) for edge in updated_pd_edges]))
gt_edges_set = set(map(tuple, [set(edge) for edge in gt_edges]))
# Delete edges not in ground truth
edges_to_delete = pd_edges_set - gt_edges_set
#deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[edge[0]] - pd_vertices[edge[1]]) for edge in edges_to_delete)
vert_tf = [np.where(col_ind == v)[0][0] if v in col_ind else 0 for v in range(len(gt_vertices))]
deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[vert_tf[edge[0]]] - pd_vertices[vert_tf[edge[1]]]) for edge in edges_to_delete)
# Insert missing edges from ground truth
edges_to_insert = gt_edges_set - pd_edges_set
insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert)
# Step 5: Calculation of WED
WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
# print("translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs")
# print(translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs)
if normalized:
total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
WED = WED / total_length_of_gt_edges
# print ("Total length", total_length_of_gt_edges)
return WED |