Source code for cosense3d.modules.losses.edl

import torch
import torch.nn.functional as F

from cosense3d.modules.losses import BaseLoss


[docs]def relu_evidence(y): return F.relu(y)
[docs]def exp_evidence(y): return torch.exp(torch.clamp(y, -6, 6))
[docs]def softplus_evidence(y): return F.softplus(y)
[docs]def kl_divergence(alpha, num_classes): device = alpha.device ones = torch.ones([1, num_classes], dtype=torch.float32, device=device) sum_alpha = torch.sum(alpha, dim=1, keepdim=True) first_term = ( torch.lgamma(sum_alpha) - torch.lgamma(alpha).sum(dim=1, keepdim=True) # + torch.lgamma(ones).sum(dim=1, keepdim=True) - torch.lgamma(ones.sum(dim=1, keepdim=True)) ) second_term = ( (alpha - ones) .mul(torch.digamma(alpha) - torch.digamma(sum_alpha)) .sum(dim=1, keepdim=True) ) kl = first_term + second_term return kl
[docs]def loglikelihood_loss(y, alpha): S = torch.sum(alpha, dim=1, keepdim=True) loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True) loglikelihood_var = torch.sum( alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True ) loglikelihood = loglikelihood_err + loglikelihood_var return loglikelihood
[docs]def mse_loss(y, alpha, epoch_num, num_classes, annealing_step): loglikelihood = loglikelihood_loss(y, alpha) annealing_coef = torch.min( torch.tensor(1.0, dtype=torch.float32), torch.tensor(epoch_num / annealing_step, dtype=torch.float32), ) kl_alpha = (alpha - 1) * (1 - y) + 1 kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes) return loglikelihood + kl_div
[docs]def edl_mse_loss(preds, tgt, n_cls, temp, annealing_step, model_label='edl'): """ Calculate evidential loss :param model_label: (str) a name to distinguish edl loss of different modules :param preds: (N, n_cls) the logits of each class :param tgt: (N,) labels with values from 0...(n_cls - 1) or (N, n_cls) :param n_cls: (int) number of classes, including background :param temp: current temperature for annealing of KL Divergence term of the loss :param annealing_step: maximum annealing step :return: """ evidence = relu_evidence(preds) if len(tgt.shape) == 1: cared = tgt >= 0 evidence = evidence[cared] tgt = tgt[cared] tgt_onehot = F.one_hot(tgt.long(), n_cls).float() elif len(tgt.shape) == 2 and tgt.shape[1] > 1: cared = (tgt >= 0).all(dim=-1) evidence = evidence[cared] tgt_onehot = tgt[cared] else: raise NotImplementedError alpha = evidence + 1 loss = mse_loss(tgt_onehot, alpha, temp, n_cls, annealing_step).mean() ss = evidence.detach() tt = tgt_onehot.detach() acc = (torch.argmax(ss, dim=1) == torch.argmax(tt, dim=1)).sum() / len(tt) * 100 loss_dict = { f'{model_label}_loss': loss, f'{model_label}_ac': acc, } # Uncomment to log recall of all classes # for cls in [1, 2]: # loss_dict[f'acc{cls}'] = torch.logical_and( # torch.argmax(ss, dim=1) == cls, tt == cls).sum() \ # / max((tt == cls).sum(), 1) * 100 return loss_dict
[docs]def evidence_to_conf_unc(evidence, edl=True): if edl: # used edl loss alpha = evidence + 1 S = torch.sum(alpha, dim=-1, keepdim=True) conf = torch.div(alpha, S) K = evidence.shape[-1] unc = torch.div(K, S) # conf = torch.sqrt(conf * (1 - unc)) unc = unc.squeeze(dim=-1) else: # use entropy as uncertainty entropy = -evidence * torch.log2(evidence) unc = entropy.sum(dim=-1) # conf = torch.sqrt(evidence * (1 - unc.unsqueeze(-1))) conf = evidence return conf, unc
[docs]def pred_to_conf_unc(preds, activation='relu', edl=True): if callable(activation): evidence = activation(preds) elif activation == 'relu': evidence = relu_evidence(preds) elif activation == 'exp': evidence = exp_evidence(preds) elif activation == 'sigmoid': evidence = preds.sigmoid() elif activation == 'softmax': evidence = preds.softmax(dim=-1) else: evidence = preds if edl: alpha = evidence + 1 S = torch.sum(alpha, dim=-1, keepdim=True) conf = torch.div(alpha, S) K = evidence.shape[-1] unc = torch.div(K, S) # conf = torch.sqrt(conf * (1 - unc)) unc = unc.squeeze(dim=-1) else: # use entropy as uncertainty entropy = -evidence * torch.log2(evidence) unc = entropy.sum(dim=-1) # conf = torch.sqrt(evidence * (1 - unc.unsqueeze(-1))) conf = evidence return conf, unc
[docs]class EDLLoss(BaseLoss): def __init__(self, n_cls: int, annealing_step: int, **kwargs): """ Evidential loss. :param n_cls: number of classes, including background. :param annealing_step: maximum temperature annealing step for KL regularization of EDL loss . :param kwargs: """ super().__init__(**kwargs) self.n_cls = n_cls self.annealing_step = annealing_step if self.activation == 'relu': self.activation = relu_evidence elif self.activation == 'exp': self.activation = exp_evidence else: self.activation = None
[docs] def loss(self, preds, tgt, temp, n_cls_override=None): if self.activation is None: evidence = preds else: evidence = self.activation(preds) if len(tgt.shape) == 1: cared = tgt >= 0 evidence = evidence[cared] tgt = tgt[cared] tgt_onehot = F.one_hot(tgt.long(), self.n_cls).float() elif len(tgt.shape) == 2 and tgt.shape[1] > 1: cared = (tgt >= 0).all(dim=-1) evidence = evidence[cared] tgt_onehot = tgt[cared] else: raise NotImplementedError alpha = evidence + 1 n_cls = self.n_cls if n_cls_override is None else n_cls_override loss = mse_loss(tgt_onehot, alpha, temp, n_cls, self.annealing_step) return loss