libauc.optimizers
An overview of the optimizers
module is summarized as follow:
Optimizer |
Loss Function |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
We also adapted some popular optimizers from PyTorch codebase as follow:
libauc.optimizers.pesg
- class PESG(params, loss_fn, lr=0.1, mode='sgd', clip_value=1.0, weight_decay=1e-05, epoch_decay=0.002, momentum=0.9, betas=(0.9, 0.999), eps=1e-08, amsgrad=False, verbose=True, device=None, **kwargs)[source]
Proximal Epoch Stochastic Gradient Method (PESG) is used for optimizing the
AUCMLoss
. The key update steps are summarized as follows:Initialize \(\mathbf v_0= \mathbf v_{ref}=\{\mathbf{w_0}, a_0, b_0\}, \alpha_0\geq 0\)
For \(t=1, \ldots, T\):
\(\hspace{5mm}\) Compute \(\nabla_{\mathbf v} F(\mathbf v_t, \alpha_t; z_t)\) and \(\nabla_\alpha F(\mathbf v_t, \alpha_t; z_t)\).
\(\hspace{5mm}\) Update primal variables
\[\mathbf v_{t+1} = \mathbf v_{t} - \eta (\nabla_{\mathbf v} F(\mathbf v_t, \alpha_t; z_t)+ \lambda_0 (\mathbf v_t-\mathbf v_{\text{ref}})) - \lambda \eta\mathbf v_t\]\(\hspace{5mm}\) Update dual variable
\[\alpha_{t+1}= [\alpha_{t} + \eta \nabla_\alpha F(\mathbf v_t, \alpha_t; z_t)]_+\]\(\hspace{5mm}\) Decrease \(\eta\) by a decay factor and update \(\mathbf v_{\text{ref}}\) periodically
where \(z_t\) is the data pair \((x_t, y_t)\), \(\lambda_0\) is the epoch-level l2 penalty (i.e., epoch_decay), \(\lambda\) is the l2 penalty (i.e., weight_decay), and \(\eta\) is the learning rate.
For more details, please refer to the paper Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification.
- Parameters:
params (iterable) – iterable of parameters to optimize
loss_fn (callable) – loss function used for optimization (default:
None
)lr (float) – learning rate (default:
0.1
)mode (str) – optimization mode, ‘sgd’ or ‘adam’ (default:
'sgd'
)clip_value (float, optional) – gradient clipping value (default:
1.0
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
2e-3
)momentum (float, optional) – momentum factor for ‘sgd’ mode (default:
0.9
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square for ‘adam’ mode (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability for ‘adam’ mode (default:
1e-8
)amsgrad (bool, optional) – whether to use the AMSGrad variant of ‘adam’ mode from the paper On the Convergence of Adam and Beyond (default:
False
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.PESG(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
- Reference:
- property optim_step
Return the number of optimization steps.
libauc.optimizers.pdsca
- class PDSCA(params, loss_fn, lr=0.1, lr0=None, beta1=0.99, beta2=0.999, clip_value=1.0, weight_decay=1e-05, epoch_decay=0.002, verbose=True, device='cuda', **kwargs)[source]
Primal-Dual Stochastic Compositional Adaptive Algorithm (PDSCA) is used for optimizing
CompositionalAUCLoss
. For itearton \(t\), the key update steps are summarized as follows:Initialize \(\mathbf v_0= \mathbf v_{ref}=\mathbf u_0= \{\mathbf{w_0}, a_0, b_0\}, \alpha_0 \geq 0\)
For \(t=1, \ldots, T\):
\(\hspace{5mm} \mathbf{u}_{t+1}=(1-\beta_{0}) \mathbf{u}_{t}+\beta_{0}(\mathbf{w}_{\mathbf{t}}-\eta_0 \nabla L_{CE}(\mathbf{w}_{\mathbf{t}}) ; a ; b)\)
\(\hspace{5mm}\) \(\mathbf{z}_{t+1}=(1-\beta_{1}) \mathbf{z}_{t}+\beta_{1} \nabla_{\mathbf{u}} L_{AUC}(\mathbf{u}_{t+1})\)
\(\hspace{5mm}\) \(\mathbf{v}_{t+1}=\mathbf{v}_{t}-\eta_{1} (\mathbf{z}_{t+1} + λ_0(\mathbf{w}_t-\mathbf{v}_{ref})+ λ_1\mathbf{v}_t)\)
\(\hspace{5mm}\) \(\theta_{t+1}=\theta_{t}+\eta_{1} \nabla_{\theta} L_{AUC}(\theta_{t})\)
\(\hspace{5mm}\) Decrease \(\eta_0, \eta_1\) by a decay factor and update \(\mathbf v_{\text{ref}}\) periodically
where \(\lambda_0,\lambda_1\) refer to
epoch_decay
andweight_decay
, \(\eta_0, \eta_1\) refer to learning rates for inner updates (\(L_{CE}\)) and outer updates (\(L_{AUC}\)), and \(\mathbf v_t\) refers to \(\{\mathbf w_t, a_t, b_t\}\) and \(\theta\) refers to dual variable inCompositionalAUCLoss
. For more details, please refer to Compositional Training for End-to-End Deep AUC Maximization.- Parameters:
params (iterable) – iterable of parameters to optimize.
loss_fn (callable) – loss function used for optimization (default:
None
)lr (float) – learning rate (default:
0.1
)lr0 (float, optional) – learning rate for inner updates (default:
None
)beta1 (float, optional) – coefficient for updating the running average of gradient (default:
0.99
)beta2 (float, optional) – coefficient for updating the running average of gradient square (default:
0.999
)clip_value (float, optional) – gradient clipping value (default:
1.0
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
).epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
2e-3
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
'cuda'
)
Example
>>> optimizer = libauc.optimizers.PDSCA(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
- Reference:
[1] Zhuoning Yuan and Zhishuai Guo and Nitesh Chawla and Tianbao Yang. “Compositional Training for End-to-End Deep AUC Maximization.” International Conference on Learning Representations. 2022. https://openreview.net/forum?id=gPvB4pdu_Z
- property optim_step
libauc.optimizers.soap
- class SOAP(params, lr=0.001, mode='adam', clip_value=1.0, weight_decay=1e-05, epoch_decay=0, momentum=0.9, nesterov=False, dampening=0.1, betas=(0.9, 0.999), eps=1e-08, amsgrad=False, verbose=True, device=None, **kwargs)[source]
Stochastic Optimization of AP (SOAP) is used for optimizing
AveragePrecisionLoss
. The key update steps are summarized as follows:Initialize \(\mathbf u=0, \mathbf w_0\)
For \(t=1, \ldots, T\):
\(\hspace{5mm}\) For each \(\mathbf{x}_i\in \mathcal B_{+}\), update
\(\mathbf{u}^1_{\mathbf{x}_i} = (1-\gamma)\mathbf{u}^1_{\mathbf{x}_i} + \gamma \frac{1}{|\mathcal B|}\sum\limits_{x_j\in\mathcal B} \ell(\mathbf{w}_t;\mathbf{x}_j,\mathbf{x}_i)\mathbb{I}(y_j=1)\)
\(\mathbf{u}^2_{\mathbf{x}_i} = (1-\gamma)\mathbf{u}^2_{\mathbf{x}_i} + \gamma \frac{1}{|\mathcal B|}\sum\limits_{\mathbf{x}_j\in\mathcal B} \ell(\mathbf{w}_t;\mathbf{x}_j,\mathbf{x}_i)\)
\(\hspace{5mm}\) Compute (biased) Stochastic Gradient Estimator:
\(G(\mathbf{w}_t) = \frac{1}{|B_+|}\sum\limits_{\mathbf{x}_i\in\mathcal B_+} \sum\limits_{\mathbf{x}_j\in\mathcal B}\frac{(\mathbf{u}_{\mathbf{x}_i}^1 - \mathbf{u}_{\mathbf{x}_i}^2\mathbf{I}(\mathbf{y}_j=1))\nabla \ell(\mathbf{w};\mathbf{x}_j,\mathbf{x}_i) }{|B|(\mathbf{u}_{\mathbf{x}_i}^2)^2}\)
\(\hspace{5mm}\) Update \(\mathbf w_{t+1} =\mathbf w_t - \eta G(\mathbf{w}_t)\) (or Momentum/Adam style)
For more details, please refer to Stochastic optimization of areas under precision-recall curves with provable convergence.
- Parameters:
params (iterable) – iterable of parameters to optimize
lr (float) – learning rate (default:
1e-3
)mode (str, optional) – optimization mode, ‘sgd’ or ‘adam’ (default:
'adam'
)clip_value (float, optional) – gradient clipping value (default:
1.0
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)momentum (float, optional) – momentum factor (default:
0.9
)dampening (float, optional) – dampening for momentum (default:
0.1
)nesterov (bool, optional) – enables Nesterov momentum (default:
False
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability (default:
1e-8
)amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default:
False
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.SOAP(model.parameters(), lr=1e-3, mode='adam') >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
- Reference:
[1] Qi, Qi, Youzhi Luo, Zhao Xu, Shuiwang Ji, and Tianbao Yang. “Stochastic optimization of areas under precision-recall curves with provable convergence.” In Advances in Neural Information Processing Systems 34 (2021): 1752-1765. https://proceedings.neurips.cc/paper/2021/file/0dd1bc593a91620daecf7723d2235624-Paper.pdf
libauc.optimizers.sopa
- class SOPA(params, mode='adam', eta=1.0, lr=0.001, clip_value=1.0, weight_decay=0, epoch_decay=0, betas=(0.9, 0.999), eps=1e-08, amsgrad=False, momentum=0.9, nesterov=False, dampening=0, verbose=False, device=None, **kwargs)[source]
Stochastic Optimization for One-way pAUC (SOPA) is used for optimizing
pAUC_CVaR_Loss
. The key update steps are summarized as follows:Initialize \(\mathbf s^1=0, \mathbf w_0\)
For \(t=1, \ldots, T\):
\(\hspace{5mm}\) Sample two mini-batches \(\mathcal B_+\subset\mathcal S_+\) and \(\mathcal B_-\subset\mathcal S_-\).
\(\hspace{5mm}\) Compute \(p_{ij} =\mathbb I (\ell(h(\mathbf w_t, \mathbf x_i) - h(\mathbf w_t, \mathbf x_j)) - s^t_i> 0)\) for each positive-negative pair (\(\mathbf x_i\in\mathcal B_+, \mathbf x_j\in\mathcal B_-\))
\(\hspace{5mm}\) Update \(s^{t+1}_i =s^t_i - \frac{\eta_2}{n_+} (1 - \frac{\sum_j p_{ij}}{\beta |\mathcal B_-|} )\) for each positive data.
\(\hspace{5mm}\) Compute a gradient estimator:
\[\nabla_t = \frac{1}{\beta |\mathcal B_+||\mathcal B_-|}\sum_{\mathbf x_i\in\mathcal B_+} \sum_{\mathbf x_j\in \mathcal B_-}p_{ij}\nabla_\mathbf w L(\mathbf w_t; \mathbf x_i, \mathbf x_j)\]\(\hspace{5mm}\) Update \(\mathbf w_{t+1} =\mathbf w_t - \eta_1 \nabla_t\) (or Momentum/Adam style)
For more details, please refer to When AUC meets DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence Guarantee.
- Parameters:
params (iterable) – iterable of parameters to optimize
loss_fn (callable) – loss function used for optimization (default:
None
)lr (float) – learning rate (default:
0.1
)mode (str) – optimization mode, ‘sgd’ or ‘adam’ (default:
'sgd'
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)momentum (float, optional) – momentum factor for ‘sgd’ mode (default:
0.9
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square for ‘adam’ mode (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability for ‘adam’ mode (default:
1e-8
)amsgrad (bool, optional) – whether to use the AMSGrad variant of ‘adam’ mode from the paper On the Convergence of Adam and Beyond (default:
False
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.SOPA(model.parameters(), loss_fn=loss_fn, lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
- Reference:
[1] Zhu, Dixian and Li, Gang and Wang, Bokun and Wu, Xiaodong and Yang, Tianbao. “When AUC meets DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence Guarantee.” In International Conference on Machine Learning, pp. 27548-27573. PMLR, 2022. https://proceedings.mlr.press/v162/zhu22g.html
libauc.optimizers.sopa_s
- class SOPAs(params, lr=0.001, mode='adam', clip_value=2.0, weight_decay=1e-05, epoch_decay=0, momentum=0.9, nesterov=False, dampening=0.1, betas=(0.9, 0.999), eps=1e-08, amsgrad=False, verbose=True, device=None, **kwargs)[source]
Stochastic Optimization for One-way pAUC (SOPAs) is used for optimizing
pAUC_DRO_Loss
. The key update steps are summarized as follows:Initialize \(\mathbf u^0=0, \mathbf w_0\)
For \(t=1, \ldots, T\):
\(\hspace{5mm}\) For each \(\mathbf{x}_i\in \mathbf{B}_{+}\), update \(u^{t}_i =(1-\gamma)u^{t-1}_{i} + \gamma \frac{1}{|\mathbf{B}_-|} \sum_{\mathbf{x}_j\in \mathbf{B}_-}\exp\left(\frac{L(\mathbf{w}_t; \mathbf{x}_i, \mathbf{x}_j)}{\lambda}\right)\)
\(\hspace{5mm}\) Let \(p_{ij} = \exp (L(\mathbf{w}_t; \mathbf{x}_i, \mathbf{x}_j)/\lambda)/u^{t}_{i}\), then compute a gradient estimator:
\(\nabla_t=\frac{1}{|\mathbf{B}_{+}|}\frac{1}{|\mathbf{B}_-|}\sum_{\mathbf{x}_i\in\mathbf{B}_{+}} \sum_{\mathbf{x}_j\in \mathbf{B}_-}p_{ij}\nabla L(\mathbf{w}_t; \mathbf{x}_i, \mathbf{x}_j)\)
\(\hspace{5mm}\) Update \(\mathbf{v}_{t}=\beta\mathbf{v}_{t-1} + (1-\beta) \nabla_t\)
\(\hspace{5mm}\) Update \(\mathbf{w}_{t+1}=\mathbf{w}_t - \eta \mathbf{v}_t\) (or Adam-style)
For more details, please refer to When AUC meets DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence Guarantee.
- Parameters:
params (iterable) – iterable of parameters to optimize
lr (float) – learning rate (default:
1e-3
)mode (str, optional) – optimization mode, ‘sgd’ or ‘adam’ (default:
'adam'
)clip_value (float, optional) – gradient clipping value (default:
2.0
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)momentum (float, optional) – momentum factor (default:
0.9
)dampening (float, optional) – dampening for momentum (default:
0.1
)nesterov (bool, optional) – enables Nesterov momentum (default:
False
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability (default:
1e-8
)amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default:
False
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.SOPAs(model.parameters(), lr=1e-3, mode='adam') >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
- Reference:
[1] Zhu, Dixian and Li, Gang and Wang, Bokun and Wu, Xiaodong and Yang, Tianbao. “When AUC meets DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence Guarantee.” In International Conference on Machine Learning, pp. 27548-27573. PMLR, 2022. https://proceedings.mlr.press/v162/zhu22g.html
libauc.optimizers.sota_s
- class SOTAs(params, mode='adam', lr=0.001, clip_value=1.0, weight_decay=0, epoch_decay=0, betas=(0.9, 0.999), eps=1e-08, amsgrad=False, momentum=0.9, nesterov=False, dampening=0, verbose=False, device=None, **kwargs)[source]
Stochastic Optimization for Two-way pAUC Soft-version (SOTAs) is used for optimizing
tpAUC_KL_Loss
. The key update steps are summarized as follows:Initialize \(\mathbf u_0= \mathbf 0, v_0= \mathbf 0, \mathbf m_0= \mathbf 0, \mathbf w\)
For \(t=1, \ldots, T\):
\(\hspace{5mm}\) Sample two mini-batches \(\mathcal B_+\subset\mathcal S_+\) and \(\mathcal B_-\subset\mathcal S_-\).
\(\hspace{5mm}\) For each \(\mathbf x_i\in\mathcal B_{+}\), update \(u^i_{t} =(1-\beta_0)u^i_{t-1} + \beta_0 \frac{1}{|B_-|} \sum_{\mathbf x_j\in \mathcal B_-}L(\mathbf w_t; \mathbf x_i, \mathbf x_j)\)
\(\hspace{5mm}\) Update \(v_{t} = (1-\beta_1)v_{t-1} + \beta_1\frac{1}{|\mathcal B_{+}|}\sum_{\mathbf x_i\in \mathcal B_{+}} f_2(u^i_{t-1})\)
\(\hspace{5mm}\) Compute \(p_{ij} = (u^i_{t-1})^{\lambda/\lambda' - 1}\exp (L(\mathbf w_t, \mathbf x_i, \mathbf x_j)/\lambda)/v_{t}\)
\(\hspace{5mm}\) Compute a gradient estimator:
\[\nabla_t=\frac{1}{|\mathcal B_{+}}\frac{1}{|\mathcal B_-|}\sum_{\mathbf x_i\in\mathcal B_{+}} \sum_{\mathbf x_j\in \mathcal B_-}p_{ij}\nabla L(\mathbf w_t; \mathcal x_i, \mathcal x_j)\]\(\hspace{5mm}\) Compute \(\mathbf m_{t}=(1-\beta_2)\mathbf m_{t-1} + \beta_2 \nabla_t\)
\(\hspace{5mm}\) Update \(\mathbf w_{t+1} =\mathbf w_t - \eta_1 \mathbf m_t\) (or Adam style)
For more details, please refer to the paper When AUC meets DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence Guarantee.
- Parameters:
params (iterable) – iterable of parameters to optimize
lr (float, optional) – learning rate (default:
0.1
)mode (str, optional) – optimization mode, ‘sgd’ or ‘adam’ (default:
'sgd'
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)momentum (float, optional) – momentum factor for ‘sgd’ mode (default:
0.9
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square for ‘adam’ mode. (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability for ‘adam’ mode (default:
1e-8
)amsgrad (bool, optional) – whether to use the AMSGrad variant of ‘adam’ mode from the paper On the Convergence of Adam and Beyond (default:
False
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.SOTAs(model.parameters(), loss_fn=loss_fn, lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
- Reference:
[1] Zhu, Dixian and Li, Gang and Wang, Bokun and Wu, Xiaodong and Yang, Tianbao. “When AUC meets DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence Guarantee.” In International Conference on Machine Learning, pp. 27548-27573. PMLR, 2022. https://proceedings.mlr.press/v162/zhu22g.html
libauc.optimizers.song
- class SONG(params, lr=<required parameter>, clip_value=1.0, weight_decay=0, epoch_decay=0, mode='sgd', momentum=0.9, dampening=0, nesterov=False, betas=(0.9, 0.999), eps=1e-08, amsgrad=False, verbose=False, device=None, **kwargs)[source]
Stochastic Optimization for NDCG (SONG) and its top-K variant (K-SONG) is used for optimizing
NDCGLoss
. The key update steps are summarized as follows:\[ \begin{align}\begin{aligned}1. & \mathbf{u}_{q,i}^{t+1} = (1-\gamma_0)\mathbf{u}_{q,i}^{t} + \gamma_0 \frac{1}{|B_q|} \sum_{x^{\prime}\in B_q} \exp(h_q(x^{\prime};\mathbf{w})-h_q(x;\mathbf{w}))\\2. & G(\mathbf{w}_t) = \frac{1}{|Q_t|} \frac{1}{|B_q^+|} \frac{1}{|B_q|} \sum_{q\in Q_t} \sum_{x_i^q\in B_q^+} \sum_{x_j^q\in B_q} \frac{1}{\mathbf{u}_{q,i}^{t+1}} \nabla_{\mathbf{w}} (h_q(x_j^q;\mathbf{w}_t)-h_q(x_i^q;\mathbf{w}_t))\\3. & m_{t+1} = \beta_1 m_{t} + (1-\beta_1) G(\mathbf{w}_t)\\4. & \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_1 m_{t+1} (or Adam style)\end{aligned}\end{align} \]- Parameters:
params (iterable) – iterable of parameters to optimize
lr (float) – learning rate
mode (str) – optimization mode, ‘sgd’ or ‘adam’ (default:
'sgd'
)clip_value (float, optional) – gradient clipping value (default:
1.0
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)momentum (float, optional) – momentum factor for ‘sgd’ mode (default:
0.9
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square for ‘adam’ mode (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability for ‘adam’ mode (default:
1e-8
)amsgrad (bool, optional) – whether to use the AMSGrad variant of ‘adam’ mode from the paper On the Convergence of Adam and Beyond (default:
False
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.SONG(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(predictions, batch).backward() # loss_fn can be ListwiseCE_Loss or NDCG_Loss >>> optimizer.step()
- Reference:
[1] Qiu, Zi-Hao, Hu, Quanqi, Zhong, Yongjian, Zhang, Lijun, and Yang, Tianbao. “Large-scale Stochastic Optimization of NDCG Surrogates for Deep Learning with Provable Convergence.” Proceedings of the 39th International Conference on Machine Learning. 2022. https://arxiv.org/abs/2202.12183
libauc.optimizers.sogclr
- class SogCLR(params, lr=<required parameter>, clip_value=10.0, weight_decay=1e-06, epoch_decay=0, mode='lars', momentum=0.9, trust_coefficient=0.001, betas=(0.9, 0.999), eps=1e-08, amsgrad=False, verbose=False, device=None, **kwargs)[source]
Stochastic Optimization for sovling
GCLoss
. For each iteration t, the key updates for SogCLR are sumarized as follows:Initialize \(\tau, \mathbf w_0, \mathbf u_0= \mathbf 0\)
For \(t=1, \ldots, T\):
\(\hspace{5mm}\) Draw a batch of \(B\) samples
\(\hspace{5mm}\) Compute \(g(\mathbf{w}_t ; \mathbf{x}_i, \mathbf{B}^-_i) = \sum_{\mathbf{x}_j \sim \mathbf{B}^-_i} \exp (h_{\mathbf{w}}(\mathbf{x}_i)^{\top} h_{\mathbf{w}}(\mathbf{x}_j) / \tau)\)
\(\hspace{5mm}\) Compute \(\mathbf{u}_{i, t}=(1-\gamma) \mathbf{u}_{i, t-1}+\gamma \frac{1}{2|\mathbf{B}_i|} g(\mathbf{w}_t ; \mathbf{x}_i, \mathbf{B}^-_i)\)
\(\hspace{5mm}\) Compute the gradient estimator \(\mathbf{m}_t = -\frac{1}{B} \sum_{\mathbf{x}_i \in \mathbf{B}} \nabla\left(h_\mathbf{w}\left(\mathbf{x}_i\right)^{\top} h_\mathbf{w}\left(\mathbf{x}_i^+\right)\right) +\frac{\tau}{\mathbf{u}_{i,t} } \nabla g\left(\mathbf{w};\mathbf{x}_i; \mathbf{B}_i^{-}\right)\)
\(\hspace{5mm}\) Update model \(\mathbf{w_t}\) by Momemtum or Adam optimzier
For more details, please refer to Provable Stochastic Optimization for Global Contrastive Learning: Small Batch Does Not Harm Performance.
- Parameters:
params (iterable) – iterable of parameters to optimize
lr (float) – learning rate (default:
0.1
)mode (str) – optimization mode, ‘lars’ or ‘adamw’ (default:
'lars'
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)momentum (float, optional) – momentum factor for ‘sgd’ mode (default:
0.9
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square for ‘adam’ mode (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability for ‘adam’ mode (default:
1e-8
)amsgrad (bool, optional) – whether to use the AMSGrad variant of ‘adam’ mode from the paper On the Convergence of Adam and Beyond (default:
False
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.SogCLR(model.parameters(),lr=0.1, mode='lars', momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target, index).backward() >>> optimizer.step()
libauc.optimizers.isogclr
- class iSogCLR(params, lr=<required parameter>, clip_value=10.0, weight_decay=1e-06, epoch_decay=0, mode='lars', momentum=0, trust_coefficient=0.001, betas=(0.9, 0.999), eps=1e-08, amsgrad=False, verbose=False, device=None, **kwargs)[source]
Stochastic Optimization for sovling
GCLoss
. For each iteration t, the key updates for iSogCLR are sumarized as follow:Initialize \(\mathbf w_1, \mathbf{\tau}=\tau_{\text{init}}, \mathbf s_1 = \mathbf v_1 = \mathbf u_1= \mathbf 0\)
For \(t=1, \ldots, T\):
\(\hspace{5mm}\) Draw a batch of \(B\) samples
\(\hspace{5mm}\) For \(\mathbf{x}_i \in \mathbf{B}\):
\(\hspace{10mm}\) Compute \(g_i (\mathbf{w_t}, \mathbf{\tau}_i^t; \mathbf{B}_i) = \frac{1}{B} \sum_{z\in\mathbf{B}_i)} \exp \left(\frac{h_i(z)}{\mathbf{\tau}_i^t} \right)\)
\(\hspace{10mm}\) Update \(\mathbf{s}_i^{t+1} = (1-\beta_0) \mathbf{s}_i^{t} + \beta_0 g_i (\mathbf{w_t}, \mathbf{\tau}_i^t; \mathbf{B}_i)\)
\(\hspace{10mm}\) Compute \(G(\mathbf{\tau}_i^t) = \frac{1}{n} \left[\frac{\mathbf{\tau}_i^t}{\mathbf{s}_i^t} \nabla_{\mathbf{\tau}_i} g_i (\mathbf{w_t}, \mathbf{\tau}_i^t; \mathbf{B}_i) + \log(\mathbf{s}_i^t) + \rho \right]\)
\(\hspace{10mm}\) Update \(\mathbf{u}_i^{t+1} = (1-\beta_1) \mathbf{u}_i^{t} + \beta_1 G(\mathbf{\tau}_i^t)\)
\(\hspace{10mm}\) Update \(\mathbf{\tau}_i^{t+1} = \Pi_{\Omega}[\mathbf{\tau}_i^{t} - \eta \mathbf{u}_i^{t+1}]\)
\(\hspace{5mm}\) Compute stochastic gradient estimator \(G(\mathbf{w}_t) = \frac{1}{B} \sum_{\mathbf{x}_i \in \mathbf{B}} \frac{\mathbf{\tau}_i^t}{\mathbf{s}_i^t} \nabla_{\mathbf{w}} g_i (\mathbf{w_t}, \mathbf{\tau}_i^t; \mathbf{B}_i)\)
\(\hspace{5mm}\) Update model \(\mathbf{w_t}\) by Momemtum or Adam optimzier
where \(h_i(z)=E(\mathcal{A}(\mathbf{x}_i))^{\top} E(z) - E(\mathcal{A}(\mathbf{x}_i))^{\top} E(\mathcal{A}^{\prime}(\mathbf{x}_i))\), \(\mathbf{B}_i = \{\mathcal{A}(\mathbf{x}), \mathcal{A}^{\prime}(\mathbf{x}): \mathcal{A},\mathcal{A}^{\prime}\in\mathcal{P},\mathbf{x}\in \mathbf{B} \backslash \mathbf{x}_i \}\), \(\Omega=\{\tau_0 \leq \tau \}\) is the constraint set for each learnable \(\mathbf{\tau}_i\), \(\Pi\) is the projection operator.
For more details, please refer to Not All Semantics are Created Equal: Contrastive Self-supervised Learning with Automatic Temperature Individualization.
- Parameters:
params (iterable) – iterable of parameters to optimize
lr (float) – learning rate (default:
0.1
)mode (str) – optimization mode, ‘sgd’ or ‘adam’ (default:
'sgd'
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)momentum (float, optional) – momentum factor for ‘sgd’ mode (default:
0.9
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square for ‘adam’ mode (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability for ‘adam’ mode (default:
1e-8
)amsgrad (bool, optional) – whether to use the AMSGrad variant of ‘adam’ mode from the paper On the Convergence of Adam and Beyond (default:
False
)verbose (bool, optional) – whether to print optimization progress (default:
True
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.iSogCLR(model.parameters(),lr=0.1, mode='lars', momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target, index).backward() >>> optimizer.step()
libauc.optimizers.midam
- class MIDAM(params, loss_fn, lr=<required parameter>, momentum=0, weight_decay=0, device=None)[source]
MIDAM (Multiple Instance Deep AUC Maximization) is used for optimizing the
MIDAMLoss
(softmax or attention pooling based AUC loss).Notice that \(h(\mathbf w; \mathcal X_i)=f_2(f_1 (\mathbf w;\mathcal X_i))\) is the bag-level prediction after the pooling operation. Denote that the moving average estimation for bag-level prediction for i-th bag at t-th iteration as \(s_i^t\). The gradients estimation are:
\[G^t_{1,\mathbf w} = \hat{\mathbb E}_{i\in\mathcal S_+^t}\nabla f_1(\mathbf w^t; \mathcal B_{i}^t) \nabla f_2(s^{t-1}_i)\nabla_1 f( f_2(s^{t-1}_i), a^t),\]\[G^t_{2,\mathbf w} = \hat{\mathbb E}_{i\in\mathcal S_-^t}\nabla f_1(\mathbf w^t; \mathcal B_{i}^t) \nabla f_2(s^{t-1}_i)\nabla_1 f( f_2(s^{t-1}_i), b^t),\]\[G^t_{3,\mathbf w} = \alpha^t \cdot\left(\hat{\mathbb E}_{i\in\mathcal S_-^t}\nabla f_1(\mathbf w^t; \mathcal B_{i}^t) \nabla f_2(s^{t-1}_i)\right. \left.- \hat{\mathbb E}_{i\in\mathcal S_+^t}\nabla f_1(\mathbf w^t; \mathcal B_{i}^t) \nabla f_2(s^{t-1}_i)\right),\]\[G^t_{1,a} = \hat{\mathbb E}_{i\in\mathcal S_+^t} \nabla_2 f( f_2(s^{t-1}_i), a^t),\]\[G^t_{2, b} =\hat{\mathbb E}_{i\in\mathcal S_-^t} \nabla_2 f( f_2(s^{t-1}_i), b^t),\]\[G^t_{3,\alpha} = c+ \hat{\mathbb E}_{i\in\mathcal S_-^t}f_2(s^{t-1}_i) - \hat{\mathbb E}_{i\in\mathcal S_+^t}f_2(s^{t-1}_i),\]The key update steps for the stochastic optimization are summarized as follows:
Initialize \(\mathbf s^0=0, \mathbf v^0=\mathbf 0, a=0, b=0, \mathbf w\)
For \(t=1, \ldots, T\):
\(\hspace{5mm}\) Sample a batch of positive bags \(\mathcal S_+^t\subset\mathcal D_+\) and a batch of negative bags \(\mathcal S_-^t\subset\mathcal D_-\).
\(\hspace{5mm}\) For each \(i \in \mathcal S^t=\mathcal S_+^t\cup \mathcal S_-^t\):
\(\hspace{5mm}\) Sample a mini-batch of instances \(\mathcal B^t_i\subset\mathcal X_i\) and update:
\[s^t_i = (1-\gamma_0)s^{t-1}_i + \gamma_0 f_1(\mathbf w^t; \mathcal B_{i}^t)\]\(\hspace{5mm}\) Update stochastic gradient estimator of \((\mathbf w, a, b)\):
\[\mathbf v_1^t =\beta_1\mathbf v_1^{t-1} + (1-\beta_1)(G^t_{1,\mathbf w} + G^t_{2,\mathbf w} + G^t_{3,\mathbf w})\]\[\mathbf v_2^t =\beta_1\mathbf v_2^{t-1} + (1-\beta_1)G^t_{1,a}\]\[\mathbf v_3^t =\beta_1\mathbf v_3^{t-1} + (1-\beta_1)G^t_{2,b}\]
\(\hspace{5mm}\) Update \((\mathbf w^{t+1}, a^{t+1}, b^{t+1}) = (\mathbf w^t, a^t, b^t) - \eta \mathbf v^t\) (or Adam style)
\(\hspace{5mm}\) Update \(\alpha^{t+1} = \Pi_{\Omega}[\alpha^t + \eta' (G^t_{3,\alpha} - \alpha^t)]\)
For more details, please refer to the paper Provable Multi-instance Deep AUC Maximization with Stochastic Pooling.
- Parameters:
params (iterable) – iterable of parameters to optimize
loss_fn (callable) – loss function used for optimization (default:
None
)lr (float) – learning rate (default:
0.1
)momentum (float, optional) – momentum factor for ‘sgd’ mode (default:
0.1
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
1e-5
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
Example
>>> optimizer = libauc.optimizers.MIDAM(params=model.parameters(), loss_fn=loss_fn, lr=0.1, momentum=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
- Reference:
[1] Dixian Zhu, Bokun Wang, Zhi Chen, Yaxing Wang, Milan Sonka, Xiaodong Wu, Tianbao Yang “Provable Multi-instance Deep AUC Maximization with Stochastic Pooling.” In International Conference on Machine Learning, pp. xxxxx-xxxxx. PMLR, 2023. https://prepare-arxiv?
libauc.optimizers.lars
- class LARS(params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001)[source]
LARS optimizer, no rate scaling or weight decay for parameters <= 1D. This code is adapated from MOCOv3 codebase.
libauc.optimizers.sgd
- class SGD(params, lr=<required parameter>, momentum=0, dampening=0, clip_value=1.0, epoch_decay=0, weight_decay=0, nesterov=False, verbose=True, device=None, **kwargs)[source]
Implements stochastic gradient descent (optionally with momentum). This code is adapted from PyTorch codebase.
Nesterov momentum is based on the formula from On the importance of initialization and momentum in deep learning.
- Parameters:
params (iterable) – iterable of parameters to optimize
lr (float) – learning rate
momentum (float, optional) – momentum factor (default:
0
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
0
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)dampening (float, optional) – dampening for momentum (default:
0.0
)nesterov (bool, optional) – enables Nesterov momentum (default:
False)
device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
).
Example
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
\[\begin{split}\begin{aligned} v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \end{aligned}\end{split}\]where \(p\), \(g\), \(v\) and \(\mu\) denote the parameters, gradient, velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and other frameworks which employ an update of the form
\[\begin{split}\begin{aligned} v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ p_{t+1} & = p_{t} - v_{t+1}. \end{aligned}\end{split}\]The Nesterov version is analogously modified.
libauc.optimizers.adam
- class Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, clip_value=1.0, epoch_decay=0, weight_decay=0, amsgrad=False, verbose=True, device=None, **kwargs)[source]
Implements Adam algorithm. This code is adapted from PyTorch codebase.
It has been proposed in Adam: A Method for Stochastic Optimization. The implementation of the L2 penalty follows changes proposed in Decoupled Weight Decay Regularization.
- Parameters:
params (iterable) – iterable of parameters to optimize
lr (float) – learning rate (default:
1e-3
)betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square (default:
(0.9, 0.999)
)eps (float, optional) – term added to the denominator to improve numerical stability (default:
1e-8
)weight_decay (float, optional) – weight decay (L2 penalty) (default:
0
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default:
False
)device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
)
libauc.optimizers.adamw
- class AdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, clip_value=10.0, epoch_decay=0, weight_decay=0.01, amsgrad=False, verbose=False, device=None, **kwargs)[source]
Implements AdamW algorithm. This code is adapated from PyTorch codebase.
The original Adam algorithm was proposed in Adam: A Method for Stochastic Optimization. The AdamW variant was proposed in Decoupled Weight Decay Regularization. :param params: iterable of parameters to optimize :type params: iterable :param lr: learning rate (default:
1e-3
) :type lr: float :param betas: coefficients used for computingrunning averages of gradient and its square (default:
(0.9, 0.999)
)- Parameters:
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay coefficient (default:
1e-2
)epoch_decay (float, optional) – epoch decay (epoch-wise l2 penalty) (default:
0.0
)amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False)
device (torch.device, optional) – the device used for optimization, e.g., ‘cpu’ or ‘cuda’ (default:
None
).