|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def edgeLoss(preds_edges, edges): |
|
""" |
|
|
|
Args: |
|
preds_edges: with shape [b, c, h , w] |
|
edges: with shape [b, c, h, w] |
|
|
|
Returns: Edge losses |
|
|
|
""" |
|
mask = (edges > 0.5).float() |
|
b, c, h, w = mask.shape |
|
num_pos = torch.sum(mask, dim=[1, 2, 3]).float() |
|
num_neg = c * h * w - num_pos |
|
neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) |
|
pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) |
|
weight = neg_weights * mask + pos_weights * (1 - mask) |
|
losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none') |
|
loss = torch.mean(losses) |
|
return loss |
|
|
|
|
|
class EdgeAcc(nn.Module): |
|
""" |
|
Measure the accuracy of the edge map |
|
""" |
|
|
|
def __init__(self, threshold=0.5): |
|
super(EdgeAcc, self).__init__() |
|
self.threshold = threshold |
|
|
|
def __call__(self, pred_edge, gt_edge): |
|
""" |
|
|
|
Args: |
|
pred_edge: Predicted edges, with shape [b, c, h, w] |
|
gt_edge: GT edges, with shape [b, c, h, w] |
|
|
|
Returns: The prediction accuracy and the recall of the edges |
|
|
|
""" |
|
labels = gt_edge > self.threshold |
|
preds = pred_edge > self.threshold |
|
|
|
relevant = torch.sum(labels.float()) |
|
selected = torch.sum(preds.float()) |
|
|
|
if relevant == 0 and selected == 0: |
|
return torch.tensor(1), torch.tensor(1) |
|
|
|
true_positive = ((preds == labels) * labels).float() |
|
recall = torch.sum(true_positive) / (relevant + 1e-8) |
|
precision = torch.sum(true_positive) / (selected + 1e-8) |
|
return precision, recall |
|
|
|
|
|
if __name__ == '__main__': |
|
edge = torch.zeros([2, 1, 10, 10]) |
|
edge[0, :, 2:8, 2:8] = 1 |
|
edge[1, :, 3:7, 3:7] = 1 |
|
mask = (edge > 0.5).float() |
|
b, c, h, w = mask.shape |
|
num_pos = torch.sum(mask, dim=[1, 2, 3]).float() |
|
num_neg = c * h * w - num_pos |
|
print(num_pos, num_neg) |
|
n = num_neg / (num_pos + num_neg) |
|
p = num_pos / (num_pos + num_neg) |
|
n = n.unsqueeze(1).unsqueeze(2).unsqueeze(3) |
|
p = p.unsqueeze(1).unsqueeze(2).unsqueeze(3) |
|
print(n * mask + p * (1 - mask)) |
|
|
|
|
|
|