dmytromishkin
commited on
Create wed.py
Browse files- hoho/wed.py +54 -0
hoho/wed.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.spatial.distance import cdist
|
2 |
+
from scipy.optimize import linear_sum_assignment
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv, ce, normalized=True, squared=False):
|
6 |
+
pd_vertices = np.array(pd_vertices)
|
7 |
+
gt_vertices = np.array(gt_vertices)
|
8 |
+
pd_edges = np.array(pd_edges)
|
9 |
+
gt_edges = np.array(gt_edges)
|
10 |
+
|
11 |
+
# Step 1: Bipartite Matching
|
12 |
+
if squared:
|
13 |
+
distances = cdist(pd_vertices, gt_vertices, metric='sqeuclidean')
|
14 |
+
else:
|
15 |
+
distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
|
16 |
+
|
17 |
+
row_ind, col_ind = linear_sum_assignment(distances)
|
18 |
+
|
19 |
+
# Step 2: Vertex Translation
|
20 |
+
|
21 |
+
if squared:
|
22 |
+
translation_costs = cv * np.sqrt(np.sum(distances[row_ind, col_ind]))
|
23 |
+
else:
|
24 |
+
translation_costs = cv * np.sum(distances[row_ind, col_ind])
|
25 |
+
|
26 |
+
# Additional: Vertex Deletion
|
27 |
+
unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
|
28 |
+
deletion_costs = cv * len(unmatched_pd_indices) # Assuming a fixed cost for vertex deletion
|
29 |
+
|
30 |
+
# Step 3: Vertex Insertion
|
31 |
+
unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
|
32 |
+
insertion_costs = cv * len(unmatched_gt_indices) # Assuming a fixed cost for vertex insertion
|
33 |
+
|
34 |
+
# Step 4: Edge Deletion and Insertion
|
35 |
+
updated_pd_edges = [(row_ind[np.where(col_ind == edge[0])[0][0]], row_ind[np.where(col_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in col_ind and edge[1] in col_ind]
|
36 |
+
pd_edges_set = set(map(tuple, updated_pd_edges))
|
37 |
+
gt_edges_set = set(map(tuple, gt_edges))
|
38 |
+
|
39 |
+
# Delete edges not in ground truth
|
40 |
+
edges_to_delete = pd_edges_set - gt_edges_set
|
41 |
+
deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[edge[0]] - pd_vertices[edge[1]]) for edge in edges_to_delete)
|
42 |
+
|
43 |
+
# Insert missing edges from ground truth
|
44 |
+
edges_to_insert = gt_edges_set - pd_edges_set
|
45 |
+
insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert)
|
46 |
+
|
47 |
+
# Step 5: Calculation of WED
|
48 |
+
WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
|
49 |
+
|
50 |
+
if normalized:
|
51 |
+
total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
|
52 |
+
WED = WED / total_length_of_gt_edges
|
53 |
+
|
54 |
+
return WED
|