================================
Optimizing Partial AUROC Loss (pAUCLoss)
================================
.. raw:: html
------------------------------------------------------------------------------------
.. container:: cell markdown
| **Author**: Zhuoning Yuan, Dixian Zhu, Gang Li, Tianbao Yang
\
Introduction
-----------------------
In this tutorial, you will learn how to quickly train a Resnet18 model by optimizing **One way Partial AUC (OPAUC)** with our novel :obj:`pAUCLoss` `[ref] `__ on a binary image classification task with the CIFAR-10 dataset. Please note that :obj:`pAUCLoss` is a wrapper function for different types of partial AUC losses. It currently supports two primary modes:
- :obj:`pAUCLoss('1w')`: This mode aims to optimize One-way Partial AUC using :obj:`pAUC_DRO_Loss` as the backend and utilizing the :obj:`SOAPs` optimizer for optimization.
- :obj:`pAUCLoss('2w')`: This mode aims to optimize Two-way Partial AUC using :obj:`tpAUC_KL_Loss` as the backend and utilizing the :obj:`SOTAs` optimizer for optimization.
This function allows for flexibility in handling varying partial AUC loss in different scenarios. For the original tutorials, please refer to :ref:`SOPAs `, :ref:`SOPA `, and :ref:`SOTAs `. After completing this tutorial, you should be able to use LibAUC to train your own models on your own datasets.
**Reference**:
If you find this tutorial helpful in your work, please cite our `library paper `__ and the following papers:
.. code-block:: RST
@inproceedings{zhu2022auc,
title={When auc meets dro: Optimizing partial auc for deep learning with non-convex convergence guarantee},
author={Zhu, Dixian and Li, Gang and Wang, Bokun and Wu, Xiaodong and Yang, Tianbao},
booktitle={International Conference on Machine Learning},
pages={27548--27573},
year={2022},
organization={PMLR}
}
Install LibAUC
------------------------------------------------------------------------------------
Let's start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using ``pip install -U``.
.. container:: cell code
.. code:: python
!pip install -U libauc
Importing LibAUC
-----------------------
Import required packages to use
.. container:: cell code
.. code:: python
from libauc.losses import pAUCLoss
from libauc.optimizers import SOPAs
from libauc.models import resnet18 as ResNet18
from libauc.datasets import CIFAR10
from libauc.utils import ImbalancedDataGenerator
from libauc.sampler import DualSampler
from libauc.metrics import auc_roc_score
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
import torch
from PIL import Image
Reproducibility
-----------------------
The following function ``set_all_seeds`` limits the number of sources
of randomness behaviors, such as model intialization, data shuffling,
etcs. However, completely reproducible results are not guaranteed
across PyTorch releases
`[Ref] `__.
.. container:: cell code
.. code:: python
def set_all_seeds(SEED):
# REPRODUCIBILITY
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_all_seeds(2023)
Loading datasets
-----------------------
.. container:: cell markdown
In this step, we will use the
`CIFAR10 `__ as
benchmark dataset. Before importing data to ``dataloader``, we
construct imbalanced version for CIFAR10 by
``ImbalanceDataGenerator``. Specifically, it first randomly splits
the training data by class ID (e.g., 10 classes) into two even
portions as the positive and negative classes, and then it randomly
removes some samples from the positive class to make it imbalanced.
We keep the testing set untouched. We refer ``imratio`` to the ratio
of number of positive examples to number of all examples.
.. container:: cell code
.. code:: python
train_data, train_targets = CIFAR10(root='./data', train=True).as_array()
test_data, test_targets = CIFAR10(root='./data', train=False).as_array()
imratio = 0.2 ##we set the imratio as 0.2 here.
generator = ImbalancedDataGenerator(verbose=True, random_seed=2023)
(train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio)
(test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5)
.. container:: cell markdown
We define the data input pipeline such as data
augmentations. In this tutorial, we use ``RandomCrop``,
``RandomHorizontalFlip``.
.. container:: cell code
.. code:: python
class ImageDataset(Dataset):
def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
self.images = images.astype(np.uint8)
self.targets = targets
self.mode = mode
self.transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.RandomCrop((crop_size, crop_size), padding=None),
transforms.RandomHorizontalFlip(),
transforms.Resize((image_size, image_size)),
])
self.transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((image_size, image_size)),
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
target = self.targets[idx]
image = Image.fromarray(image.astype('uint8'))
if self.mode == 'train':
image = self.transform_train(image)
else:
image = self.transform_test(image)
return image, target, idx
.. container:: cell markdown
We define ``dataset``, ``DualSampler`` and ``dataloader`` here. By
default, we use ``batch_size`` 64 and we oversample the minority
class with ``pos:neg=1:1`` by setting ``sampling_rate=0.5``.
.. container:: cell code
.. code:: python
batch_size = 64
sampling_rate = 0.5
trainSet = ImageDataset(train_images, train_labels)
trainSet_eval = ImageDataset(train_images, train_labels,mode='test')
testSet = ImageDataset(test_images, test_labels, mode='test')
sampler = DualSampler(trainSet, batch_size, sampling_rate=sampling_rate)
trainloader = torch.utils.data.DataLoader(trainSet, batch_size=batch_size, sampler=sampler, num_workers=2)
trainloader_eval = torch.utils.data.DataLoader(trainSet_eval, batch_size=batch_size, shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testSet, batch_size=batch_size, shuffle=False, num_workers=2)
Configuration
-----------------------
.. container:: cell markdown
Hyper-Parameters
.. container:: cell code
.. code:: python
lr = 1e-3
margin = 0.6
gamma = 0.1
Lambda = 1.0
weight_decay = 1e-5
total_epoch = 60
decay_epoch = [30,45]
Model, Loss and Optimizer
-----------------------
.. container:: cell code
.. code:: python
model = ResNet18(pretrained=False, last_activation=None, num_classes=1)
model = model.cuda()
loss_fn = pAUCLoss('1w', data_len=len(trainSet), margin=margin, gamma=gamma)
optimizer = SOPAs(model.parameters(), mode='adam', lr=lr, weight_decay=weight_decay)
Training
-----------------------
.. container:: cell markdown
Now it's time for training. And we evaluate partial AUC performance with False Positive Rate(FPR) less than or equal to 0.3, i.e., FPR ≤ 0.3.
.. container:: cell code
.. code:: python
print ('Start Training')
print ('-'*30)
train_log, test_log = [], []
for epoch in range(total_epoch):
if epoch in decay_epoch:
optimizer.update_lr(decay_factor=10)
train_loss = []
model.train()
for idx, (data, targets, index) in enumerate(trainloader):
data, targets, index = data.cuda(), targets.cuda(), index.cuda()
y_pred = model(data)
y_prob = torch.sigmoid(y_pred)
loss = loss_fn(y_prob, targets, index)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())
######***evaluation***####
# evaluation on training sets
model.eval()
train_pred_list, train_true_list = [], []
for i, data in enumerate(trainloader_eval):
_, train_data, train_targets = data
train_data = train_data.cuda()
y_pred = model(train_data)
y_prob = torch.sigmoid(y_pred)
train_pred_list.append(y_prob.cpu().detach().numpy())
train_true_list.append(train_targets.cpu().detach().numpy())
train_true = np.concatenate(train_true_list)
train_pred = np.concatenate(train_pred_list)
train_pauc = auc_roc_score(train_true, train_pred, max_fpr=0.3)
train_loss = np.mean(train_loss)
train_log.append(train_pauc)
# evaluation on test sets
model.eval()
test_pred_list, test_true_list = [], []
for j, data in enumerate(testloader):
_, test_data, test_targets = data
test_data = test_data.cuda()
y_pred = model(test_data)
y_prob = torch.sigmoid(y_pred)
test_pred_list.append(y_prob.cpu().detach().numpy())
test_true_list.append(test_targets.numpy())
test_true = np.concatenate(test_true_list)
test_pred = np.concatenate(test_pred_list)
val_pauc = auc_roc_score(test_true, test_pred,max_fpr=0.3)
test_log.append(val_pauc)
model.train()
# print results
print("epoch: %s, train_loss: %.4f, train_pauc: %.4f, test_pauc: %.4f, lr: %.5f"%(epoch, train_loss, train_pauc, val_pauc, optimizer.lr ))
.. container:: output stream stdout
::
Start Training
------------------------------
epoch: 0, train_loss: 1.7896, train_pauc: 0.6491, test_pauc: 0.6448, lr: 0.00100
epoch: 1, train_loss: 0.5043, train_pauc: 0.7353, test_pauc: 0.7124, lr: 0.00100
epoch: 2, train_loss: 0.3066, train_pauc: 0.7430, test_pauc: 0.7172, lr: 0.00100
epoch: 3, train_loss: 0.2271, train_pauc: 0.8185, test_pauc: 0.7732, lr: 0.00100
epoch: 4, train_loss: 0.1825, train_pauc: 0.8344, test_pauc: 0.7796, lr: 0.00100
epoch: 5, train_loss: 0.1573, train_pauc: 0.8176, test_pauc: 0.7662, lr: 0.00100
epoch: 6, train_loss: 0.1341, train_pauc: 0.8846, test_pauc: 0.8118, lr: 0.00100
epoch: 7, train_loss: 0.1233, train_pauc: 0.8811, test_pauc: 0.8107, lr: 0.00100
epoch: 8, train_loss: 0.1087, train_pauc: 0.8662, test_pauc: 0.7902, lr: 0.00100
epoch: 9, train_loss: 0.0950, train_pauc: 0.8982, test_pauc: 0.8201, lr: 0.00100
epoch: 10, train_loss: 0.0830, train_pauc: 0.9113, test_pauc: 0.8254, lr: 0.00100
epoch: 11, train_loss: 0.0726, train_pauc: 0.9411, test_pauc: 0.8336, lr: 0.00100
epoch: 12, train_loss: 0.0619, train_pauc: 0.9570, test_pauc: 0.8474, lr: 0.00100
epoch: 13, train_loss: 0.0559, train_pauc: 0.9542, test_pauc: 0.8393, lr: 0.00100
epoch: 14, train_loss: 0.0480, train_pauc: 0.9034, test_pauc: 0.7886, lr: 0.00100
epoch: 15, train_loss: 0.0451, train_pauc: 0.9257, test_pauc: 0.8108, lr: 0.00100
epoch: 16, train_loss: 0.0360, train_pauc: 0.9719, test_pauc: 0.8481, lr: 0.00100
epoch: 17, train_loss: 0.0366, train_pauc: 0.9671, test_pauc: 0.8428, lr: 0.00100
epoch: 17, train_loss: 0.0366, train_pauc: 0.9671, test_pauc: 0.8428, lr: 0.00100
epoch: 18, train_loss: 0.0305, train_pauc: 0.9709, test_pauc: 0.8400, lr: 0.00100
epoch: 18, train_loss: 0.0305, train_pauc: 0.9709, test_pauc: 0.8400, lr: 0.00100
epoch: 19, train_loss: 0.0306, train_pauc: 0.9689, test_pauc: 0.8373, lr: 0.00100
epoch: 19, train_loss: 0.0306, train_pauc: 0.9689, test_pauc: 0.8373, lr: 0.00100
epoch: 20, train_loss: 0.0275, train_pauc: 0.9820, test_pauc: 0.8499, lr: 0.00100
epoch: 20, train_loss: 0.0275, train_pauc: 0.9820, test_pauc: 0.8499, lr: 0.00100
epoch: 21, train_loss: 0.0241, train_pauc: 0.9861, test_pauc: 0.8499, lr: 0.00100
epoch: 21, train_loss: 0.0241, train_pauc: 0.9861, test_pauc: 0.8499, lr: 0.00100
epoch: 22, train_loss: 0.0225, train_pauc: 0.9663, test_pauc: 0.8257, lr: 0.00100
epoch: 22, train_loss: 0.0225, train_pauc: 0.9663, test_pauc: 0.8257, lr: 0.00100
epoch: 23, train_loss: 0.0210, train_pauc: 0.9529, test_pauc: 0.8133, lr: 0.00100
epoch: 23, train_loss: 0.0210, train_pauc: 0.9529, test_pauc: 0.8133, lr: 0.00100
epoch: 24, train_loss: 0.0194, train_pauc: 0.9766, test_pauc: 0.8435, lr: 0.00100
epoch: 24, train_loss: 0.0194, train_pauc: 0.9766, test_pauc: 0.8435, lr: 0.00100
epoch: 25, train_loss: 0.0181, train_pauc: 0.9660, test_pauc: 0.8241, lr: 0.00100
epoch: 25, train_loss: 0.0181, train_pauc: 0.9660, test_pauc: 0.8241, lr: 0.00100
epoch: 26, train_loss: 0.0189, train_pauc: 0.9869, test_pauc: 0.8505, lr: 0.00100
epoch: 26, train_loss: 0.0189, train_pauc: 0.9869, test_pauc: 0.8505, lr: 0.00100
epoch: 27, train_loss: 0.0163, train_pauc: 0.9916, test_pauc: 0.8516, lr: 0.00100
epoch: 27, train_loss: 0.0163, train_pauc: 0.9916, test_pauc: 0.8516, lr: 0.00100
epoch: 28, train_loss: 0.0161, train_pauc: 0.9850, test_pauc: 0.8454, lr: 0.00100
epoch: 28, train_loss: 0.0161, train_pauc: 0.9850, test_pauc: 0.8454, lr: 0.00100
epoch: 29, train_loss: 0.0155, train_pauc: 0.9757, test_pauc: 0.8357, lr: 0.00100
Reducing learning rate to 0.00010 @ T=23430!
epoch: 29, train_loss: 0.0155, train_pauc: 0.9757, test_pauc: 0.8357, lr: 0.00100
Reducing learning rate to 0.00010 @ T=23430!
epoch: 30, train_loss: 0.0069, train_pauc: 0.9972, test_pauc: 0.8674, lr: 0.00010
epoch: 30, train_loss: 0.0069, train_pauc: 0.9972, test_pauc: 0.8674, lr: 0.00010
epoch: 31, train_loss: 0.0031, train_pauc: 0.9981, test_pauc: 0.8669, lr: 0.00010
epoch: 31, train_loss: 0.0031, train_pauc: 0.9981, test_pauc: 0.8669, lr: 0.00010
epoch: 32, train_loss: 0.0025, train_pauc: 0.9991, test_pauc: 0.8712, lr: 0.00010
epoch: 32, train_loss: 0.0025, train_pauc: 0.9991, test_pauc: 0.8712, lr: 0.00010
epoch: 33, train_loss: 0.0014, train_pauc: 0.9991, test_pauc: 0.8702, lr: 0.00010
epoch: 33, train_loss: 0.0014, train_pauc: 0.9991, test_pauc: 0.8702, lr: 0.00010
epoch: 34, train_loss: 0.0014, train_pauc: 0.9994, test_pauc: 0.8738, lr: 0.00010
epoch: 34, train_loss: 0.0014, train_pauc: 0.9994, test_pauc: 0.8738, lr: 0.00010
epoch: 35, train_loss: 0.0010, train_pauc: 0.9994, test_pauc: 0.8728, lr: 0.00010
epoch: 35, train_loss: 0.0010, train_pauc: 0.9994, test_pauc: 0.8728, lr: 0.00010
epoch: 36, train_loss: 0.0010, train_pauc: 0.9994, test_pauc: 0.8715, lr: 0.00010
epoch: 36, train_loss: 0.0010, train_pauc: 0.9994, test_pauc: 0.8715, lr: 0.00010
epoch: 37, train_loss: 0.0009, train_pauc: 0.9994, test_pauc: 0.8725, lr: 0.00010
epoch: 37, train_loss: 0.0009, train_pauc: 0.9994, test_pauc: 0.8725, lr: 0.00010
epoch: 38, train_loss: 0.0010, train_pauc: 0.9994, test_pauc: 0.8714, lr: 0.00010
epoch: 38, train_loss: 0.0010, train_pauc: 0.9994, test_pauc: 0.8714, lr: 0.00010
epoch: 39, train_loss: 0.0007, train_pauc: 0.9993, test_pauc: 0.8711, lr: 0.00010
epoch: 39, train_loss: 0.0007, train_pauc: 0.9993, test_pauc: 0.8711, lr: 0.00010
epoch: 40, train_loss: 0.0007, train_pauc: 0.9996, test_pauc: 0.8723, lr: 0.00010
epoch: 40, train_loss: 0.0007, train_pauc: 0.9996, test_pauc: 0.8723, lr: 0.00010
epoch: 41, train_loss: 0.0005, train_pauc: 0.9994, test_pauc: 0.8711, lr: 0.00010
epoch: 41, train_loss: 0.0005, train_pauc: 0.9994, test_pauc: 0.8711, lr: 0.00010
epoch: 42, train_loss: 0.0006, train_pauc: 0.9995, test_pauc: 0.8715, lr: 0.00010
epoch: 42, train_loss: 0.0006, train_pauc: 0.9995, test_pauc: 0.8715, lr: 0.00010
epoch: 43, train_loss: 0.0005, train_pauc: 0.9994, test_pauc: 0.8694, lr: 0.00010
epoch: 43, train_loss: 0.0005, train_pauc: 0.9994, test_pauc: 0.8694, lr: 0.00010
epoch: 44, train_loss: 0.0006, train_pauc: 0.9996, test_pauc: 0.8735, lr: 0.00010
Reducing learning rate to 0.00001 @ T=35145!
epoch: 44, train_loss: 0.0006, train_pauc: 0.9996, test_pauc: 0.8735, lr: 0.00010
Reducing learning rate to 0.00001 @ T=35145!
epoch: 45, train_loss: 0.0005, train_pauc: 0.9996, test_pauc: 0.8742, lr: 0.00001
epoch: 45, train_loss: 0.0005, train_pauc: 0.9996, test_pauc: 0.8742, lr: 0.00001
epoch: 46, train_loss: 0.0003, train_pauc: 0.9997, test_pauc: 0.8743, lr: 0.00001
epoch: 46, train_loss: 0.0003, train_pauc: 0.9997, test_pauc: 0.8743, lr: 0.00001
epoch: 47, train_loss: 0.0002, train_pauc: 0.9996, test_pauc: 0.8737, lr: 0.00001
epoch: 47, train_loss: 0.0002, train_pauc: 0.9996, test_pauc: 0.8737, lr: 0.00001
epoch: 48, train_loss: 0.0003, train_pauc: 0.9997, test_pauc: 0.8747, lr: 0.00001
epoch: 48, train_loss: 0.0003, train_pauc: 0.9997, test_pauc: 0.8747, lr: 0.00001
epoch: 49, train_loss: 0.0003, train_pauc: 0.9997, test_pauc: 0.8741, lr: 0.00001
epoch: 49, train_loss: 0.0003, train_pauc: 0.9997, test_pauc: 0.8741, lr: 0.00001
epoch: 50, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8754, lr: 0.00001
epoch: 50, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8754, lr: 0.00001
epoch: 51, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8745, lr: 0.00001
epoch: 51, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8745, lr: 0.00001
epoch: 52, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8749, lr: 0.00001
epoch: 52, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8749, lr: 0.00001
epoch: 53, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8748, lr: 0.00001
epoch: 53, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8748, lr: 0.00001
epoch: 54, train_loss: 0.0002, train_pauc: 0.9998, test_pauc: 0.8757, lr: 0.00001
epoch: 54, train_loss: 0.0002, train_pauc: 0.9998, test_pauc: 0.8757, lr: 0.00001
epoch: 55, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8749, lr: 0.00001
epoch: 55, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8749, lr: 0.00001
epoch: 56, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8744, lr: 0.00001
epoch: 56, train_loss: 0.0002, train_pauc: 0.9997, test_pauc: 0.8744, lr: 0.00001
epoch: 57, train_loss: 0.0002, train_pauc: 0.9998, test_pauc: 0.8749, lr: 0.00001
epoch: 57, train_loss: 0.0002, train_pauc: 0.9998, test_pauc: 0.8749, lr: 0.00001
epoch: 58, train_loss: 0.0002, train_pauc: 0.9998, test_pauc: 0.8743, lr: 0.00001
epoch: 58, train_loss: 0.0002, train_pauc: 0.9998, test_pauc: 0.8743, lr: 0.00001
epoch: 59, train_loss: 0.0002, train_pauc: 0.9998, test_pauc: 0.8744, lr: 0.00001
epoch: 59, train_loss: 0.0002, train_pauc: 0.9998, test_pauc: 0.8744, lr: 0.00001
Visualization
-----------------------
Now, let's see the learning curve for optimizing pAUC on train and test sets.
.. container:: cell code
.. code:: python
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (9,5)
x=np.arange(len(train_log))
plt.figure()
plt.plot(x, train_log, linestyle='-', label='pAUC_DRO Training', linewidth=3)
plt.plot(x, test_log, linestyle='-', label='pAUC_DRO Test', linewidth=3)
plt.title('CIFAR-10 (20% imbalanced)',fontsize=25)
plt.legend(fontsize=15)
plt.ylabel('pAUC(FPR≤0.3)', fontsize=25)
plt.xlabel('Epoch', fontsize=25)
plt.show()
.. container:: output display_data
.. image:: ./imgs/training_pauc.png