|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from ..builder import LOSSES |
|
from .utils import weighted_loss |
|
|
|
|
|
def gmof(x, sigma): |
|
"""Geman-McClure error function. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
sigma (float): The sigma value used in the calculation. |
|
|
|
Returns: |
|
torch.Tensor: The computed Geman-McClure error. |
|
""" |
|
x_squared = x**2 |
|
sigma_squared = sigma**2 |
|
return (sigma_squared * x_squared) / (sigma_squared + x_squared) |
|
|
|
|
|
@weighted_loss |
|
def mse_loss(pred, target): |
|
"""Wrapper for Mean Squared Error (MSE) loss. |
|
|
|
Args: |
|
pred (torch.Tensor): Predicted values. |
|
target (torch.Tensor): Ground truth values. |
|
|
|
Returns: |
|
torch.Tensor: MSE loss. |
|
""" |
|
return F.mse_loss(pred, target, reduction='none') |
|
|
|
|
|
@weighted_loss |
|
def smooth_l1_loss(pred, target): |
|
"""Wrapper for Smooth L1 loss. |
|
|
|
Args: |
|
pred (torch.Tensor): Predicted values. |
|
target (torch.Tensor): Ground truth values. |
|
|
|
Returns: |
|
torch.Tensor: Smooth L1 loss. |
|
""" |
|
return F.smooth_l1_loss(pred, target, reduction='none') |
|
|
|
|
|
@weighted_loss |
|
def l1_loss(pred, target): |
|
"""Wrapper for L1 loss. |
|
|
|
Args: |
|
pred (torch.Tensor): Predicted values. |
|
target (torch.Tensor): Ground truth values. |
|
|
|
Returns: |
|
torch.Tensor: L1 loss. |
|
""" |
|
return F.l1_loss(pred, target, reduction='none') |
|
|
|
|
|
@weighted_loss |
|
def mse_loss_with_gmof(pred, target, sigma): |
|
"""Extended MSE Loss with Geman-McClure function applied. |
|
|
|
Args: |
|
pred (torch.Tensor): Predicted values. |
|
target (torch.Tensor): Ground truth values. |
|
sigma (float): The sigma value for the Geman-McClure function. |
|
|
|
Returns: |
|
torch.Tensor: The loss value. |
|
""" |
|
loss = F.mse_loss(pred, target, reduction='none') |
|
loss = gmof(loss, sigma) |
|
return loss |
|
|
|
|
|
@LOSSES.register_module() |
|
class MSELoss(nn.Module): |
|
"""Mean Squared Error (MSE) Loss. |
|
|
|
Args: |
|
reduction (str, optional): The method to reduce the loss to a scalar. |
|
Options are 'none', 'mean', and 'sum'. Defaults to 'mean'. |
|
loss_weight (float, optional): Weight of the loss. Defaults to 1.0. |
|
""" |
|
|
|
def __init__(self, reduction='mean', loss_weight=1.0): |
|
super().__init__() |
|
assert reduction in (None, 'none', 'mean', 'sum') |
|
self.reduction = 'none' if reduction is None else reduction |
|
self.loss_weight = loss_weight |
|
|
|
def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): |
|
"""Forward function to compute loss. |
|
|
|
Args: |
|
pred (torch.Tensor): Predictions. |
|
target (torch.Tensor): Ground truth. |
|
weight (torch.Tensor, optional): Optional weight per sample. |
|
avg_factor (int, optional): Factor for averaging the loss. |
|
reduction_override (str, optional): Option to override reduction method. |
|
|
|
Returns: |
|
torch.Tensor: Calculated loss. |
|
""" |
|
assert reduction_override in (None, 'none', 'mean', 'sum') |
|
reduction = reduction_override if reduction_override else self.reduction |
|
loss = self.loss_weight * mse_loss(pred, target, weight, reduction=reduction, avg_factor=avg_factor) |
|
return loss |
|
|
|
|
|
@LOSSES.register_module() |
|
class KinematicLoss(nn.Module): |
|
"""Kinematic Loss for hierarchical motion prediction. |
|
|
|
Args: |
|
reduction (str, optional): Reduction method ('none', 'mean', or 'sum'). |
|
loss_type (str, optional): The type of loss to use ('mse', 'smooth_l1', 'l1'). |
|
loss_weight (list[float], optional): List of weights for each stage of the hierarchy. |
|
""" |
|
|
|
def __init__(self, reduction='mean', loss_type='mse', loss_weight=[1.0]): |
|
super().__init__() |
|
assert reduction in (None, 'none', 'mean', 'sum') |
|
self.reduction = 'none' if reduction is None else reduction |
|
self.loss_weight = loss_weight |
|
self.num_stages = len(loss_weight) |
|
|
|
|
|
if loss_type == 'mse': |
|
self.loss_func = mse_loss |
|
elif loss_type == 'smooth_l1': |
|
self.loss_func = smooth_l1_loss |
|
elif loss_type == 'l1': |
|
self.loss_func = l1_loss |
|
else: |
|
raise ValueError(f"Unknown loss type: {loss_type}") |
|
|
|
def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): |
|
"""Forward function for hierarchical kinematic loss. |
|
|
|
Args: |
|
pred (torch.Tensor): The prediction tensor. |
|
target (torch.Tensor): The target tensor. |
|
weight (torch.Tensor, optional): Weights for each prediction. Defaults to None. |
|
avg_factor (int, optional): Factor to average the loss. Defaults to None. |
|
reduction_override (str, optional): Override reduction method. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: The calculated hierarchical loss. |
|
""" |
|
assert reduction_override in (None, 'none', 'mean', 'sum') |
|
reduction = reduction_override if reduction_override else self.reduction |
|
|
|
total_loss = 0 |
|
pred_t = pred.clone() |
|
target_t = target.clone() |
|
|
|
|
|
for i in range(self.num_stages): |
|
stage_loss = self.loss_weight[i] * self.loss_func( |
|
pred_t, target_t, weight, reduction=reduction, avg_factor=avg_factor) |
|
total_loss += stage_loss |
|
|
|
|
|
pred_t = torch.cat((pred_t[:, :1, :], pred_t[:, 1:] - pred_t[:, :-1]), dim=1) |
|
target_t = torch.cat((target_t[:, :1, :], target_t[:, 1:] - target_t[:, :-1]), dim=1) |
|
|
|
return total_loss |
|
|