File size: 5,612 Bytes
6f08eef
 
 
 
2633f6b
 
 
 
4a7e4e0
2633f6b
 
4a7e4e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2633f6b
4a7e4e0
5cd2bb7
6f08eef
 
4a7e4e0
 
 
 
 
 
 
 
 
2633f6b
6f08eef
 
 
 
4a7e4e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f08eef
4a7e4e0
6f08eef
4a7e4e0
 
6f08eef
 
 
4a7e4e0
6f08eef
 
 
4a7e4e0
6f08eef
 
2633f6b
 
 
 
6f08eef
 
 
2633f6b
 
 
 
 
6f08eef
 
 
 
 
 
 
4a7e4e0
 
6f08eef
 
 
 
4a7e4e0
 
6f08eef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import numpy as np 


def zeromean_normalize(vertices):
    vertices = np.array(vertices)
    vertices = vertices - vertices.mean(axis=0)
    vertices = vertices / (1e-6 + np.linalg.norm(vertices, axis=1)[:, None]) # project all verts to sphere (not what we meant)
    return vertices

def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
    mu_target = target_verts.mean(axis=0)
    mu_in = verts_to_transform.mean(axis=0)
    std_target = np.std(target_verts, axis=0)
    std_in = np.std(verts_to_transform, axis=0)
    
    if np.any(std_in == 0):
        std_in[std_in == 0] = 1
    if np.any(std_target == 0):
        std_target[std_target == 0] = 1
    if np.any(np.isnan(std_in)):
        std_in[np.isnan(std_in)] = 1
    if np.any(np.isnan(std_target)):
        std_target[np.isnan(std_target)] = 1
        
    if single_scale:
        std_target = np.linalg.norm(std_target)
        std_in = np.linalg.norm(std_in)
    
    transformed_verts = (verts_to_transform - mu_in) / std_in
    transformed_verts = transformed_verts * std_target + mu_target
    
    return transformed_verts


def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1000.0, ce=1.0, normalized=True, prenorm=False, preregister=True, register=False, single_scale=True):
    pd_vertices = np.array(pd_vertices)
    gt_vertices = np.array(gt_vertices)
    
    # Step 0: Prenormalize / preregister
    if prenorm:
        pd_vertices = zeromean_normalize(pd_vertices)
        gt_vertices = zeromean_normalize(gt_vertices)
        
    if preregister:
        pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
        

    pd_edges = np.array(pd_edges)
    gt_edges = np.array(gt_edges)
    

    # Step 0.5: Register
    if register:
        # find the optimal rotation, translation, and scale
        from scipy.spatial.transform import Rotation as R
        from scipy.optimize import minimize
        
        def transform(x, pd_vertices):
            # x is a 7-element vector, first 3 elements are the rotation vector, next 3 elements are the translation vector, finally scale
            rotation = R.from_rotvec(x[:3])
            translation = x[3:6]
            scale = x[6]
            return scale * rotation.apply(pd_vertices) + translation
        
        def cost_function(x, pd_vertices, gt_vertices):
            pd_vertices_transformed = transform(x, pd_vertices)
            distances = cdist(pd_vertices_transformed, gt_vertices, metric='euclidean')
            row_ind, col_ind = linear_sum_assignment(distances)
            translation_costs = np.sum(distances[row_ind, col_ind])
            
            return translation_costs
        
        x0 = np.array([0, 0, 0, 0, 0, 0, 1])
        # minimize subject to scale > 1e-6
        # res = minimize(cost_function, x0, args=(pd_vertices, gt_vertices), constraints={'type': 'ineq', 'fun': lambda x: x[6] - 1e-6})
        res = minimize(cost_function, x0, args=(pd_vertices, gt_vertices), bounds=[(-np.pi, np.pi), (-np.pi, np.pi), (-np.pi, np.pi), (-500, 500), (-500, 500), (-500, 500), (0.1, 3)])
        # print("scale:", res.x)

        pd_vertices = transform(res.x, pd_vertices) 
        
    
    # Step 1: Bipartite Matching
    distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
    row_ind, col_ind = linear_sum_assignment(distances)
        
    
    # Step 2: Vertex Translation
    translation_costs = np.sum(distances[row_ind, col_ind])
    
    # Additional: Vertex Deletion
    unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
    deletion_costs = cv * len(unmatched_pd_indices)  
    
    # Step 3: Vertex Insertion
    unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
    insertion_costs = cv * len(unmatched_gt_indices)  
    
    # Step 4: Edge Deletion and Insertion
    updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
    pd_edges_set = set(map(tuple, [set(edge) for edge in updated_pd_edges]))
    gt_edges_set = set(map(tuple, [set(edge) for edge in gt_edges]))

    
    # Delete edges not in ground truth
    edges_to_delete = pd_edges_set - gt_edges_set
    
    #deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[edge[0]] - pd_vertices[edge[1]]) for edge in edges_to_delete)
    vert_tf = [np.where(col_ind == v)[0][0] if v in col_ind else 0 for v in range(len(gt_vertices))]
    deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[vert_tf[edge[0]]] - pd_vertices[vert_tf[edge[1]]]) for edge in edges_to_delete)

    
    # Insert missing edges from ground truth
    edges_to_insert = gt_edges_set - pd_edges_set
    insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert) 
    
    # Step 5: Calculation of WED
    WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
    # print("translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs")
    # print(translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs)
    
    if normalized:
        total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
        WED = WED / total_length_of_gt_edges
        
    # print ("Total length", total_length_of_gt_edges)
    return WED