import torch
import torch.nn.functional as F
from .base_loss import BaseLoss
[docs]def quality_focal_loss(pred: torch.Tensor,
target: tuple([torch.Tensor]),
beta: float = 2.0) -> torch.Tensor:
r"""
Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
:param pred: Predicted joint representation of classification
and quality (IoU) estimation with shape (N, C), C is the number of
classes.
:param target: Target category label with shape (N,)
and target quality label with shape (N,).
:param beta: The beta parameter for calculating the modulating factor.
Defaults to 2.0.
:return: Loss tensor with shape (N,).
"""
assert len(target) == 2, """target for QFL must be a tuple of two elements,
including category label and quality label, respectively"""
# label denotes the category id, score denotes the quality score
label, score = target
# negatives are supervised by 0 quality score
pred_sigmoid = pred.sigmoid()
scale_factor = pred_sigmoid
zerolabel = scale_factor.new_zeros(pred.shape)
loss = F.binary_cross_entropy_with_logits(
pred, zerolabel, reduction='none') * scale_factor.pow(beta)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = pred.size(1)
pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
pos_label = label[pos].long()
# positives are supervised by bbox quality (IoU) score
scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
pred[pos, pos_label], score[pos],
reduction='none') * scale_factor.abs().pow(beta)
loss = loss.sum(dim=1, keepdim=False)
return loss
[docs]def quality_focal_loss_with_prob(pred: torch.Tensor,
target: tuple([torch.Tensor]),
beta: float = 2.0) -> torch.Tensor:
r"""
Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
:param pred: Predicted joint representation of classification
and quality (IoU) estimation with shape (N, C), C is the number of
classes.
:param target: Target category label with shape (N,)
and target quality label with shape (N,).
:param beta: The beta parameter for calculating the modulating factor.
Defaults to 2.0.
:return: Loss tensor with shape (N,).
"""
assert len(target) == 2, """target for QFL must be a tuple of two elements,
including category label and quality label, respectively"""
# label denotes the category id, score denotes the quality score
label, score = target
# negatives are supervised by 0 quality score
pred_sigmoid = pred
scale_factor = pred_sigmoid
zerolabel = scale_factor.new_zeros(pred.shape)
loss = F.binary_cross_entropy(
pred, zerolabel, reduction='none') * scale_factor.pow(beta)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = pred.size(1)
pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
pos_label = label[pos].long()
# positives are supervised by bbox quality (IoU) score
scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
loss[pos, pos_label] = F.binary_cross_entropy(
pred[pos, pos_label], score[pos],
reduction='none') * scale_factor.abs().pow(beta)
loss = loss.sum(dim=1, keepdim=False)
return loss
[docs]class QualityFocalLoss(BaseLoss):
def __init__(self,
use_sigmoid: bool=True,
beta: float=2.0,
activated: bool=False,
**kwargs):
r"""
Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss:
Learning Qualified and Distributed Bounding Boxes for Dense Object
Detection <https://arxiv.org/abs/2006.04388>`_.
:param use_sigmoid: Whether sigmoid operation is conducted in QFL.
Defaults to True.
:param beta: The beta parameter for calculating the modulating factor.
Defaults to 2.0.
:param activated: (optional) Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
:param kwargs:
"""
super(QualityFocalLoss, self).__init__(**kwargs)
assert use_sigmoid is True, 'Only sigmoid in QFL supported now.'
self.use_sigmoid = use_sigmoid
self.beta = beta
self.activated = activated
[docs] def loss(self, pred: torch.Tensor, target: torch.Tensor):
"""Forward function.
:param pred: Predicted joint representation of
classification and quality (IoU) estimation with shape (N, C),
C is the number of classes.
:param target: Target category label with shape
(N,) and target quality label with shape (N,).
:return: loss result.
"""
if self.use_sigmoid:
if self.activated:
loss_cls = quality_focal_loss_with_prob(pred, target, self.beta)
else:
loss_cls = quality_focal_loss(pred, target, self.beta)
else:
raise NotImplementedError
return loss_cls
[docs]class GaussianFocalLoss(BaseLoss):
"""GaussianFocalLoss is a variant of focal loss.
More details can be found in the `paper
<https://arxiv.org/abs/1808.01244>`_
Code is modified from `kp_utils.py
<https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
not 0/1 binary target.
"""
def __init__(self,
alpha: float=2.0,
gamma: float=4.0,
reduction: str='mean',
loss_weight: float=1.0):
"""
:param alpha: Power of prediction.
:param gamma: Power of target for negative samples.
:param reduction: Options are "none", "mean" and "sum".
:param loss_weight: Loss weight of current loss.
"""
super(GaussianFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.loss_weight = loss_weight
[docs] def loss(self, pred: torch.Tensor, target: torch.Tensor):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
distribution.
:param pred: The prediction.
:param target: The learning target of the prediction
in gaussian distribution.
:return: loss result.
"""
eps = 1e-12
pos_weights = target.eq(1)
neg_weights = (1 - target).pow(self.gamma)
pos_loss = -(pred + eps).log() * (1 - pred).pow(self.alpha) * pos_weights
neg_loss = -(1 - pred + eps).log() * pred.pow(self.alpha) * neg_weights
return pos_loss + neg_loss
[docs]def py_focal_loss_with_prob(pred: torch.Tensor,
target: torch.Tensor,
gamma: float=2.0,
alpha: float=0.25):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Different from `py_sigmoid_focal_loss`, this function accepts probability
as input.
:param pred: The prediction probability with shape (N, C),
C is the number of classes.
:param target: The learning label of the prediction.
:param gamma: The gamma for calculating the modulating
factor. Defaults to 2.0.
:param alpha: A balanced form for Focal Loss.
Defaults to 0.25.
:return: loss result.
"""
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
target = target.type_as(pred)
pt = (1 - pred) * target + pred * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy(
pred, target, reduction='none') * focal_weight
return loss
[docs]def py_sigmoid_focal_loss(pred: torch.Tensor,
target: torch.Tensor,
gamma: float=2.0,
alpha: float=0.25):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Different from `py_sigmoid_focal_loss`, this function accepts probability
as input.
:param pred: The prediction probability with shape (N, C),
C is the number of classes.
:param target: The learning label of the prediction.
:param gamma: The gamma for calculating the modulating
factor. Defaults to 2.0.
:param alpha: A balanced form for Focal Loss.
Defaults to 0.25.
:return: loss result.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
return loss
[docs]class FocalLoss(BaseLoss):
def __init__(self,
use_sigmoid: bool=True,
gamma: float=2.0,
alpha: float=0.25,
activated: bool=False,
bg_idx: int=None,
**kwargs):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
:param use_sigmoid: Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
:param gamma: The gamma for calculating the modulating
factor. Defaults to 2.0.
:param alpha: A balanced form for Focal Loss.
Defaults to 0.25.
:param activated: Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
:param bg_idx: background class index.
:param kwargs:
"""
super(FocalLoss, self).__init__(**kwargs)
assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
self.use_sigmoid = use_sigmoid
self.gamma = gamma
self.alpha = alpha
self.activated = activated
self.bg_idx = bg_idx
if use_sigmoid:
self.activation = 'sigmoid'
elif activated is False:
self.activation = 'softmax'
[docs] def loss(self, pred: torch.Tensor, target: torch.Tensor, *args, **kwargs):
"""
:param pred: prediction.
:param target: ground truth targets.
:param args:
:param kwargs:
:return:
"""
if self.use_sigmoid:
if self.activated:
calculate_loss_func = py_focal_loss_with_prob
else:
num_classes = pred.size(1)
if isinstance(target, torch.cuda.FloatTensor) and target.ndim == 1:
target = torch.stack([1 - target, target], dim=1)
else:
target = F.one_hot(target, num_classes=num_classes + 1)
if self.bg_idx is None:
target = target[:, :num_classes]
else:
target = target[:, [c for c in range(num_classes + 1) if c != self.bg_idx]]
calculate_loss_func = py_sigmoid_focal_loss
loss_cls = calculate_loss_func(
pred,
target,
gamma=self.gamma,
alpha=self.alpha)
else:
raise NotImplementedError
return loss_cls