|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Function
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
from ..utils import ext_loader
|
|
|
|
ext_module = ext_loader.load_ext('_ext', [
|
|
'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward',
|
|
'softmax_focal_loss_forward', 'softmax_focal_loss_backward'
|
|
])
|
|
|
|
|
|
class SigmoidFocalLossFunction(Function):
|
|
|
|
@staticmethod
|
|
def symbolic(g, input, target, gamma, alpha, weight, reduction):
|
|
return g.op(
|
|
'mmcv::MMCVSigmoidFocalLoss',
|
|
input,
|
|
target,
|
|
gamma_f=gamma,
|
|
alpha_f=alpha,
|
|
weight_f=weight,
|
|
reduction_s=reduction)
|
|
|
|
@staticmethod
|
|
def forward(ctx,
|
|
input,
|
|
target,
|
|
gamma=2.0,
|
|
alpha=0.25,
|
|
weight=None,
|
|
reduction='mean'):
|
|
|
|
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
|
|
assert input.dim() == 2
|
|
assert target.dim() == 1
|
|
assert input.size(0) == target.size(0)
|
|
if weight is None:
|
|
weight = input.new_empty(0)
|
|
else:
|
|
assert weight.dim() == 1
|
|
assert input.size(1) == weight.size(0)
|
|
ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
|
|
assert reduction in ctx.reduction_dict.keys()
|
|
|
|
ctx.gamma = float(gamma)
|
|
ctx.alpha = float(alpha)
|
|
ctx.reduction = ctx.reduction_dict[reduction]
|
|
|
|
output = input.new_zeros(input.size())
|
|
|
|
ext_module.sigmoid_focal_loss_forward(
|
|
input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
|
|
if ctx.reduction == ctx.reduction_dict['mean']:
|
|
output = output.sum() / input.size(0)
|
|
elif ctx.reduction == ctx.reduction_dict['sum']:
|
|
output = output.sum()
|
|
ctx.save_for_backward(input, target, weight)
|
|
return output
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, grad_output):
|
|
input, target, weight = ctx.saved_tensors
|
|
|
|
grad_input = input.new_zeros(input.size())
|
|
|
|
ext_module.sigmoid_focal_loss_backward(
|
|
input,
|
|
target,
|
|
weight,
|
|
grad_input,
|
|
gamma=ctx.gamma,
|
|
alpha=ctx.alpha)
|
|
|
|
grad_input *= grad_output
|
|
if ctx.reduction == ctx.reduction_dict['mean']:
|
|
grad_input /= input.size(0)
|
|
return grad_input, None, None, None, None, None
|
|
|
|
|
|
sigmoid_focal_loss = SigmoidFocalLossFunction.apply
|
|
|
|
|
|
class SigmoidFocalLoss(nn.Module):
|
|
|
|
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
|
|
super(SigmoidFocalLoss, self).__init__()
|
|
self.gamma = gamma
|
|
self.alpha = alpha
|
|
self.register_buffer('weight', weight)
|
|
self.reduction = reduction
|
|
|
|
def forward(self, input, target):
|
|
return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
|
|
self.weight, self.reduction)
|
|
|
|
def __repr__(self):
|
|
s = self.__class__.__name__
|
|
s += f'(gamma={self.gamma}, '
|
|
s += f'alpha={self.alpha}, '
|
|
s += f'reduction={self.reduction})'
|
|
return s
|
|
|
|
|
|
class SoftmaxFocalLossFunction(Function):
|
|
|
|
@staticmethod
|
|
def symbolic(g, input, target, gamma, alpha, weight, reduction):
|
|
return g.op(
|
|
'mmcv::MMCVSoftmaxFocalLoss',
|
|
input,
|
|
target,
|
|
gamma_f=gamma,
|
|
alpha_f=alpha,
|
|
weight_f=weight,
|
|
reduction_s=reduction)
|
|
|
|
@staticmethod
|
|
def forward(ctx,
|
|
input,
|
|
target,
|
|
gamma=2.0,
|
|
alpha=0.25,
|
|
weight=None,
|
|
reduction='mean'):
|
|
|
|
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
|
|
assert input.dim() == 2
|
|
assert target.dim() == 1
|
|
assert input.size(0) == target.size(0)
|
|
if weight is None:
|
|
weight = input.new_empty(0)
|
|
else:
|
|
assert weight.dim() == 1
|
|
assert input.size(1) == weight.size(0)
|
|
ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
|
|
assert reduction in ctx.reduction_dict.keys()
|
|
|
|
ctx.gamma = float(gamma)
|
|
ctx.alpha = float(alpha)
|
|
ctx.reduction = ctx.reduction_dict[reduction]
|
|
|
|
channel_stats, _ = torch.max(input, dim=1)
|
|
input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
|
|
input_softmax.exp_()
|
|
|
|
channel_stats = input_softmax.sum(dim=1)
|
|
input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
|
|
|
|
output = input.new_zeros(input.size(0))
|
|
ext_module.softmax_focal_loss_forward(
|
|
input_softmax,
|
|
target,
|
|
weight,
|
|
output,
|
|
gamma=ctx.gamma,
|
|
alpha=ctx.alpha)
|
|
|
|
if ctx.reduction == ctx.reduction_dict['mean']:
|
|
output = output.sum() / input.size(0)
|
|
elif ctx.reduction == ctx.reduction_dict['sum']:
|
|
output = output.sum()
|
|
ctx.save_for_backward(input_softmax, target, weight)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input_softmax, target, weight = ctx.saved_tensors
|
|
buff = input_softmax.new_zeros(input_softmax.size(0))
|
|
grad_input = input_softmax.new_zeros(input_softmax.size())
|
|
|
|
ext_module.softmax_focal_loss_backward(
|
|
input_softmax,
|
|
target,
|
|
weight,
|
|
buff,
|
|
grad_input,
|
|
gamma=ctx.gamma,
|
|
alpha=ctx.alpha)
|
|
|
|
grad_input *= grad_output
|
|
if ctx.reduction == ctx.reduction_dict['mean']:
|
|
grad_input /= input_softmax.size(0)
|
|
return grad_input, None, None, None, None, None
|
|
|
|
|
|
softmax_focal_loss = SoftmaxFocalLossFunction.apply
|
|
|
|
|
|
class SoftmaxFocalLoss(nn.Module):
|
|
|
|
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
|
|
super(SoftmaxFocalLoss, self).__init__()
|
|
self.gamma = gamma
|
|
self.alpha = alpha
|
|
self.register_buffer('weight', weight)
|
|
self.reduction = reduction
|
|
|
|
def forward(self, input, target):
|
|
return softmax_focal_loss(input, target, self.gamma, self.alpha,
|
|
self.weight, self.reduction)
|
|
|
|
def __repr__(self):
|
|
s = self.__class__.__name__
|
|
s += f'(gamma={self.gamma}, '
|
|
s += f'alpha={self.alpha}, '
|
|
s += f'reduction={self.reduction})'
|
|
return s
|
|
|