Source code for libauc.losses.surrogate

import torch 

__all__ = ['squared_loss', 
           'squared_hinge_loss', 
           'hinge_loss', 
           'logistic_loss', 
           'barrier_hinge_loss', 
           'get_surrogate_loss']


[docs] def squared_loss(margin, t): r""" Squared Loss. The loss can be described as: .. math:: L_\text{squared}(t, m) = (m - t)^2 where ``m`` is the margin hyper-parameter. """ return (margin - t)** 2
[docs] def squared_hinge_loss(margin, t): r""" Squared Hinge Loss. The loss can be described as: .. math:: L_\text{squared_hinge}(t, m) = \max(m - t, 0)^2 where ``m`` is the margin hyper-parameter. """ return torch.max(margin - t, torch.zeros_like(t)) ** 2
[docs] def hinge_loss(margin, t): r""" Hinge Loss. The loss can be described as: .. math:: L_\text{hinge}(t, m) = \max(m - t, 0) where ``m`` is the margin hyper-parameter. """ return torch.max(margin - t, torch.zeros_like(t))
[docs] def logistic_loss(scale, t): r""" Logistic Loss. The loss can be described as: .. math:: L_\text{logistic}(t, s) = \log(1 + e^{-st}) where ``s`` is the scaling hyper-parameter. """ return torch.log(1+torch.exp(-scale*t))
[docs] def barrier_hinge_loss(hparam, t): r""" Barrier Hinge Loss. The loss can be described as: .. math:: L_\text{barrier_hinge}(t, s, m) = \max(−s(m + t) + m, \max(s(t − m), m − t)) where ``m`` is the margin hyper-parameter and ``s`` is the the scaling hyper-parameter. Reference: .. [1] Charoenphakdee, Nontawat, Jongyeong Lee, and Masashi Sugiyama. "On symmetric losses for learning from corrupted labels." International Conference on Machine Learning. PMLR, 2019. """ m,s = hparam loss = torch.maximum(-s * (m + t) + m, torch.maximum(m - t, s* (t - rm))) return loss
[docs] def get_surrogate_loss(loss_name='squared_hinge'): r""" A wrapper to call a specific surrogate loss function. Args: loss_name (str): type of surrogate loss function to fetch, including 'squared_hinge', 'squared', 'logistic', 'barrier_hinge' (default: ``'squared_hinge'``). """ assert f'{loss_name}_loss' in __all__, f'{loss_name} is not implemented' if loss_name == 'squared_hinge': surr_loss = squared_hinge_loss elif loss_name == 'squared': surr_loss = squared_loss elif loss_name == 'logistic': surr_loss = logistic_loss elif loss_name == 'barrier_hinge': surr_loss = barrier_hinge_loss else: raise ValueError('Out of options!') return surr_loss