Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmpretrain.registry import MODELS | |
| from .utils import weight_reduce_loss | |
| def cross_entropy(pred, | |
| label, | |
| weight=None, | |
| reduction='mean', | |
| avg_factor=None, | |
| class_weight=None): | |
| """Calculate the CrossEntropy loss. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C), C is the number | |
| of classes. | |
| label (torch.Tensor): The gt label of the prediction. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| reduction (str): The method used to reduce the loss. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| class_weight (torch.Tensor, optional): The weight for each class with | |
| shape (C), C is the number of classes. Default None. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| # element-wise losses | |
| loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none') | |
| # apply weights and do the reduction | |
| if weight is not None: | |
| weight = weight.float() | |
| loss = weight_reduce_loss( | |
| loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
| return loss | |
| def soft_cross_entropy(pred, | |
| label, | |
| weight=None, | |
| reduction='mean', | |
| class_weight=None, | |
| avg_factor=None): | |
| """Calculate the Soft CrossEntropy loss. The label can be float. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C), C is the number | |
| of classes. | |
| label (torch.Tensor): The gt label of the prediction with shape (N, C). | |
| When using "mixup", the label can be float. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| reduction (str): The method used to reduce the loss. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| class_weight (torch.Tensor, optional): The weight for each class with | |
| shape (C), C is the number of classes. Default None. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| # element-wise losses | |
| loss = -label * F.log_softmax(pred, dim=-1) | |
| if class_weight is not None: | |
| loss *= class_weight | |
| loss = loss.sum(dim=-1) | |
| # apply weights and do the reduction | |
| if weight is not None: | |
| weight = weight.float() | |
| loss = weight_reduce_loss( | |
| loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
| return loss | |
| def binary_cross_entropy(pred, | |
| label, | |
| weight=None, | |
| reduction='mean', | |
| avg_factor=None, | |
| class_weight=None, | |
| pos_weight=None): | |
| r"""Calculate the binary CrossEntropy loss with logits. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, \*). | |
| label (torch.Tensor): The gt label with shape (N, \*). | |
| weight (torch.Tensor, optional): Element-wise weight of loss with shape | |
| (N, ). Defaults to None. | |
| reduction (str): The method used to reduce the loss. | |
| Options are "none", "mean" and "sum". If reduction is 'none' , loss | |
| is same shape as pred and label. Defaults to 'mean'. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| class_weight (torch.Tensor, optional): The weight for each class with | |
| shape (C), C is the number of classes. Default None. | |
| pos_weight (torch.Tensor, optional): The positive weight for each | |
| class with shape (C), C is the number of classes. Default None. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| # Ensure that the size of class_weight is consistent with pred and label to | |
| # avoid automatic boracast, | |
| assert pred.dim() == label.dim() | |
| if class_weight is not None: | |
| N = pred.size()[0] | |
| class_weight = class_weight.repeat(N, 1) | |
| loss = F.binary_cross_entropy_with_logits( | |
| pred, | |
| label.float(), # only accepts float type tensor | |
| weight=class_weight, | |
| pos_weight=pos_weight, | |
| reduction='none') | |
| # apply weights and do the reduction | |
| if weight is not None: | |
| assert weight.dim() == 1 | |
| weight = weight.float() | |
| if pred.dim() > 1: | |
| weight = weight.reshape(-1, 1) | |
| loss = weight_reduce_loss( | |
| loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
| return loss | |
| class CrossEntropyLoss(nn.Module): | |
| """Cross entropy loss. | |
| Args: | |
| use_sigmoid (bool): Whether the prediction uses sigmoid | |
| of softmax. Defaults to False. | |
| use_soft (bool): Whether to use the soft version of CrossEntropyLoss. | |
| Defaults to False. | |
| reduction (str): The method used to reduce the loss. | |
| Options are "none", "mean" and "sum". Defaults to 'mean'. | |
| loss_weight (float): Weight of the loss. Defaults to 1.0. | |
| class_weight (List[float], optional): The weight for each class with | |
| shape (C), C is the number of classes. Default None. | |
| pos_weight (List[float], optional): The positive weight for each | |
| class with shape (C), C is the number of classes. Only enabled in | |
| BCE loss when ``use_sigmoid`` is True. Default None. | |
| """ | |
| def __init__(self, | |
| use_sigmoid=False, | |
| use_soft=False, | |
| reduction='mean', | |
| loss_weight=1.0, | |
| class_weight=None, | |
| pos_weight=None): | |
| super(CrossEntropyLoss, self).__init__() | |
| self.use_sigmoid = use_sigmoid | |
| self.use_soft = use_soft | |
| assert not ( | |
| self.use_soft and self.use_sigmoid | |
| ), 'use_sigmoid and use_soft could not be set simultaneously' | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.class_weight = class_weight | |
| self.pos_weight = pos_weight | |
| if self.use_sigmoid: | |
| self.cls_criterion = binary_cross_entropy | |
| elif self.use_soft: | |
| self.cls_criterion = soft_cross_entropy | |
| else: | |
| self.cls_criterion = cross_entropy | |
| def forward(self, | |
| cls_score, | |
| label, | |
| weight=None, | |
| avg_factor=None, | |
| reduction_override=None, | |
| **kwargs): | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| if self.class_weight is not None: | |
| class_weight = cls_score.new_tensor(self.class_weight) | |
| else: | |
| class_weight = None | |
| # only BCE loss has pos_weight | |
| if self.pos_weight is not None and self.use_sigmoid: | |
| pos_weight = cls_score.new_tensor(self.pos_weight) | |
| kwargs.update({'pos_weight': pos_weight}) | |
| else: | |
| pos_weight = None | |
| loss_cls = self.loss_weight * self.cls_criterion( | |
| cls_score, | |
| label, | |
| weight, | |
| class_weight=class_weight, | |
| reduction=reduction, | |
| avg_factor=avg_factor, | |
| **kwargs) | |
| return loss_cls | |