Source code for cosense3d.modules.losses.base_loss

import torch
from torch import nn


[docs]class BaseLoss(nn.Module): def __init__(self, reduction: str = 'mean', activation: str = 'none', loss_weight: float = 1.0): """ :param reduction: (optional) the method to reduce the loss. :param activation: options are "none", "mean" and "sum". :param loss_weight: (optional) the weight of loss. """ super().__init__() self.reduction = reduction self.loss_weight = loss_weight self.activation = activation @property def name(self): return self.__class__.__name__
[docs] def loss(self, *args, **kwargs): raise NotImplementedError
[docs] def forward(self, preds: torch.Tensor, targets: torch.Tensor, weight: torch.Tensor=None, avg_factor: int=None, reduction_override: str=None, *args, **kwargs) -> torch.Tensor: """ :param preds: prediction tensor. :param targets: target tensor. :param weight: The weight of loss for each prediction. Defaults to None. :param avg_factor: Average factor that is used to average the loss. Defaults to None. :param reduction_override: The reduction method used to override the original reduction method of the loss. Defaults to None. :param args: additional arguments. :param kwargs: :return: weighted loss. """ loss = self.loss(preds, targets, *args, **kwargs) # if weight is specified, apply element-wise weight if weight is not None: loss = loss * weight assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) # if avg_factor is not specified, just reduce the loss if avg_factor is None: if reduction == 'mean': loss = loss.mean() elif reduction == 'sum': loss = loss.sum() else: # if reduction is mean, then average the loss by avg_factor if reduction == 'mean': # Avoid causing ZeroDivisionError when avg_factor is 0.0, # i.e., all labels of an image belong to ignore index. eps = torch.finfo(torch.float32).eps loss = loss.sum() / (avg_factor + eps) # if reduction is 'none', then do nothing, otherwise raise an error elif reduction != 'none': raise ValueError('avg_factor can not be used with reduction="sum"') return self.loss_weight * loss