oguzakif's picture
init repo
d4b77ac
import torch
import torch.nn as nn
import torch.nn.functional as F
def edgeLoss(preds_edges, edges):
"""
Args:
preds_edges: with shape [b, c, h , w]
edges: with shape [b, c, h, w]
Returns: Edge losses
"""
mask = (edges > 0.5).float()
b, c, h, w = mask.shape
num_pos = torch.sum(mask, dim=[1, 2, 3]).float()
num_neg = c * h * w - num_pos
neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug
losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none')
loss = torch.mean(losses)
return loss
class EdgeAcc(nn.Module):
"""
Measure the accuracy of the edge map
"""
def __init__(self, threshold=0.5):
super(EdgeAcc, self).__init__()
self.threshold = threshold
def __call__(self, pred_edge, gt_edge):
"""
Args:
pred_edge: Predicted edges, with shape [b, c, h, w]
gt_edge: GT edges, with shape [b, c, h, w]
Returns: The prediction accuracy and the recall of the edges
"""
labels = gt_edge > self.threshold
preds = pred_edge > self.threshold
relevant = torch.sum(labels.float())
selected = torch.sum(preds.float())
if relevant == 0 and selected == 0:
return torch.tensor(1), torch.tensor(1)
true_positive = ((preds == labels) * labels).float()
recall = torch.sum(true_positive) / (relevant + 1e-8)
precision = torch.sum(true_positive) / (selected + 1e-8)
return precision, recall
if __name__ == '__main__':
edge = torch.zeros([2, 1, 10, 10]) # [b, 1, h, w] -> the extracted edges
edge[0, :, 2:8, 2:8] = 1
edge[1, :, 3:7, 3:7] = 1
mask = (edge > 0.5).float()
b, c, h, w = mask.shape
num_pos = torch.sum(mask, dim=[1, 2, 3]).float()
num_neg = c * h * w - num_pos
print(num_pos, num_neg)
n = num_neg / (num_pos + num_neg)
p = num_pos / (num_pos + num_neg)
n = n.unsqueeze(1).unsqueeze(2).unsqueeze(3)
p = p.unsqueeze(1).unsqueeze(2).unsqueeze(3)
print(n * mask + p * (1 - mask))
# weight = num_neg / (num_pos + num_neg) * mask + num_pos / (num_pos + num_neg) * (1 - mask)
# print(weight)