libauc.losses

libauc.losses.auc

An overview of the auc module can be found below:

AUCMLoss

AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC

CompositionalAUCLoss

Compositional AUC loss with squared-hinge surrogate loss for optimizing AUROC

AveragePrecisionLoss

Average Precision loss with squared-hinge surrogate loss for optimizing AUPRC

pAUCLoss

A wrapper for Partial AUC losses to optimize One-way and Two-way Partial AUROC

pAUC_CVaR_Loss

Partial AUC loss based on DRO-CVaR to optimize One-way Partial AUROC

pAUC_DRO_Loss

Partial AUC loss based on KL-DRO to One-way Partial AUROC

tpAUC_KL_Loss

Partial AUC loss based on DRO-KL to optimize two-way partial AUROC

PairwiseAUCLoss

Pairwise AUC loss to optimize AUROC based on different surrogate losses

MultiLabelAUCMLoss

AUC-Margin loss with squared-hinge surrogate loss to optimize multi-label AUROC

meanAveragePrecisionLoss

Mean Average Precision loss based on squared-hinge surrogate loss to optimize mAP

MultiLabelpAUCLoss

Partial AUC loss with squared-hinge surrogate loss to optimize multi-label Paritial AUROC.

class AUCMLoss(margin=1.0, imratio=None, version='v1', device=None)[source]

AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC. The objective function is defined as:

\[\begin{split}\min _{\substack{\mathbf{w} \in \mathbb{R}^d \\(a, b) \in \mathbb{R}^2}} \max _{\alpha \in \mathbb{R^+}} f(\mathbf{w}, a, b, \alpha):=\mathbb{E}_{\mathbf{z}}[F(\mathbf{w}, a, b, \alpha ; \mathbf{z})]\end{split}\]

where

\[\begin{split}F(\mathbf{w},a,b,\alpha; \mathbf{z}) &=(1-p)(h_{\mathbf{w}}(x)-a)^2\mathbb{I}_{[y=1]} +p(h_{\mathbf{w}}(x)-b)^2\mathbb{I}_{[y=-1]} \\ &+2\alpha(p(1-p)m+ p h_{\mathbf{w}}(x)\mathbb{I}_{[y=-1]}-(1-p)h_{\mathbf{w}}(x)\mathbb{I}_{[y=1]})\\ &-p(1-p)\alpha^2\end{split}\]

\(h_{\mathbf{w}}\) is the prediction scoring function, e.g., deep neural network, \(p\) is the ratio of positive samples to all samples, \(a\), \(b\) are the running statistics of the positive and negative predictions, \(\alpha\) is the auxiliary variable derived from the problem formulation and \(m\) is the margin term. We denote this version of AUCMLoss as v1.

To remove the class prior \(p\) in the above formulation, we can write the new objective function as follow:

\[\begin{split}f(\mathbf{w},a,b,\alpha) &= \mathbb{E}_{y=1}[(h_{\mathbf{w}}(x)-a)^2] + \mathbb{E}_{y=-1}[(h_{\mathbf{w}}(x)-b)^2] \\ &+2\alpha(m + \mathbb{E}_{y=-1}[h_{\mathbf{w}}(x)] - \mathbb{E}_{y=1}[h_{\mathbf{w}}(x)])\\ &-\alpha^2\end{split}\]

We denote this version of AUCMLoss as v2. The optimization algorithm for solving the above objectives are implemented as PESG. For the derivations, please refer to the original paper [1]_.

Parameters:
  • margin (float) – margin for squared-hinge surrogate loss (default: 1.0).

  • imratio (float, optional) – the ratio of the number of positive samples to the number of total samples in the training dataset. If this value is not given, the mini-batch statistics will be used instead.

  • version (str, optional) – whether to include prior \(p\) in the objective function (default: 'v1').

Example

>>> loss_fn = libauc.losses.AUCMLoss(margin=1.0)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> loss = loss_fn(preds, target)
>>> loss.backward()

Note

To use v2 of AUCMLoss, plesae set version='v2'. Otherwise, the default version is v1. The v2 version requires the use of DualSampler.

Note

Practial Tips:

  • epoch_decay is a regularization parameter similar to weight_decay that can be tuned in the same range.

  • For complex tasks, it is recommended to use regular loss to pretrain the model, and then switch to AUCMLoss for finetuning with a smaller learning rate.

Reference:
class CompositionalAUCLoss(margin=1.0, k=1, version='v1', imratio=None, backend='ce', l_avg=None, l_imb=None, device=None)[source]

Compositional AUC loss with squared-hinge surrogate loss for optimizing AUROC. The objective is defined as

\[L_{\mathrm{AUC}}\left(\mathbf{w}-\alpha \nabla L_{\mathrm{CE}}(\mathbf{w})\right)\]

where \(L_{\mathrm{AUC}}\) refers to AUCMLoss, \(L_{\mathrm{CE}}\) refers to CrossEntropyLoss and math:alpha refer to the step size for inner updates.

The optimization algorithm for solving this objective is implemented as PDSCA. For the derivations, please refer to the original paper [2]_.

Parameters:
  • margin (float) – margin for squared-hinge surrogate loss (default: 1.0).

  • imratio (float, optional) – the ratio of the number of positive samples to the number of total samples in the training dataset. If this value is not given, the mini-batch statistics will be used instead.

  • k (int, optional) – number of steps for inner updates. For example, when k is set to 2, the optimizer will alternately execute two steps optimizing CrossEntropyLoss followed by a single step optimizing AUCMLoss during training (default: 1).

  • version (str, optional) – whether to include prior \(p\) in the objective function (default: 'v1').

Example

>>> loss_fn = libauc.losses.CompositionalAUCLoss(margin=1.0, k=1)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> loss = loss_fn(preds, target)
>>> loss.backward()

Note

As CompositionalAUCLoss is built on AUCMLoss, there are also two versions of CompositionalAUCLoss. To use v2 version, plesae set version='v2'. Otherwise, the default version is v1.

Note

Practial Tips:

  • By default, k is set to 1. You may consider increasing it to a larger number to potentially improve performance.

Reference:
mean(tensor)[source]
class AveragePrecisionLoss(data_len, gamma=0.9, margin=1.0, surr_loss='squared_hinge', device=None)[source]

Average Precision loss with squared-hinge surrogate loss for optimizing AUPRC. The objective is defined as

\[\min_{\mathbf{w}} P(\mathbf{w})=\frac{1}{n_{+}}\sum\limits_{y_i=1}\frac{-\sum\limits_{s=1}^n\mathbb{I}(y_s=1)\ell(\mathbf{w};\mathbf{x}_s;\mathbf{x}_i)}{\sum\limits_{s=1}^n \ell(\mathbf{w};\mathbf{x}_s;\mathbf{x}_i)}\]

where \(\ell(\mathbf{w}; \mathbf{x}_s, \mathbf{x}_i)\) is a surrogate function of the non-continuous indicator function \(\mathbb{I}(h(\mathbf{x}_s)\geq h(\mathbf{x}_i))\), \(h(\cdot)\) is the prediction function, e.g., deep neural network.

The optimization algorithm for solving this objective is implemented as SOAP. For the derivations, please refer to the original paper [3].

This class is also aliased as APLoss.

Parameters:
  • data_len (int) – total number of samples in the training dataset.

  • gamma (float, optional) – parameter for moving average estimator (default: 0.9).

  • surr_loss (str, optional) – the choice for surrogate loss used for problem formulation (default: 'squared_hinge').

  • margin (float, optional) – margin for squred hinge surrogate loss (default: 1.0).

Example

>>> loss_fn = libauc.losses.APLoss(data_len=data_length)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> loss = loss_fn(preds, target, index)
>>> loss.backward()

Note

To use APLoss, we need to track index for each sample in the training dataset. To do so, see the example below:

class SampleDataset (torch.utils.data.Dataset):
     def __init__(self, inputs, targets):
         self.inputs = inputs
         self.targets = targets
     def __len__ (self) :
         return len(self.inputs)
     def __getitem__ (self, index):
         data = self.inputs[index]
         target = self.targets[index]
         return data, target, index

Note

Practical tips:

  • gamma is a parameter which is better to be tuned in the range (0, 1) for better performance. Some suggested values are {0.1, 0.3, 0.5, 0.7, 0.9}.

  • To further improve the performance, try tuning margin in the range (0, 1]. Some suggested values: {0.6, 0.8, 1.0}.

Reference:
class pAUCLoss(mode='1w', **kwargs)[source]

A wrapper for Partial AUC losses to optimize One-way and Two-way Partial AUROC. By default, One-way Partial AUC (OPAUC) refers to SOPAs and Two-way Partial AUC (TPAUC) refers to SOTAs. The usage for each loss is same as the original loss.

Parameters:
  • mode (str) – the specific loss function to be used in the backend (default: ‘1w’).

  • **kwargs – the required arguments for the selected loss function.

Example

>>> loss_fn = pAUCLoss(mode='1w', data_len=data_length)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> loss = loss_fn(preds, target, index)
>>> loss.backward()
class pAUC_CVaR_Loss(data_len, pos_len, num_neg=None, margin=1.0, beta=0.2, eta=0.1, surr_loss='squared_hinge', device=None)[source]

Partial AUC loss based on DRO-CVaR to optimize One-way Partial AUROC (OPAUC). The loss focuses on optimizing OPAUC in the range [0, beta] for false positive rate. The objective is defined as

\[F(\mathbf w, \mathbf s) = \frac{1}{n_+}\sum_{\mathbf x_i\in\mathcal S_+} \left(s_i + \frac{1}{\beta n_-}\sum_{\mathbf x_j\in \mathcal S_-}(L(\mathbf w; \mathbf x_i, \mathbf x_j) - s_i)_+\right)\]

where \(L(\mathbf w; \mathbf x_i, \mathbf x_j)\) is the surrogate pairwise loss function for one positive data and one negative data, e.g., squared hinge loss, logitstic loss, etc. \(\mathbf s\) is the dual variable from DRO-CVaR formulation that is minimized in the loss function. For a positive data \(\mathbf x_i\), any pairwise losses samller than \(s_i\) are truncated. Therefore, the loss function focus on the harder negative data; as a consequence, the pAUC_CVaR_Loss optimize the upper bounded FPR (false positive rate) of pAUC region.

This loss optimizes OPAUC in the range [0, beta] for False Positive Rate (FPR). The optimization algorithm for solving this objective is implemented as SOPA. For the derivations, please refer to the original paper [4].

Parameters:
  • data_len (int) – total number of samples in the training dataset.

  • pos_len (int) – total number of positive samples in the training dataset.

  • margin (float, optional) – margin term for squared-hinge surrogate loss (default: 1.0).

  • beta (float) – upper bound of False Positive Rate (FPR) used for optimizing pAUC (default: 0.2).

  • eta (float) – stepsize for update the dual variables for DRO-CVaR formulation (default: 0.1).

  • surr_loss (string, optional) – surrogate loss used in the problem formulation (default: 'squared_hinge').

Example

>>> loss_fn = pAUC_CVaR_loss(data_len=data_length, pos_len=pos_length)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> loss = loss_fn(preds, target, index)
>>> loss.backward()

Note

To use pAUC_CVaR_Loss, we need to track index for each sample in the training dataset. To do so, see the example below:

class SampleDataset (torch.utils.data.Dataset):
     def __init__(self, inputs, targets):
         self.inputs = inputs
         self.targets = targets
     def __len__ (self) :
         return len(self.inputs)
     def __getitem__ (self, index):
         data = self.inputs[index]
         target = self.targets[index]
         return data, target, index

Note

Practical tips:

  • margin can be tuned in {0.1, 0.3, 0.5, 0.7, 0.9, 1.0} for better performance.

  • beta can be tuned in the range (0.1, 0.9), ideally based on task requirement for FPR.

  • eta can be tuned in {10, 1.0, 0.1, 0.01} for better performance.

Reference:
class pAUC_DRO_Loss(data_len, gamma=0.9, margin=1.0, Lambda=1.0, surr_loss='squared_hinge', device=None)[source]

Partial AUC loss based on KL-DRO to optimize One-way Partial AUROC (OPAUC). In contrast to conventional AUC, partial AUC pays more attention to partial difficult samples. By leveraging the Distributionally Robust Optimization (DRO), the objective is defined as

\[\min_{\mathbf{w}}\frac{1}{n_+}\sum_{\mathbf{x}_i\in\mathbf{S}_+} \max_{\mathbf{p}\in\Delta} \sum_j p_j L(\mathbf{w}; \mathbf{x}_i, \mathbf{x}_j) - \lambda \text{KL}(\mathbf{p}, 1/n)\]

Then the objective is reformulated as follows to develop an algorithm.

\[\min_{\mathbf{w}}\frac{1}{n_+}\sum_{\mathbf{x}_i \in \mathbf{S}_+}\lambda \log \frac{1}{n_-}\sum_{\mathbf{x}_j \in \mathbf{S}_-}\exp\left(\frac{L(\mathbf{w}; \mathbf{x}_i, \mathbf{x}_j)}{\lambda}\right)\]

where \(L(\mathbf{w}; \mathbf{x_i}, \mathbf{x_j})\) is the surrogate pairwise loss function for one positive data and one negative data, e.g., squared hinge loss, \(\mathbf{S}_+\) and \(\mathbf{S}_-\) denote the subsets of the dataset which contain only positive samples and negative samples, respectively.

The optimization algorithm for solving the above objective is implemented as SOAPs. For the derivation of the above formulation, please refer to the original paper [4].

Parameters:
  • data_len (int) – total number of samples in the training dataset.

  • gamma (float) – parameter for moving average estimator (default: 0.9).

  • surr_loss (string, optional) – surrogate loss used in the problem formulation (default: 'squared_hinge').

  • margin (float, optional) – margin for squared-hinge surrogate loss (default: 1.0).

  • Lambda (float, optional) – weight for KL divergence regularization, e.g., 0.1, 1.0, 10.0 (default: 1.0).

Example

>>> loss_fn = libauc.losses.pAUC_DRO_Loss(data_len=data_length, gamma=0.9, Lambda=1.0)
>>> preds  = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> loss = loss_fn(preds, target, index)
>>> loss.backward()

Note

To use pAUC_DRO_Loss, we need to track index for each sample in the training dataset. To do so, see the example below:

class SampleDataset (torch.utils.data.Dataset):
     def __init__(self, inputs, targets):
         self.inputs = inputs
         self.targets = targets
     def __len__ (self) :
         return len(self.inputs)
     def __getitem__ (self, index):
         data = self.inputs[index]
         target = self.targets[index]
         return data, target, index

Note

Practical tips:

  • gamma is a parameter which is better to be tuned in the range (0, 1) for better performance. Some suggested values are {0.1, 0.3, 0.5, 0.7, 0.9}.

  • margin can be tuned in {0.1, 0.3, 0.5, 0.7, 0.9, 1.0} for better performance.

  • Lambda can be tuned in the range (0.1, 10) for better performance.

class tpAUC_KL_Loss(data_len, tau=1.0, Lambda=1.0, gammas=(0.9, 0.9), margin=1.0, surr_loss='squared_hinge', device=None)[source]

Partial AUC loss based on DRO-KL to optimize two-way partial AUROC. The objective function is defined as

\[F(\mathbf w; \phi_{kl}, \phi_{kl})= \lambda'\log \mathrm E_{\mathbf x_i\sim\mathcal S_+}\left(\mathrm E_{\mathbf x_j\sim\mathcal S_-}\exp(\frac{L(\mathbf w; \mathbf x_i,\mathbf x_j)}{\lambda})\right)^{\frac{\lambda}{\lambda'}}\]

where \(L(\mathbf w; \mathbf x_i, \mathbf x_j)\) is the surrogate pairwise loss function for one positive data and one negative data, e.g., squared hinge loss, logitstic loss, etc. In this formulation, we implicitly handle the \(\alpha\) and \(\beta\) range of TPAUC by tuning \(\lambda\) and \(\lambda'\) (we rename \(\lambda\) as Lambda and \(\lambda'\) as tau for coding purpose). The loss focuses on both harder positive and harder negative samples, hence can optimize the TPAUC on the left corner space of the AUROC curve.

The optimization algorithm for solving the above objective is implemented as SOTAs. For the derivation of the above formulation, please refer to the original paper [4].

Parameters:
  • data_len (int) – total number of samples in the training dataset.

  • margin (float, optional) – margin term used in surrogate loss (default: 1.0).

  • Lambda (float, optional) – KL regularization for negative samples (default: 1.0).

  • tau (float, optional) – KL regularization for positive samples (default: 1.0).

  • gammas (Tuple[float, float], optional) – coefficients used for moving average estimation for composite functions. (default: (0.9, 0.9))

  • surr_loss (string, optional) – surrogate loss used in the problem formulation (default: 'squared_hinge').

Example

>>> loss_fn = tpAUC_KL_Loss(data_len=data_length)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> loss = loss_fn(preds, target, index)
>>> loss.backward()

Note

To use tpAUC_KL_Loss, we need to track index for each sample in the training dataset. To do so, see the example below:

class SampleDataset (torch.utils.data.Dataset):
     def __init__(self, inputs, targets):
         self.inputs = inputs
         self.targets = targets
     def __len__ (self) :
         return len(self.inputs)
     def __getitem__ (self, index):
         data = self.inputs[index]
         target = self.targets[index]
         return data, target, index

Note

Practical tips:

  • gammas are parameters which are better to be tuned in the range (0, 1) for better performance. Some suggested values are {(0.1, 0.1), (0.5,0.5), (0.9,0.9)}.

  • margin can be tuned in {0.1, 0.3, 0.5, 0.7, 0.9, 1.0} for better performance.

  • Lambda and tau can be tuned in the range (0.1, 10) for better performance.

class PairwiseAUCLoss(surr_loss='logistic', hparam=1.0)[source]

Pairwise AUC loss to optimize AUROC based on different surrogate losses. For optimizing this objective, we can use existing optimizers in LibAUC or PyTorch such as, SGD, AdamW.

Parameters:
  • surr_loss (str) – surrogate loss for optimizing pairwise AUC loss. The available options are ‘logistic’, ‘squared’, ‘squared_hinge’, ‘barrier_hinge’ (default: 'squared_hinge').

  • hparam (float or tuple, optional) –

    abstract hyper parameter for different surrogate loss. In particular, the available options are:

    • squared with tunable margin term (default: 1.0).

    • squared_hinge with tunable margin term (default: 1.0).

    • logistic with tunable scaling term (default: 1.0).

    • barrier_hinge with tunable a tuple of (scale, margin) (default: (1.0, 1.0)).

Example

>>> loss_fn = PairwiseAUCLoss(surr_loss='squared', hparam=0.5)
>>> y_pred = torch.randn(32, requires_grad=True)
>>> y_true = torch.empty(32, dtype=torch.long).random_(2)
>>> loss = loss_fn(y_pred, y_true)
>>> loss.backward()
class meanAveragePrecisionLoss(data_len, num_labels, margin=1.0, gamma=0.9, surr_loss='squared_hinge', device=None)[source]

Mean Average Precision loss based on squared-hinge surrogate loss to optimize mAP. This is an extension of APLoss.

Parameters:
  • data_len (int) – total number of samples in the training dataset.

  • num_labels (int) – number of unique labels(tasks) in the dataset.

  • margin (float, optional) – margin for the squared-hinge surrogate loss (default: 1.0).

  • gamma (float, optional) – parameter for the moving average estimator (default: 0.9).

  • surr_loss (str, optional) – type of surrogate loss to use. Choices are ‘squared_hinge’, ‘squared’, ‘logistic’, ‘barrier_hinge’ (default: 'squared_hinge').

This class is also aliased as mAPLoss.

Example

>>> loss_fn = meanAveragePrecisionLoss(data_len=data_length, margin=1.0, num_labels=10, gamma=0.9)
>>> y_pred = torch.randn((32,10), requires_grad=True)
>>> y_true = torch.empty((32,10), dtype=torch.long).random_(2)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> task_ids = torch.randint(10, (32,), requires_grad=False)
>>> loss = loss_fn(y_pred, y_true, index, task_ids)
>>> loss.backward()
Reference:
class MultiLabelAUCMLoss(margin=1.0, version='v1', imratio=None, num_labels=10, device=None)[source]

AUC-Margin loss with squared-hinge surrogate loss to optimize multi-label AUROC. This is an extension of AUCMLoss.

Parameters:
  • margin (float) – margin term for squared-hinge surrogate loss. (default: 1.0)

  • num_labels (int) – number of labels for the dataset.

  • imratio (float, optional) – the ratio of the number of positive samples to the number of total samples in the training dataset. If this value is not given, the mini-batch statistics will be used instead.

  • version (str, optional) – whether to include prior \(p\) in the objective function (default: 'v1').

This class is also aliased as mAUCMLoss.

Example

>>> loss_fn = MultiLabelAUCMLoss(margin=1.0, num_labels=10)
>>> y_pred = torch.randn(32, 10, requires_grad=True)
>>> y_true = torch.empty(32, dtype=torch.long).random_(2)
>>> loss = loss_fn(y_pred, y_true)
>>> loss.backward()
Reference:
mean(tensor)[source]
class MultiLabelpAUCLoss(mode='1w', num_labels=10, device=None, **kwargs)[source]

Partial AUC loss with squared-hinge surrogate loss to optimize multi-label Paritial AUROC. This is an extension of pAUCLoss.

This class is also aliased as mPAUCLoss.

Parameters:
  • mode (str) – the specific loss function to be used in the backend (default: ‘1w’).

  • num_labels (int) – number of unique labels(tasks) in the dataset.

  • **kwargs – the required arguments for the selected loss function.

Example

>>> loss_fn = MultiLabelpAUCLoss(data_len=data_length, margin=1.0, num_labels=10)
>>> y_pred = torch.randn((32,10), requires_grad=True)
>>> y_true = torch.empty((32,10), dtype=torch.long).random_(2)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> task_ids = torch.randint(10, (32,), requires_grad=False)
>>> loss = loss_fn(y_pred, y_true, index, task_ids)
>>> loss.backward()
Reference:

libauc.losses.ranking

An overview of the ranking module can be found below:

ListwiseCELoss

Stochastic Optimization of Listwise CE loss.

NDCGLoss

Stochastic Optimization of NDCG (SONG) and top-K NDCG (K-SONG).

class ListwiseCELoss(N, num_pos, gamma, eps=1e-10, device=None)[source]

Stochastic Optimization of Listwise CE loss. The objective function is defined as

\[F(\mathbf{w})=\frac{1}{N}\sum_{q=1}^{N} \frac{1}{N_q}\sum_{\mathbf{x}_i^q \in S_q^+} - y_i^q \ln \left(\frac{\exp(h_q(\mathbf{x}_i^q;\mathbf{w}))}{\sum_{\mathbf{x}_j^q \in S_q} \exp(h_q(\mathbf{x}_j^q;\mathbf{w})) }\right)\]

where \(h_q(\mathbf{x}_i^q;\mathbf{w})\) is the predicted score of \(\mathbf{x}_i^q\) with respect to \(q\), \(y_i^q\) is the relvance score of \(x_i^q\) with respect to \(q\), \(N\) is the number of total queries, \(N_q\) is the total number of items to be ranked for query q, \(S_q\) denotes the set of items to be ranked by query \(q\), and \(S_q^+\) denotes the set of relevant items for query \(q\).

Parameters:
  • N (int) – number of all relevant pairs

  • num_pos (int) – number of positive items sampled for each user

  • gamma (float) – the factor for moving average, i.e., gamma in our paper [1]_.

  • eps (float, optional) – a small value to avoid divide-zero error (default: 1e-10)

Example

>>> loss_fn = libauc.losses.ListwiseCELoss(N=1000, num_pos=10, gamma=0.1)      # assume we have 1000 relevant query-item pairs
>>> predictions = torch.randn((32, 10+20), requires_grad=True)                   # we sample 32 queries/users, and 10 positive items and 20 negative items for each query/user
>>> batch = {'user_item_id': torch.randint(low=0, high=1000-1, size=(32,10+20))} # ids for all sampled query-item pairs in the batch
>>> loss = loss_fn(predictions, batch)
>>> loss.backward()
Reference:
class NDCGLoss(N, num_user, num_item, num_pos, gamma0=0.9, gamma1=0.9, eta0=0.01, margin=1.0, topk=-1, topk_version='theo', tau_1=0.01, tau_2=0.0001, sigmoid_alpha=2.0, surrogate_loss='squared_hinge', device=None)[source]

Stochastic Optimization of NDCG (SONG) and top-K NDCG (K-SONG). The objective function of K-SONG is a bilevel optimization problem as presented below:

\[ \begin{align}\begin{aligned}& \min \frac{1}{|S|} \sum_{(q,\mathbf{x}_i^q)\in S} \psi(h_q(\mathbf{x}_i^q;\mathbf{w})-\hat{\lambda}_q(\mathbf{w})) f_{q,i}(g(\mathbf{w};\mathbf{x}_i^q,S_q))\\& s.t. \hat{\lambda}_q(\mathbf{w})=\arg\min_{\lambda} \frac{K+\epsilon}{N_q}\lambda + \frac{\tau_2}{2}\lambda^2 + \frac{1}{N_q} \sum_{\mathbf{x}_i^q \in S_q} \tau_1 \ln(1+\exp((h_q(\mathbf{x}_i^q;\mathbf{w})-\lambda)/\tau_1)) ,\\& \forall q\in\mathbf{Q}\end{aligned}\end{align} \]

where \(\psi(\cdot)\) is a smooth Lipschtiz continuous function to approximate \(\mathbb{I}(\cdot\ge 0)\), e.g., sigmoid function, \(f_{q,i}(g)\) denotes \(\frac{1}{Z_q^K}\frac{1-2^{y_i^q}}{\log_2(N_q g+1)}\). The objective formulation for SONG is a special case of that for K-SONG, where the \(\psi(\cdot)\) function is a constant.

Parameters:
  • N (int) – number of all relevant pairs

  • num_user (int) – number of users in the dataset

  • num_item (int) – number of items in the dataset

  • num_pos (int) – number of positive items sampled for each user

  • gamma0 (float) – the moving average factor of u_{q,i}, i.e., beta_0 in our paper, in range (0.0, 1.0) this hyper-parameter can be tuned for better performance (default: 0.9)

  • gamma1 (float, optional) – the moving average factor of s_{q} and v_{q} (default: 0.9)

  • eta0 (float, optional) – step size of lambda (default: 0.01)

  • margin (float, optional) – margin for squared hinge loss (default: 1.0)

  • topk (int, optional) – NDCG@k optimization is activated if topk > 0; topk=-1 represents SONG (default: 1e-10)

  • topk_version (string, optional) – ‘theo’ or ‘prac’ (default: theo)

  • tau_1 (float, optional) – tau_1 in Eq. (6), tau_1 << 1 (default: 0.01)

  • tau_2 (float, optional) – tau_2 in Eq. (6), tau_2 << 1 (default: 0.0001)

  • sigmoid_alpha (float, optional) – a hyperparameter for sigmoid function, psi(x) = sigmoid(x * sigmoid_alpha) (default: 1.0)

Example

>>> loss_fn = libauc.losses.NDCGLoss(N=1000, num_user=100, num_item=5000, num_pos=10, gamma0=0.1, topk=-1)  # SONG (with topk = -1)/K-SONG (with topk = 100)
>>> predictions = torch.randn((32, 10+20), requires_grad=True)              # we sample 32 queries/users, and 10 positive items and 20 negative items for each query/user
>>> batch = {
        'rating': torch.randint(low=0, high=5, size=(32,10+20)),            # ratings (e.g., in the range of [0,1,2,3,4]) for each sampled query-item pair
        'user_id': torch.randint(low=0, high=100-1, size=32),               # id for each sampled query
        'num_pos_items': torch.randint(low=0, high=1000, size=32),          # number of all relevant items for each sampled query
        'ideal_dcg': torch.rand(32),                                        # ideal DCG precomputed for each sampled query (in the range of (0.0, 1.0))
        'user_item_id': torch.randint(low=0, high=1000-1, size=(32,10+20))} # ids for all sampled query-item pairs in the batch
    }
>>> loss = loss_fn(predictions, batch)
>>> loss.backward()
Reference:

libauc.losses.contrastive

An overview of the contrastive module can be found below:

GCLoss

A high-level wrapper for GCLoss_v1 and GCLoss_v2.

GCLoss_v1

Stochastic Optimization of Global Contrastive Loss (GCL) and Robust Global Contrastive Loss (RGCL) for learning representations for unimodal tasks (e.g., image-image).

GCLoss_v2

Stochastic Optimization of Global Contrastive Loss (GCL) and Robust Global Contrastive Loss (RGCL) for learning representations for bimodal task (e.g., image-text).

class GCLoss(mode='unimodal', **kwargs)[source]

A high-level wrapper for GCLoss_v1 and GCLoss_v2.

Parameters:
  • mode (str, optional) – type of GCLoss to use. Options are ‘unimodal’ for GCLoss_v1 and ‘bimodal’ for GCLoss_v2 (default: 'unimodal').

  • **kwargs – arbitrary keyword arguments. These will be passed directly to the chosen GCLoss version’s constructor.

Example

>>> loss_fn = GCLoss(mode='bimodal', N=1000, tau=0.1)
>>> feat_img, feat_txt = torch.randn((32, 256), requires_grad=True), torch.randn((32, 256), requires_grad=True)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> dynamic_loss = loss_fn(feat_img=feat_img, feat_txt=feat_txt, index=index)

Note

The forward method of this class simply calls the forward method of the chosen GCLoss (GCLoss_v1 or GCLoss_v2).

class GCLoss_v1(N=100000, tau=0.1, gamma=0.9, eps=1e-08, device=None, distributed=False, enable_isogclr=False, tau_min=0.05, tau_max=0.7, rho=0.3, eta=0.01, beta=0.9)[source]

Stochastic Optimization of Global Contrastive Loss (GCL) and Robust Global Contrastive Loss (RGCL) for learning representations for unimodal tasks (e.g., image-image). The objective for optimizing GCL (i.e., objective for SogCLR) is defined as

\[F(\mathbf{w}) = \frac{1}{n} \sum_{\mathbf{x}_i \in D} { \tau \log \mathbb{E}_{\mathbf{z}\in S_i^-} \exp \Big( \frac{h_i(\mathbf{z})}{\tau} \Big) },\]

and the objective for optimizing RGCL (i.e., objective for iSogCLR) is defined as

\[F(\mathbf{w},\mathbf{\tau}) = \frac{1}{n} \sum_{\mathbf{x}_i \in D} { \mathbf{\tau}_i \log \mathbb{E}_{\mathbf{z}\in S_i^-} \exp \Big( \frac{h_i(\mathbf{z})}{\mathbf{\tau}_i} \Big) + \mathbf{\tau}_i \rho },\]

where \(h_i(\mathbf{z})=E(\mathcal{A}(\mathbf{x}_i))^{\mathrm{T}}E(\mathbf{z})-E(\mathcal{A}(\mathbf{x}_i))^{\mathrm{T}}E(\mathcal{A}^{\prime}(\mathbf{x}_i))\), \(\mathcal{A}\) and \(\mathcal{A}^{\prime}\) are two data augmentation operations, \(S_i^-\) denotes all negative samples for anchor data \(\mathbf{x}_i\), and \(E(\cdot)\) represents the image encoder. In iSogCLR, \(\mathbf{\tau}_i\) is the individualized temperature for \(\mathbf{x}_i\).

Parameters:
  • N (int) – number of samples in the training dataset (default: 100000)

  • tau (float) – temperature parameter for global contrastive loss. If you enable isogclr, then input temperature will be the initial value for learnable temperature parameters (default: 0.1)

  • device (torch.device) – the device for the inputs (default: None)

  • distributed (bool) – whether to use distributed training (default: False)

  • enable_isogclr (bool, optional) – whether to enable iSogCLR. If True, then the algorithm will optimize individualized temperature parameters for all samples (default: False)

  • eta (float, optional) – the step size for updating temperature parameters in iSogCLR (default: 0.01)

  • rho (float, optional) – the hyperparameter \(\rho\) in Eq. (6) in iSogCLR [2] (default: 0.3)

  • tau_min (float, optional) – lower bound of learnable temperature in iSogCLR (default: 0.05)

  • tau_max (float, optional) – upper bound of learnable temperature in iSogCLR (default: 0.7)

  • beta (float, optional) – the momentum parameter for updating temperature parameters in iSogCLR (default: 0.9)

Example

>>> loss_fn = GCLoss_v1(N=1000, tau=0.1)
>>> img_feat1, img_feat2 = torch.randn((32, 256), requires_grad=True), torch.randn((32, 256), requires_grad=True)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> loss = loss_fn(img_feat1, img_feat2, index)
>>> loss.backward()
Reference:
class GCLoss_v2(N=1000000, tau=0.01, gamma=0.9, cache_labels=False, rank=0, world_size=1, distributed=False, enable_isogclr=False, tau_min=0.005, tau_max=0.05, rho=6.0, eta=0.01, beta=0.9)[source]

Stochastic Optimization of Global Contrastive Loss (GCL) and Robust Global Contrastive Loss (RGCL) for learning representations for bimodal task (e.g., image-text). The objective for optimizing GCL (i.e., objective for SogCLR) is defined as

\[F(\mathbf{w}) = \frac{1}{n} \sum_{(\mathbf{x}_i, \mathbf{t}_i) \in D} { \tau \log \mathbb{E}_{\mathbf{t}\in T_i^-} \exp \Big( \frac{h_{\mathbf{x}_i}(\mathbf{t})}{\tau} \Big) + \tau \log \mathbb{E}_{\mathbf{x}\in I_i^-} \exp \Big( \frac{h_{\mathbf{t}_i}(\mathbf{x})}{\tau} \Big) },\]

and the objective for optimizing RGCL (i.e., objective for iSogCLR) is defined as

\[F(\mathbf{w}, \mathbf{\tau}, \mathbf{\tau}^{\prime}) = \frac{1}{n} \sum_{(\mathbf{x}_i, \mathbf{t}_i) \in D} { (\mathbf{\tau}_i + \mathbf{\tau}^{\prime}_i)\rho + \mathbf{\tau}_i \log \mathbb{E}_{\mathbf{t}\in T_i^-} \exp \Big( \frac{h_{\mathbf{x}_i}(\mathbf{t})}{\mathbf{\tau}_i} \Big) + \mathbf{\tau}^{\prime}_i \log \mathbb{E}_{\mathbf{x}\in I_i^-} \exp \Big( \frac{h_{\mathbf{t}_i}(\mathbf{x})}{\mathbf{\tau}^{\prime}_i} \Big) },\]

where \((\mathbf{x}_i, \mathbf{t}_i) \in D\) is an image-text pair, \(h_{\mathbf{x}_i}(\mathbf{t})=E_I(\mathbf{x}_i)^{\mathrm{T}}E_T(\mathbf{t}) - E_I(\mathbf{x}_i)^{\mathrm{T}}E_T(\mathbf{t}_i)\), \(h_{\mathbf{t}_i}(\mathbf{x})=E_I(\mathbf{x})^{\mathrm{T}}E_T(\mathbf{t}_i) - E_I(\mathbf{x}_i)^{\mathrm{T}}E_T(\mathbf{t}_i)\), \(E_I(\cdot)\) and \(E_T(\cdot)\) are image and text encoder, respectively. In iSogCLR, \(\mathbf{\tau}_i\), \(\mathbf{\tau}^{\prime}_i\) are individualized temperature for \(\mathbf{x}_i\) and \(\mathbf{t}_i\), respectively.

Parameters:
  • N (int) – number of samples in the training dataset (default: 100000)

  • tau (float) – temperature parameter for global contrastive loss. If you enable isogclr, then input temperature will be the initial value for learnable temperature parameters (default: 0.1)

  • gamma (float) – the moving average factor for dynamic loss in range the range of (0.0, 1.0) (default: 0.9)

  • cache_labels (bool) – whether to cache labels for mini-batch data (default: True)

  • rank (int) – unique ID given to a process for distributed training (default: 0)

  • world_size (int) – total number of processes for distributed training (default: 1)

  • distributed (bool) – whether to use distributed training (default: False)

  • enable_isogclr (bool, optional) – whether to enable iSogCLR. If True, then the algorithm will optimize individualized temperature parameters for all samples (default: False)

  • eta (float, optional) – the step size for updating temperature parameters in iSogCLR (default: 0.01)

  • rho (float, optional) – the hyperparameter \(\rho\) in Eq. (6) in iSogCLR [2] (default: 6.0)

  • tau_min (float, optional) – lower bound of learnable temperature in iSogCLR (default: 0.005)

  • tau_max (float, optional) – upper bound of learnable temperature in iSogCLR (default: 0.05)

  • beta (float, optional) – the momentum parameter for updating temperature parameters in iSogCLR (default: 0.9)

Example

>>> loss_fn = GCLoss_v2(N=1000, tau=0.1)
>>> img_feat, txt_feat = torch.randn((32, 256), requires_grad=True), torch.randn((32, 256), requires_grad=True)
>>> index = torch.randint(32, (32,), requires_grad=False)
>>> loss = loss_fn(img_feat, txt_feat, index)
>>> loss.backward()
Reference:

libauc.losses.mil

An overview of the mil module can be found below:

MIDAMLoss

A high-level wrapper for MIDAM_softmax_pooling_loss and MIDAM_attention_pooling_loss.

MIDAM_attention_pooling_loss

Multiple Instance Deep AUC Maximization with stochastic Attention (MIDAM-att) Pooling is used for optimizing the AUROC under Multiple Instance Learning (MIL) setting.

MIDAM_softmax_pooling_loss

Multiple Instance Deep AUC Maximization with stochastic Smoothed-MaX (MIDAM-smx) Pooling.

class MIDAMLoss(mode='attention', **kwargs)[source]

A high-level wrapper for MIDAM_softmax_pooling_loss and MIDAM_attention_pooling_loss.

Example

>>> loss_fn = MIDAMLoss(mode='softmax', data_len=N, margin=para)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32 dtype=torch.long).random_(1)
>>> # in practice, index should be the indices of your data (bag-index for multiple instance learning).
>>> loss = loss_fn(exps=preds, y_true=target, index=torch.arange(32))
>>> loss.backward()
>>> loss_fn = MIDAMLoss(mode='attention', data_len=N, margin=para)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> denoms = torch.rand(32, 1, requires_grad=True) + 0.01
>>> target = torch.empty(32 dtype=torch.long).random_(1)
>>> # in practice, index should be the indices of your data (bag-index for multiple instance learning).
>>> # denoms should be the stochastic denominator values output from your model.
>>> loss = loss_fn(sn=preds, sd=denoms, y_true=target, index=torch.arange(32))
>>> loss.backward()
Reference:
update_smoothing(decay_factor)[source]
class MIDAM_attention_pooling_loss(data_len, margin=1.0, gamma=0.9, device=None)[source]

Multiple Instance Deep AUC Maximization with stochastic Attention (MIDAM-att) Pooling is used for optimizing the AUROC under Multiple Instance Learning (MIL) setting. The Attention Pooling is defined as

\[h(\mathbf w; \mathcal X) = \sigma(\mathbf w_c^{\top}E(\mathbf w; \mathcal X)) = \sigma\left(\sum_{\mathbf x\in\mathcal X}\frac{\exp(g(\mathbf w; \mathbf x))\delta(\mathbf w;\mathbf x)}{\sum_{\mathbf x'\in\mathcal X}\exp(g(\mathbf w; \mathbf x'))}\right),\]

where \(g(\mathbf w;\mathbf x)\) is a parametric function, e.g., \(g(\mathbf w; \mathbf x)=\mathbf w_a^{\top}\text{tanh}(V e(\mathbf w_e; \mathbf x))\), where \(V\in\mathbb R^{m\times d_o}\) and \(\mathbf w_a\in\mathbb R^m\). And \(\delta(\mathbf w;\mathbf x) = \mathbf w_c^{\top}e(\mathbf w_e; \mathbf x)\) is the prediction score from each instance, which will be combined with attention weights. We optimize the following AUC loss with the Attention Pooling:

\[\begin{split}\min_{\mathbf w\in\mathbb R^d,(a,b)\in\mathbb R^2}\max_{\alpha\in\Omega}F\left(\mathbf w,a,b,\alpha\right)&:= \underbrace{\hat{\mathbb E}_{i\in\mathcal D_+}\left[(h(\mathbf w; \mathcal X_i) - a)^2 \right]}_{F_1(\mathbf w, a)} \\ &+ \underbrace{\hat{\mathbb E}_{i\in\mathcal D_-}\left[(h(\mathbf w; \mathcal X_i) - b)^2 \right]}_{F_2(\mathbf w, b)} \\ &+ \underbrace{2\alpha (c+ \hat{\mathbb E}_{i\in\mathcal D_-}h(\mathbf w; \mathcal X_i) - \hat{\mathbb E}_{i\in\mathcal D_+}h(\mathbf w; \mathcal X_i)) - \alpha^2}_{F_3(\mathbf w, \alpha)},\end{split}\]

The optimization algorithm for solving the above objective is implemented as MIDAM. The stochastic pooling loss only requires partial data from each bag in the mini-batch. For the more details about the formulations, please refer to the original paper [1]_.

Parameters:
  • data_len (int) – number of training samples.

  • margin (float, optional) – margin parameter for AUC loss (default: 0.5).

  • gamma (float, optional) – moving average parameter for numerator and denominator on attention calculation (default: 0.9).

  • device (torch.device, optional) – the device used for computing loss, e.g., ‘cpu’ or ‘cuda’ (default: None)

Example

>>> loss_fn = MIDAM_attention_pooling_loss(data_len=data_length, margin=margin, tau=tau, gamma=gamma)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> denoms = torch.rand(32, 1, requires_grad=True) + 0.01
>>> target = torch.empty(32 dtype=torch.long).random_(1)
>>> # in practice, index should be the indices of your data (bag-index for multiple instance learning).
>>> # denoms should be the stochastic denominator values output from your model.
>>> loss = loss_fn(sn=preds, sd=denoms, y_true=target, index=torch.arange(32))
>>> loss.backward()
Reference:

Note

To use MIDAM_attention_pooling_loss, we need to track index for each sample in the training dataset. To do so, see the example below:

class SampleDataset (torch.utils.data.Dataset):
     def __init__(self, inputs, targets):
         self.inputs = inputs
         self.targets = targets
     def __len__ (self) :
         return len(self.inputs)
     def __getitem__ (self, index):
         data = self.inputs[index]
         target = self.targets[index]
         return data, target, index

Note

Practical tips:

  • gamma is a parameter which is better to be tuned in the range (0, 1) for better performance. Some suggested values are {0.1, 0.3, 0.5, 0.7, 0.9}.

  • margin can be tuned in as {0.1, 0.3, 0.5, 0.7, 0.9, 1.0} for better performance.

update_smoothing(decay_factor)[source]
class MIDAM_softmax_pooling_loss(data_len, margin=1.0, tau=0.1, gamma=0.9, device=None)[source]

Multiple Instance Deep AUC Maximization with stochastic Smoothed-MaX (MIDAM-smx) Pooling. This loss is used for optimizing the AUROC under Multiple Instance Learning (MIL) setting. The Smoothed-MaX Pooling is defined as

\[h(\mathbf w; \mathcal X) = \tau \log\left(\frac{1}{|\mathcal X|}\sum_{\mathbf x\in\mathcal X}\exp(\phi(\mathbf w; \mathbf x)/\tau)\right)\]

where \(\phi(\mathbf w;\mathbf x)\) is the prediction score for instance \(\mathbf x\) and \(\tau>0\) is a hyperparameter. We optimize the following AUC loss with the Smoothed-MaX Pooling:

\[\begin{split}\min_{\mathbf w\in\mathbb R^d,(a,b)\in\mathbb R^2}\max_{\alpha\in\Omega}F\left(\mathbf w,a,b,\alpha\right)&:= \underbrace{\hat{\mathbb E}_{i\in\mathcal D_+}\left[(h(\mathbf w; \mathcal X_i) - a)^2 \right]}_{F_1(\mathbf w, a)} \\ &+ \underbrace{\hat{\mathbb E}_{i\in\mathcal D_-}\left[(h(\mathbf w; \mathcal X_i) - b)^2 \right]}_{F_2(\mathbf w, b)} \\ &+ \underbrace{2\alpha (c+ \hat{\mathbb E}_{i\in\mathcal D_-}h(\mathbf w; \mathcal X_i) - \hat{\mathbb E}_{i\in\mathcal D_+}h(\mathbf w; \mathcal X_i)) - \alpha^2}_{F_3(\mathbf w, \alpha)},\end{split}\]

The optimization algorithm for solving the above objective is implemented as MIDAM. The stochastic pooling loss only requires partial data from each bag in the mini-batch For the more details about the formulations, please refer to the original paper [1]_.

Parameters:
  • data_len (int) – number of training samples.

  • margin (float, optional) – margin parameter for AUC loss (default: 0.5).

  • tau (float) – temperature parameter for smoothed max pooling (default: 0.1).

  • gamma (float, optional) – moving average parameter for pooling operation (default: 0.9).

  • device (torch.device, optional) – the device used for computing loss, e.g., ‘cpu’ or ‘cuda’ (default: None)

Example

>>> loss_fn = MIDAM_softmax_pooling_loss(data_len=data_length, margin=margin, tau=tau, gamma=gamma)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32 dtype=torch.long).random_(1)
>>> # in practice, index should be the indices of your data (bag-index for multiple instance learning).
>>> loss = loss_fn(exps=preds, y_true=target, index=torch.arange(32))
>>> loss.backward()
Reference:

Note

To use MIDAM_softmax_pooling_loss, we need to track index for each sample in the training dataset. To do so, see the example below:

class SampleDataset (torch.utils.data.Dataset):
     def __init__(self, inputs, targets):
         self.inputs = inputs
         self.targets = targets
     def __len__ (self) :
         return len(self.inputs)
     def __getitem__ (self, index):
         data = self.inputs[index]
         target = self.targets[index]
         return data, target, index

Note

Practical tips:

  • gamma is a parameter which is better to be tuned in the range (0, 1) for better performance. Some suggested values are {0.1, 0.3, 0.5, 0.7, 0.9}.

  • margin can be tuned in as {0.1, 0.3, 0.5, 0.7, 0.9, 1.0} for better performance.

  • tau can be tuned in the range (0.1, 10) ance. Some suggested values are {0.1, 0.3, 0.5, 0.7, 0.9}.

  • margin can be tuned in {0.1, 0.3, 0.5, 0.7, 0.9, 1.0} for better performance.

update_smoothing(decay_factor)[source]

libauc.losses.losses

An overview of the losses module can be found below:

CrossEntropyLoss

Cross-Entropy loss with a sigmoid function.

FocalLoss

Focal loss with a sigmoid function.

class CrossEntropyLoss[source]

Cross-Entropy loss with a sigmoid function. This implementation is based on the built-in function from binary_cross_entropy_with_logits.

Example

>>> loss_fn = CrossEntropyLoss()
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> loss = loss_fn(preds, target)
>>> loss.backward()
Reference:

https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html

class FocalLoss(alpha=0.25, gamma=2, device=None)[source]

Focal loss with a sigmoid function.

Parameters:
  • alpha (float) – weighting factor in range (0,1) to balance positive vs negative examples (Default: 0.25).

  • gamma (float) – exponent of the modulating factor (1 - p_t) to balance easy vs hard examples (Default: 2).

Example

>>> loss_fn = FocalLoss(alpha=0.25, gamma=2.0)
>>> preds = torch.randn(32, 1, requires_grad=True)
>>> target = torch.empty(32, dtype=torch.long).random_(1)
>>> loss = loss_fn(preds, target)
>>> loss.backward()
Reference:

libauc.losses.surrogate

An overview of the surrogate module can be found below:

barrier_hinge_loss

Barrier Hinge Loss.

get_surrogate_loss

A wrapper to call a specific surrogate loss function.

hinge_loss

Hinge Loss.

logistic_loss

Logistic Loss.

squared_hinge_loss

Squared Hinge Loss.

squared_loss

Squared Loss.

barrier_hinge_loss(hparam, t)[source]

Barrier Hinge Loss. The loss can be described as:

\[L_\text{barrier_hinge}(t, s, m) = \max(−s(m + t) + m, \max(s(t − m), m − t))\]

where m is the margin hyper-parameter and s is the the scaling hyper-parameter.

Reference:
get_surrogate_loss(loss_name='squared_hinge')[source]

A wrapper to call a specific surrogate loss function.

Parameters:

loss_name (str) – type of surrogate loss function to fetch, including ‘squared_hinge’, ‘squared’, ‘logistic’, ‘barrier_hinge’ (default: 'squared_hinge').

hinge_loss(margin, t)[source]

Hinge Loss. The loss can be described as:

\[L_\text{hinge}(t, m) = \max(m - t, 0)\]

where m is the margin hyper-parameter.

logistic_loss(scale, t)[source]

Logistic Loss. The loss can be described as:

\[L_\text{logistic}(t, s) = \log(1 + e^{-st})\]

where s is the scaling hyper-parameter.

squared_hinge_loss(margin, t)[source]

Squared Hinge Loss. The loss can be described as:

\[L_\text{squared_hinge}(t, m) = \max(m - t, 0)^2\]

where m is the margin hyper-parameter.

squared_loss(margin, t)[source]

Squared Loss. The loss can be described as:

\[L_\text{squared}(t, m) = (m - t)^2\]

where m is the margin hyper-parameter.