.. _tpauc_staco:
================================================================================================================================
Optimizing Two-Way partial AUC on Imbalanced CIFAR10 Dataset (STACO)
================================================================================================================================
.. raw:: html
------------------------------------------------------------------------------------
.. container:: cell markdown
| **Author**: Linli Zhou, Siqi Guo
\
Introduction
------------------------------------------------------------------------------------
In this tutorial, we will learn how to quickly train a ResNet18 model by optimizing **two-way partial AUC (TPAUC)** score using our novel :obj:`tpAUC_CVaR_Loss` and :obj:`STACO` optimizer `[Ref] `__ method on a binary image classification task on Cifar10. After completion of this tutorial, you should be able to use LibAUC to train your own models on your own datasets.
**Reference**:
If you find this tutorial helpful in your work, please cite our `[library paper] `__ and the following papers:
.. code-block:: RST
@article{
zhou2025stochastic,
title={Stochastic Primal-Dual Double Block-Coordinate for Two-way Partial {AUC} Maximization},
author={Linli Zhou and Bokun Wang and My T. Thai and Tianbao Yang},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2025},
url={https://openreview.net/forum?id=M3kibBFP4q},
note={}
}
Install LibAUC
------------------------------------------------------------------------------------
Let's start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using ``pip install -U``.
.. container:: cell code
.. code:: python
!pip install -U libauc
Importing LibAUC
------------------------------------------------------------------------------------
Import required libraries to use
.. container:: cell code
.. code:: python
from libauc.models import resnet18
from libauc.datasets import CIFAR10
from libauc.losses import tpAUC_CVaR_loss
from libauc.optimizers import STACO
from libauc.utils import ImbalancedDataGenerator
from libauc.sampler import DualSampler # data resampling (for binary class)
from libauc.metrics import pauc_roc_score
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.nn as nn
Reproducibility
-----------------------
The following function ``set_all_seeds`` limits the number of sources
of randomness behaviors, such as model initialization, data shuffling,
etcs. However, completely reproducible results are not guaranteed
across PyTorch releases
`[Ref] `__.
.. container:: cell code
.. code:: python
def set_all_seeds(SEED):
# REPRODUCIBILITY
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_all_seeds(2023)
Image Dataset
------------------------------------------------------------------------------------
Now we define the data input pipeline such as data augmentations. In this tutorials, we use ``RandomCrop``, ``RandomHorizontalFlip``. The ``pos_index_map`` helps map global index to local index for reducing memory cost in loss function since we only need to track the indices for positive samples.
.. container:: cell code
.. code:: python
class ImageDataset(Dataset):
def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
self.images = images.astype(np.uint8)
self.targets = targets
self.mode = mode
self.transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.RandomCrop((crop_size, crop_size), padding=None),
transforms.RandomHorizontalFlip(),
transforms.Resize((image_size, image_size)),
])
self.transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((image_size, image_size)),
])
# for loss function
self.pos_indices = np.flatnonzero(targets==1)
self.pos_index_map = {}
for i, idx in enumerate(self.pos_indices):
self.pos_index_map[idx] = i
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
target = self.targets[idx]
image = Image.fromarray(image.astype('uint8'))
if self.mode == 'train':
idx = self.pos_index_map[idx] if idx in self.pos_indices else -1
image = self.transform_train(image)
else:
image = self.transform_test(image)
return image, target, idx
HyperParameters
------------------------------------------------------------------------------------
.. container:: cell code
.. code:: python
# HyperParameters
SEED = 123
batch_size = 64
total_epochs = 60
weight_decay = 1e-4
lr = 1e-3
decay_epochs = [20, 40]
decay_factor = 10
theta_0 = theta_1 = 0.5
alpha = beta_0 = beta_1 = 1e-3
sampling_rate = 0.5
num_pos = round(sampling_rate*batch_size)
num_neg = batch_size - num_pos
Loading datasets
------------------------------------------------------------------------------------
.. container:: cell code
.. code:: python
train_data, train_targets = CIFAR10(root='./data', train=True).as_array()
test_data, test_targets = CIFAR10(root='./data', train=False).as_array()
imratio = 0.2
g = ImbalancedDataGenerator(shuffle=True, verbose=True, random_seed=0)
(train_images, train_labels) = g.transform(train_data, train_targets, imratio=imratio)
(test_images, test_labels) = g.transform(test_data, test_targets, imratio=0.5)
trainSet = ImageDataset(train_images, train_labels)
testSet = ImageDataset(test_images, test_labels, mode = 'test')
sampler = DualSampler(dataset=trainSet, batch_size=batch_size, labels=train_labels, shuffle=True, sampling_rate=sampling_rate)
trainloader = torch.utils.data.DataLoader(trainSet, sampler=sampler, batch_size=batch_size, shuffle=False, num_workers=0)
testloader = torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=0, shuffle=False)
.. container:: output stream stdout
::
Files already downloaded and verified
Files already downloaded and verified
#SAMPLES: 31250, CLASS 0.0 COUNT: 25000, CLASS RATIO: 0.8000
#SAMPLES: 31250, CLASS 1.0 COUNT: 6250, CLASS RATIO: 0.2000
#SAMPLES: 10000, CLASS 1.0 COUNT: 5000, CLASS RATIO: 0.5000
#SAMPLES: 10000, CLASS 0.0 COUNT: 5000, CLASS RATIO: 0.5000
Creating models & TPAUC Optimizer
-----------------------
.. container:: cell code
.. code:: python
set_all_seeds(SEED)
model = resnet18(pretrained=False, num_classes=1, last_activation=None)
model = model.cuda()
# Initialize the loss function and optimizer
loss_fn = tpAUC_CVaR_loss(data_length=sampler.pos_len, alpha=alpha, beta_0=beta_0, beta_1=beta_1, theta_0=theta_0, theta_1=theta_1)
optimizer = STACO(model.parameters(), loss_fn=loss_fn, mode='adam', lr=lr, weight_decay=weight_decay)
Training
-----------------------
.. container:: cell markdown
Now it's time for training.
.. container:: cell code
.. code:: python
print ('Start Training')
print ('-'*30)
tr_tpAUC=[]
te_tpAUC=[]
for epoch in range(total_epochs):
if epoch in decay_epochs:
optimizer.update_lr(decay_factor=decay_factor)
loss_fn.alpha /= decay_factor
loss_fn.beta_0 /= decay_factor
loss_fn.beta_1 /= decay_factor
train_loss = 0
model.train()
for idx, data in enumerate(trainloader):
train_data, train_labels, index = data
train_data, train_labels = train_data.cuda(), train_labels.cuda()
y_pred = model(train_data)
loss = loss_fn(y_pred, train_labels, index[:num_pos])
train_loss = train_loss + loss.cpu().detach().numpy()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss = train_loss/(idx+1)
# evaluation
model.eval()
with torch.no_grad():
train_pred = []
train_true = []
for jdx, data in enumerate(trainloader):
train_data, train_labels,_ = data
train_data = train_data.cuda()
y_pred = model(train_data)
y_prob = torch.sigmoid(y_pred)
train_pred.append(y_prob.cpu().detach().numpy())
train_true.append(train_labels.numpy())
train_true = np.concatenate(train_true)
train_pred = np.concatenate(train_pred)
single_train_auc = pauc_roc_score(train_true, train_pred, max_fpr = 0.3, min_tpr=0.7)
test_pred = []
test_true = []
for jdx, data in enumerate(testloader):
test_data, test_labels, _ = data
test_data = test_data.cuda()
y_pred = model(test_data)
y_prob = torch.sigmoid(y_pred)
test_pred.append(y_prob.cpu().detach().numpy())
test_true.append(test_labels.numpy())
test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred)
single_test_auc = pauc_roc_score(test_true, test_pred, max_fpr = 0.3, min_tpr=0.7)
print('Epoch=%s, Loss=%0.4f, Train_tpAUC(0.3,0.7)=%.4f, Test_tpAUC(0.3,0.7)=%.4f, lr=%.4f'%(epoch, train_loss, single_train_auc, single_test_auc, optimizer.lr))
tr_tpAUC.append(single_train_auc)
te_tpAUC.append(single_test_auc)
.. container:: output stream stdout
::
Start Training
------------------------------
Epoch=0, Loss=0.7769, Train_tpAUC(0.3,0.7)=0.0083, Test_tpAUC(0.3,0.7)=0.0027, lr=0.0010
Epoch=1, Loss=0.5561, Train_tpAUC(0.3,0.7)=0.0003, Test_tpAUC(0.3,0.7)=0.0000, lr=0.0010
Epoch=2, Loss=0.4426, Train_tpAUC(0.3,0.7)=0.0901, Test_tpAUC(0.3,0.7)=0.0508, lr=0.0010
Epoch=3, Loss=0.3564, Train_tpAUC(0.3,0.7)=0.1696, Test_tpAUC(0.3,0.7)=0.1078, lr=0.0010
Epoch=4, Loss=0.2897, Train_tpAUC(0.3,0.7)=0.3645, Test_tpAUC(0.3,0.7)=0.1791, lr=0.0010
Epoch=5, Loss=0.2330, Train_tpAUC(0.3,0.7)=0.3899, Test_tpAUC(0.3,0.7)=0.1958, lr=0.0010
Epoch=6, Loss=0.1927, Train_tpAUC(0.3,0.7)=0.4758, Test_tpAUC(0.3,0.7)=0.2084, lr=0.0010
Epoch=7, Loss=0.1622, Train_tpAUC(0.3,0.7)=0.3806, Test_tpAUC(0.3,0.7)=0.1483, lr=0.0010
Epoch=8, Loss=0.1318, Train_tpAUC(0.3,0.7)=0.3751, Test_tpAUC(0.3,0.7)=0.0954, lr=0.0010
Epoch=9, Loss=0.1049, Train_tpAUC(0.3,0.7)=0.6161, Test_tpAUC(0.3,0.7)=0.1791, lr=0.0010
Epoch=10, Loss=0.0851, Train_tpAUC(0.3,0.7)=0.6697, Test_tpAUC(0.3,0.7)=0.2554, lr=0.0010
Epoch=11, Loss=0.0735, Train_tpAUC(0.3,0.7)=0.6783, Test_tpAUC(0.3,0.7)=0.2324, lr=0.0010
Epoch=12, Loss=0.0570, Train_tpAUC(0.3,0.7)=0.6791, Test_tpAUC(0.3,0.7)=0.2125, lr=0.0010
Epoch=13, Loss=0.0509, Train_tpAUC(0.3,0.7)=0.7166, Test_tpAUC(0.3,0.7)=0.1990, lr=0.0010
Epoch=14, Loss=0.0432, Train_tpAUC(0.3,0.7)=0.6712, Test_tpAUC(0.3,0.7)=0.2173, lr=0.0010
Epoch=15, Loss=0.0366, Train_tpAUC(0.3,0.7)=0.6469, Test_tpAUC(0.3,0.7)=0.2462, lr=0.0010
Epoch=16, Loss=0.0299, Train_tpAUC(0.3,0.7)=0.8008, Test_tpAUC(0.3,0.7)=0.3238, lr=0.0010
Epoch=17, Loss=0.0250, Train_tpAUC(0.3,0.7)=0.7510, Test_tpAUC(0.3,0.7)=0.2185, lr=0.0010
Epoch=18, Loss=0.0211, Train_tpAUC(0.3,0.7)=0.7333, Test_tpAUC(0.3,0.7)=0.2713, lr=0.0010
Epoch=19, Loss=0.0182, Train_tpAUC(0.3,0.7)=0.7187, Test_tpAUC(0.3,0.7)=0.2233, lr=0.0010
Reducing learning rate to 0.00010 @ T=15620!
Epoch=20, Loss=0.0079, Train_tpAUC(0.3,0.7)=0.9566, Test_tpAUC(0.3,0.7)=0.3584, lr=0.0001
Epoch=21, Loss=0.0041, Train_tpAUC(0.3,0.7)=0.9705, Test_tpAUC(0.3,0.7)=0.3665, lr=0.0001
Epoch=22, Loss=0.0028, Train_tpAUC(0.3,0.7)=0.9784, Test_tpAUC(0.3,0.7)=0.3786, lr=0.0001
Epoch=23, Loss=0.0022, Train_tpAUC(0.3,0.7)=0.9817, Test_tpAUC(0.3,0.7)=0.3660, lr=0.0001
Epoch=24, Loss=0.0019, Train_tpAUC(0.3,0.7)=0.9878, Test_tpAUC(0.3,0.7)=0.3658, lr=0.0001
Epoch=25, Loss=0.0016, Train_tpAUC(0.3,0.7)=0.9879, Test_tpAUC(0.3,0.7)=0.3671, lr=0.0001
Epoch=26, Loss=0.0014, Train_tpAUC(0.3,0.7)=0.9880, Test_tpAUC(0.3,0.7)=0.3684, lr=0.0001
Epoch=27, Loss=0.0012, Train_tpAUC(0.3,0.7)=0.9902, Test_tpAUC(0.3,0.7)=0.3505, lr=0.0001
Epoch=28, Loss=0.0013, Train_tpAUC(0.3,0.7)=0.9921, Test_tpAUC(0.3,0.7)=0.3423, lr=0.0001
Epoch=29, Loss=0.0011, Train_tpAUC(0.3,0.7)=0.9916, Test_tpAUC(0.3,0.7)=0.3578, lr=0.0001
Epoch=30, Loss=0.0011, Train_tpAUC(0.3,0.7)=0.9904, Test_tpAUC(0.3,0.7)=0.3379, lr=0.0001
Epoch=31, Loss=0.0011, Train_tpAUC(0.3,0.7)=0.9933, Test_tpAUC(0.3,0.7)=0.3677, lr=0.0001
Epoch=32, Loss=0.0009, Train_tpAUC(0.3,0.7)=0.9916, Test_tpAUC(0.3,0.7)=0.3506, lr=0.0001
Epoch=33, Loss=0.0011, Train_tpAUC(0.3,0.7)=0.9921, Test_tpAUC(0.3,0.7)=0.3405, lr=0.0001
Epoch=34, Loss=0.0010, Train_tpAUC(0.3,0.7)=0.9917, Test_tpAUC(0.3,0.7)=0.3405, lr=0.0001
Epoch=35, Loss=0.0009, Train_tpAUC(0.3,0.7)=0.9919, Test_tpAUC(0.3,0.7)=0.3500, lr=0.0001
Epoch=36, Loss=0.0010, Train_tpAUC(0.3,0.7)=0.9917, Test_tpAUC(0.3,0.7)=0.3385, lr=0.0001
Epoch=37, Loss=0.0009, Train_tpAUC(0.3,0.7)=0.9922, Test_tpAUC(0.3,0.7)=0.3134, lr=0.0001
Epoch=38, Loss=0.0010, Train_tpAUC(0.3,0.7)=0.9916, Test_tpAUC(0.3,0.7)=0.3365, lr=0.0001
Epoch=39, Loss=0.0009, Train_tpAUC(0.3,0.7)=0.9928, Test_tpAUC(0.3,0.7)=0.3325, lr=0.0001
Reducing learning rate to 0.00001 @ T=31240!
Epoch=40, Loss=0.0006, Train_tpAUC(0.3,0.7)=0.9965, Test_tpAUC(0.3,0.7)=0.3450, lr=0.0000
Epoch=41, Loss=0.0004, Train_tpAUC(0.3,0.7)=0.9973, Test_tpAUC(0.3,0.7)=0.3559, lr=0.0000
Epoch=42, Loss=0.0003, Train_tpAUC(0.3,0.7)=0.9985, Test_tpAUC(0.3,0.7)=0.3526, lr=0.0000
Epoch=43, Loss=0.0003, Train_tpAUC(0.3,0.7)=0.9986, Test_tpAUC(0.3,0.7)=0.3488, lr=0.0000
Epoch=44, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9987, Test_tpAUC(0.3,0.7)=0.3589, lr=0.0000
Epoch=45, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9990, Test_tpAUC(0.3,0.7)=0.3578, lr=0.0000
Epoch=46, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9990, Test_tpAUC(0.3,0.7)=0.3590, lr=0.0000
Epoch=47, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9993, Test_tpAUC(0.3,0.7)=0.3574, lr=0.0000
Epoch=48, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3645, lr=0.0000
Epoch=49, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3697, lr=0.0000
Epoch=50, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3672, lr=0.0000
Epoch=51, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3669, lr=0.0000
Epoch=52, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3709, lr=0.0000
Epoch=53, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9996, Test_tpAUC(0.3,0.7)=0.3621, lr=0.0000
Epoch=54, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9997, Test_tpAUC(0.3,0.7)=0.3673, lr=0.0000
Epoch=55, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3683, lr=0.0000
Epoch=56, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9997, Test_tpAUC(0.3,0.7)=0.3755, lr=0.0000
Epoch=57, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9998, Test_tpAUC(0.3,0.7)=0.3776, lr=0.0000
Epoch=58, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9996, Test_tpAUC(0.3,0.7)=0.3722, lr=0.0000
Epoch=59, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9999, Test_tpAUC(0.3,0.7)=0.3707, lr=0.0000
Visualization
-----------------------
Now, let's see the change of two-way partial AUC scores on training and testing set and compare it with AUCM method.
.. container:: cell code
.. code:: python
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (9,5)
x=np.arange(60)
aucm_tr_tpAUC = [0.0, 0.0036274632780175535, 0.005047971892709222, 0.16224877002854712, 0.08772204107278256, 0.09231572626711304, 0.2732654643710093, 0.26647582593555136, 0.24948327111467297, 0.3138958243403896, 0.41470878408333095, 0.43854233806173326, 0.44279415206018125, 0.41232459713094294, 0.42158369465786594, 0.28455608967939594, 0.47755676075725645, 0.4896781046120168, 0.3863189337375359, 0.41758365936793596, 0.842293465507402, 0.8849664436167597, 0.908744785930645, 0.925100663457797, 0.9326945885667315, 0.9383691079068582, 0.9493701298918441, 0.953469756067609, 0.959401363614275, 0.9611639035102166, 0.9687268876750703, 0.96784177567828, 0.9702642497143897, 0.9702752956047458, 0.9733664551144465, 0.9760006331554156, 0.9760579703978677, 0.9755799377033129, 0.9770459817114208, 0.9787550086264667, 0.9813317160205113, 0.9802740408912104, 0.9814502948096553, 0.9811205278462805, 0.9821393644867049, 0.9823378881296415, 0.9827693626631546, 0.9815768245053843, 0.9835002857701681, 0.9825182242961066, 0.9819347131316462, 0.9819089304938505, 0.9819178330190087, 0.9830069649230193, 0.982775401438961, 0.9835906895346361, 0.9843369096878757, 0.9844427527966737, 0.9829346774859711, 0.9838700118199921]
aucm_te_tpAUC = [0.0, 0.000301777777777778, 0.0, 0.050526666666666664, 0.022042666666666665, 0.006600444444444444, 0.09903022222222221, 0.10774866666666666, 0.1029888888888889, 0.08559822222222221, 0.17839955555555553, 0.16736266666666666, 0.1591968888888889, 0.14970355555555553, 0.16401288888888887, 0.10874133333333333, 0.2060488888888889, 0.20340622222222224, 0.15917866666666666, 0.15321466666666667, 0.3240817777777778, 0.3328191111111111, 0.33990511111111105, 0.33390533333333333, 0.3470675555555556, 0.3300217777777778, 0.35229422222222223, 0.32868844444444445, 0.34444044444444444, 0.33212177777777785, 0.3557311111111111, 0.33471155555555554, 0.332584, 0.3388742222222223, 0.31508044444444444, 0.3285328888888889, 0.32913555555555557, 0.35162800000000005, 0.32234244444444443, 0.31436844444444445, 0.331578888888889, 0.3387742222222222, 0.3338368888888889, 0.3277377777777778, 0.335048, 0.3345755555555555, 0.34122044444444444, 0.3344044444444444, 0.33397066666666664, 0.32639111111111113, 0.3255946666666667, 0.3287395555555556, 0.3287168888888889, 0.33072799999999997, 0.3306648888888889, 0.33826266666666666, 0.3271982222222222, 0.3275493333333333, 0.33227422222222225, 0.33158488888888893]
plt.figure()
plt.plot(x, tr_tpAUC, linestyle='--', label='STACO train', linewidth=3)
plt.plot(x, te_tpAUC, label='STACO test', linewidth=3)
plt.plot(x, aucm_tr_tpAUC, linestyle='--', label='PESG train', linewidth=3)
plt.plot(x, aucm_te_tpAUC, label='PESG test', linewidth=3)
plt.title('CIFAR-10 (20% imbalanced)',fontsize=25)
plt.legend(fontsize=15)
plt.ylabel('TPAUC(0.7,0.3)',fontsize=25)
plt.xlabel('epochs',fontsize=25)
.. container:: output display_data
.. image:: ./imgs/staco_tp.png