.. _tpauc_staco: ================================================================================================================================ Optimizing Two-Way partial AUC on Imbalanced CIFAR10 Dataset (STACO) ================================================================================================================================ .. raw:: html
Run on Colab
Download Notebook
View on Github
------------------------------------------------------------------------------------ .. container:: cell markdown | **Author**: Linli Zhou, Siqi Guo \ Introduction ------------------------------------------------------------------------------------ In this tutorial, we will learn how to quickly train a ResNet18 model by optimizing **two-way partial AUC (TPAUC)** score using our novel :obj:`tpAUC_CVaR_Loss` and :obj:`STACO` optimizer `[Ref] `__ method on a binary image classification task on Cifar10. After completion of this tutorial, you should be able to use LibAUC to train your own models on your own datasets. **Reference**: If you find this tutorial helpful in your work, please cite our `[library paper] `__ and the following papers: .. code-block:: RST @article{ zhou2025stochastic, title={Stochastic Primal-Dual Double Block-Coordinate for Two-way Partial {AUC} Maximization}, author={Linli Zhou and Bokun Wang and My T. Thai and Tianbao Yang}, journal={Transactions on Machine Learning Research}, issn={2835-8856}, year={2025}, url={https://openreview.net/forum?id=M3kibBFP4q}, note={} } Install LibAUC ------------------------------------------------------------------------------------ Let's start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using ``pip install -U``. .. container:: cell code .. code:: python !pip install -U libauc Importing LibAUC ------------------------------------------------------------------------------------ Import required libraries to use .. container:: cell code .. code:: python from libauc.models import resnet18 from libauc.datasets import CIFAR10 from libauc.losses import tpAUC_CVaR_loss from libauc.optimizers import STACO from libauc.utils import ImbalancedDataGenerator from libauc.sampler import DualSampler # data resampling (for binary class) from libauc.metrics import pauc_roc_score import torch from PIL import Image import numpy as np import torchvision.transforms as transforms from torch.utils.data import Dataset import torch.nn as nn Reproducibility ----------------------- The following function ``set_all_seeds`` limits the number of sources of randomness behaviors, such as model initialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases `[Ref] `__. .. container:: cell code .. code:: python def set_all_seeds(SEED): # REPRODUCIBILITY np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_all_seeds(2023) Image Dataset ------------------------------------------------------------------------------------ Now we define the data input pipeline such as data augmentations. In this tutorials, we use ``RandomCrop``, ``RandomHorizontalFlip``. The ``pos_index_map`` helps map global index to local index for reducing memory cost in loss function since we only need to track the indices for positive samples. .. container:: cell code .. code:: python class ImageDataset(Dataset): def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'): self.images = images.astype(np.uint8) self.targets = targets self.mode = mode self.transform_train = transforms.Compose([ transforms.ToTensor(), transforms.RandomCrop((crop_size, crop_size), padding=None), transforms.RandomHorizontalFlip(), transforms.Resize((image_size, image_size)), ]) self.transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Resize((image_size, image_size)), ]) # for loss function self.pos_indices = np.flatnonzero(targets==1) self.pos_index_map = {} for i, idx in enumerate(self.pos_indices): self.pos_index_map[idx] = i def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] target = self.targets[idx] image = Image.fromarray(image.astype('uint8')) if self.mode == 'train': idx = self.pos_index_map[idx] if idx in self.pos_indices else -1 image = self.transform_train(image) else: image = self.transform_test(image) return image, target, idx HyperParameters ------------------------------------------------------------------------------------ .. container:: cell code .. code:: python # HyperParameters SEED = 123 batch_size = 64 total_epochs = 60 weight_decay = 1e-4 lr = 1e-3 decay_epochs = [20, 40] decay_factor = 10 theta_0 = theta_1 = 0.5 alpha = beta_0 = beta_1 = 1e-3 sampling_rate = 0.5 num_pos = round(sampling_rate*batch_size) num_neg = batch_size - num_pos Loading datasets ------------------------------------------------------------------------------------ .. container:: cell code .. code:: python train_data, train_targets = CIFAR10(root='./data', train=True).as_array() test_data, test_targets = CIFAR10(root='./data', train=False).as_array() imratio = 0.2 g = ImbalancedDataGenerator(shuffle=True, verbose=True, random_seed=0) (train_images, train_labels) = g.transform(train_data, train_targets, imratio=imratio) (test_images, test_labels) = g.transform(test_data, test_targets, imratio=0.5) trainSet = ImageDataset(train_images, train_labels) testSet = ImageDataset(test_images, test_labels, mode = 'test') sampler = DualSampler(dataset=trainSet, batch_size=batch_size, labels=train_labels, shuffle=True, sampling_rate=sampling_rate) trainloader = torch.utils.data.DataLoader(trainSet, sampler=sampler, batch_size=batch_size, shuffle=False, num_workers=0) testloader = torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=0, shuffle=False) .. container:: output stream stdout :: Files already downloaded and verified Files already downloaded and verified #SAMPLES: 31250, CLASS 0.0 COUNT: 25000, CLASS RATIO: 0.8000 #SAMPLES: 31250, CLASS 1.0 COUNT: 6250, CLASS RATIO: 0.2000 #SAMPLES: 10000, CLASS 1.0 COUNT: 5000, CLASS RATIO: 0.5000 #SAMPLES: 10000, CLASS 0.0 COUNT: 5000, CLASS RATIO: 0.5000 Creating models & TPAUC Optimizer ----------------------- .. container:: cell code .. code:: python set_all_seeds(SEED) model = resnet18(pretrained=False, num_classes=1, last_activation=None) model = model.cuda() # Initialize the loss function and optimizer loss_fn = tpAUC_CVaR_loss(data_length=sampler.pos_len, alpha=alpha, beta_0=beta_0, beta_1=beta_1, theta_0=theta_0, theta_1=theta_1) optimizer = STACO(model.parameters(), loss_fn=loss_fn, mode='adam', lr=lr, weight_decay=weight_decay) Training ----------------------- .. container:: cell markdown Now it's time for training. .. container:: cell code .. code:: python print ('Start Training') print ('-'*30) tr_tpAUC=[] te_tpAUC=[] for epoch in range(total_epochs): if epoch in decay_epochs: optimizer.update_lr(decay_factor=decay_factor) loss_fn.alpha /= decay_factor loss_fn.beta_0 /= decay_factor loss_fn.beta_1 /= decay_factor train_loss = 0 model.train() for idx, data in enumerate(trainloader): train_data, train_labels, index = data train_data, train_labels = train_data.cuda(), train_labels.cuda() y_pred = model(train_data) loss = loss_fn(y_pred, train_labels, index[:num_pos]) train_loss = train_loss + loss.cpu().detach().numpy() optimizer.zero_grad() loss.backward() optimizer.step() train_loss = train_loss/(idx+1) # evaluation model.eval() with torch.no_grad(): train_pred = [] train_true = [] for jdx, data in enumerate(trainloader): train_data, train_labels,_ = data train_data = train_data.cuda() y_pred = model(train_data) y_prob = torch.sigmoid(y_pred) train_pred.append(y_prob.cpu().detach().numpy()) train_true.append(train_labels.numpy()) train_true = np.concatenate(train_true) train_pred = np.concatenate(train_pred) single_train_auc = pauc_roc_score(train_true, train_pred, max_fpr = 0.3, min_tpr=0.7) test_pred = [] test_true = [] for jdx, data in enumerate(testloader): test_data, test_labels, _ = data test_data = test_data.cuda() y_pred = model(test_data) y_prob = torch.sigmoid(y_pred) test_pred.append(y_prob.cpu().detach().numpy()) test_true.append(test_labels.numpy()) test_true = np.concatenate(test_true) test_pred = np.concatenate(test_pred) single_test_auc = pauc_roc_score(test_true, test_pred, max_fpr = 0.3, min_tpr=0.7) print('Epoch=%s, Loss=%0.4f, Train_tpAUC(0.3,0.7)=%.4f, Test_tpAUC(0.3,0.7)=%.4f, lr=%.4f'%(epoch, train_loss, single_train_auc, single_test_auc, optimizer.lr)) tr_tpAUC.append(single_train_auc) te_tpAUC.append(single_test_auc) .. container:: output stream stdout :: Start Training ------------------------------ Epoch=0, Loss=0.7769, Train_tpAUC(0.3,0.7)=0.0083, Test_tpAUC(0.3,0.7)=0.0027, lr=0.0010 Epoch=1, Loss=0.5561, Train_tpAUC(0.3,0.7)=0.0003, Test_tpAUC(0.3,0.7)=0.0000, lr=0.0010 Epoch=2, Loss=0.4426, Train_tpAUC(0.3,0.7)=0.0901, Test_tpAUC(0.3,0.7)=0.0508, lr=0.0010 Epoch=3, Loss=0.3564, Train_tpAUC(0.3,0.7)=0.1696, Test_tpAUC(0.3,0.7)=0.1078, lr=0.0010 Epoch=4, Loss=0.2897, Train_tpAUC(0.3,0.7)=0.3645, Test_tpAUC(0.3,0.7)=0.1791, lr=0.0010 Epoch=5, Loss=0.2330, Train_tpAUC(0.3,0.7)=0.3899, Test_tpAUC(0.3,0.7)=0.1958, lr=0.0010 Epoch=6, Loss=0.1927, Train_tpAUC(0.3,0.7)=0.4758, Test_tpAUC(0.3,0.7)=0.2084, lr=0.0010 Epoch=7, Loss=0.1622, Train_tpAUC(0.3,0.7)=0.3806, Test_tpAUC(0.3,0.7)=0.1483, lr=0.0010 Epoch=8, Loss=0.1318, Train_tpAUC(0.3,0.7)=0.3751, Test_tpAUC(0.3,0.7)=0.0954, lr=0.0010 Epoch=9, Loss=0.1049, Train_tpAUC(0.3,0.7)=0.6161, Test_tpAUC(0.3,0.7)=0.1791, lr=0.0010 Epoch=10, Loss=0.0851, Train_tpAUC(0.3,0.7)=0.6697, Test_tpAUC(0.3,0.7)=0.2554, lr=0.0010 Epoch=11, Loss=0.0735, Train_tpAUC(0.3,0.7)=0.6783, Test_tpAUC(0.3,0.7)=0.2324, lr=0.0010 Epoch=12, Loss=0.0570, Train_tpAUC(0.3,0.7)=0.6791, Test_tpAUC(0.3,0.7)=0.2125, lr=0.0010 Epoch=13, Loss=0.0509, Train_tpAUC(0.3,0.7)=0.7166, Test_tpAUC(0.3,0.7)=0.1990, lr=0.0010 Epoch=14, Loss=0.0432, Train_tpAUC(0.3,0.7)=0.6712, Test_tpAUC(0.3,0.7)=0.2173, lr=0.0010 Epoch=15, Loss=0.0366, Train_tpAUC(0.3,0.7)=0.6469, Test_tpAUC(0.3,0.7)=0.2462, lr=0.0010 Epoch=16, Loss=0.0299, Train_tpAUC(0.3,0.7)=0.8008, Test_tpAUC(0.3,0.7)=0.3238, lr=0.0010 Epoch=17, Loss=0.0250, Train_tpAUC(0.3,0.7)=0.7510, Test_tpAUC(0.3,0.7)=0.2185, lr=0.0010 Epoch=18, Loss=0.0211, Train_tpAUC(0.3,0.7)=0.7333, Test_tpAUC(0.3,0.7)=0.2713, lr=0.0010 Epoch=19, Loss=0.0182, Train_tpAUC(0.3,0.7)=0.7187, Test_tpAUC(0.3,0.7)=0.2233, lr=0.0010 Reducing learning rate to 0.00010 @ T=15620! Epoch=20, Loss=0.0079, Train_tpAUC(0.3,0.7)=0.9566, Test_tpAUC(0.3,0.7)=0.3584, lr=0.0001 Epoch=21, Loss=0.0041, Train_tpAUC(0.3,0.7)=0.9705, Test_tpAUC(0.3,0.7)=0.3665, lr=0.0001 Epoch=22, Loss=0.0028, Train_tpAUC(0.3,0.7)=0.9784, Test_tpAUC(0.3,0.7)=0.3786, lr=0.0001 Epoch=23, Loss=0.0022, Train_tpAUC(0.3,0.7)=0.9817, Test_tpAUC(0.3,0.7)=0.3660, lr=0.0001 Epoch=24, Loss=0.0019, Train_tpAUC(0.3,0.7)=0.9878, Test_tpAUC(0.3,0.7)=0.3658, lr=0.0001 Epoch=25, Loss=0.0016, Train_tpAUC(0.3,0.7)=0.9879, Test_tpAUC(0.3,0.7)=0.3671, lr=0.0001 Epoch=26, Loss=0.0014, Train_tpAUC(0.3,0.7)=0.9880, Test_tpAUC(0.3,0.7)=0.3684, lr=0.0001 Epoch=27, Loss=0.0012, Train_tpAUC(0.3,0.7)=0.9902, Test_tpAUC(0.3,0.7)=0.3505, lr=0.0001 Epoch=28, Loss=0.0013, Train_tpAUC(0.3,0.7)=0.9921, Test_tpAUC(0.3,0.7)=0.3423, lr=0.0001 Epoch=29, Loss=0.0011, Train_tpAUC(0.3,0.7)=0.9916, Test_tpAUC(0.3,0.7)=0.3578, lr=0.0001 Epoch=30, Loss=0.0011, Train_tpAUC(0.3,0.7)=0.9904, Test_tpAUC(0.3,0.7)=0.3379, lr=0.0001 Epoch=31, Loss=0.0011, Train_tpAUC(0.3,0.7)=0.9933, Test_tpAUC(0.3,0.7)=0.3677, lr=0.0001 Epoch=32, Loss=0.0009, Train_tpAUC(0.3,0.7)=0.9916, Test_tpAUC(0.3,0.7)=0.3506, lr=0.0001 Epoch=33, Loss=0.0011, Train_tpAUC(0.3,0.7)=0.9921, Test_tpAUC(0.3,0.7)=0.3405, lr=0.0001 Epoch=34, Loss=0.0010, Train_tpAUC(0.3,0.7)=0.9917, Test_tpAUC(0.3,0.7)=0.3405, lr=0.0001 Epoch=35, Loss=0.0009, Train_tpAUC(0.3,0.7)=0.9919, Test_tpAUC(0.3,0.7)=0.3500, lr=0.0001 Epoch=36, Loss=0.0010, Train_tpAUC(0.3,0.7)=0.9917, Test_tpAUC(0.3,0.7)=0.3385, lr=0.0001 Epoch=37, Loss=0.0009, Train_tpAUC(0.3,0.7)=0.9922, Test_tpAUC(0.3,0.7)=0.3134, lr=0.0001 Epoch=38, Loss=0.0010, Train_tpAUC(0.3,0.7)=0.9916, Test_tpAUC(0.3,0.7)=0.3365, lr=0.0001 Epoch=39, Loss=0.0009, Train_tpAUC(0.3,0.7)=0.9928, Test_tpAUC(0.3,0.7)=0.3325, lr=0.0001 Reducing learning rate to 0.00001 @ T=31240! Epoch=40, Loss=0.0006, Train_tpAUC(0.3,0.7)=0.9965, Test_tpAUC(0.3,0.7)=0.3450, lr=0.0000 Epoch=41, Loss=0.0004, Train_tpAUC(0.3,0.7)=0.9973, Test_tpAUC(0.3,0.7)=0.3559, lr=0.0000 Epoch=42, Loss=0.0003, Train_tpAUC(0.3,0.7)=0.9985, Test_tpAUC(0.3,0.7)=0.3526, lr=0.0000 Epoch=43, Loss=0.0003, Train_tpAUC(0.3,0.7)=0.9986, Test_tpAUC(0.3,0.7)=0.3488, lr=0.0000 Epoch=44, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9987, Test_tpAUC(0.3,0.7)=0.3589, lr=0.0000 Epoch=45, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9990, Test_tpAUC(0.3,0.7)=0.3578, lr=0.0000 Epoch=46, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9990, Test_tpAUC(0.3,0.7)=0.3590, lr=0.0000 Epoch=47, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9993, Test_tpAUC(0.3,0.7)=0.3574, lr=0.0000 Epoch=48, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3645, lr=0.0000 Epoch=49, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3697, lr=0.0000 Epoch=50, Loss=0.0002, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3672, lr=0.0000 Epoch=51, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3669, lr=0.0000 Epoch=52, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3709, lr=0.0000 Epoch=53, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9996, Test_tpAUC(0.3,0.7)=0.3621, lr=0.0000 Epoch=54, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9997, Test_tpAUC(0.3,0.7)=0.3673, lr=0.0000 Epoch=55, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3683, lr=0.0000 Epoch=56, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9997, Test_tpAUC(0.3,0.7)=0.3755, lr=0.0000 Epoch=57, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9998, Test_tpAUC(0.3,0.7)=0.3776, lr=0.0000 Epoch=58, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9996, Test_tpAUC(0.3,0.7)=0.3722, lr=0.0000 Epoch=59, Loss=0.0001, Train_tpAUC(0.3,0.7)=0.9999, Test_tpAUC(0.3,0.7)=0.3707, lr=0.0000 Visualization ----------------------- Now, let's see the change of two-way partial AUC scores on training and testing set and compare it with AUCM method. .. container:: cell code .. code:: python import matplotlib.pyplot as plt plt.rcParams["figure.figsize"] = (9,5) x=np.arange(60) aucm_tr_tpAUC = [0.0, 0.0036274632780175535, 0.005047971892709222, 0.16224877002854712, 0.08772204107278256, 0.09231572626711304, 0.2732654643710093, 0.26647582593555136, 0.24948327111467297, 0.3138958243403896, 0.41470878408333095, 0.43854233806173326, 0.44279415206018125, 0.41232459713094294, 0.42158369465786594, 0.28455608967939594, 0.47755676075725645, 0.4896781046120168, 0.3863189337375359, 0.41758365936793596, 0.842293465507402, 0.8849664436167597, 0.908744785930645, 0.925100663457797, 0.9326945885667315, 0.9383691079068582, 0.9493701298918441, 0.953469756067609, 0.959401363614275, 0.9611639035102166, 0.9687268876750703, 0.96784177567828, 0.9702642497143897, 0.9702752956047458, 0.9733664551144465, 0.9760006331554156, 0.9760579703978677, 0.9755799377033129, 0.9770459817114208, 0.9787550086264667, 0.9813317160205113, 0.9802740408912104, 0.9814502948096553, 0.9811205278462805, 0.9821393644867049, 0.9823378881296415, 0.9827693626631546, 0.9815768245053843, 0.9835002857701681, 0.9825182242961066, 0.9819347131316462, 0.9819089304938505, 0.9819178330190087, 0.9830069649230193, 0.982775401438961, 0.9835906895346361, 0.9843369096878757, 0.9844427527966737, 0.9829346774859711, 0.9838700118199921] aucm_te_tpAUC = [0.0, 0.000301777777777778, 0.0, 0.050526666666666664, 0.022042666666666665, 0.006600444444444444, 0.09903022222222221, 0.10774866666666666, 0.1029888888888889, 0.08559822222222221, 0.17839955555555553, 0.16736266666666666, 0.1591968888888889, 0.14970355555555553, 0.16401288888888887, 0.10874133333333333, 0.2060488888888889, 0.20340622222222224, 0.15917866666666666, 0.15321466666666667, 0.3240817777777778, 0.3328191111111111, 0.33990511111111105, 0.33390533333333333, 0.3470675555555556, 0.3300217777777778, 0.35229422222222223, 0.32868844444444445, 0.34444044444444444, 0.33212177777777785, 0.3557311111111111, 0.33471155555555554, 0.332584, 0.3388742222222223, 0.31508044444444444, 0.3285328888888889, 0.32913555555555557, 0.35162800000000005, 0.32234244444444443, 0.31436844444444445, 0.331578888888889, 0.3387742222222222, 0.3338368888888889, 0.3277377777777778, 0.335048, 0.3345755555555555, 0.34122044444444444, 0.3344044444444444, 0.33397066666666664, 0.32639111111111113, 0.3255946666666667, 0.3287395555555556, 0.3287168888888889, 0.33072799999999997, 0.3306648888888889, 0.33826266666666666, 0.3271982222222222, 0.3275493333333333, 0.33227422222222225, 0.33158488888888893] plt.figure() plt.plot(x, tr_tpAUC, linestyle='--', label='STACO train', linewidth=3) plt.plot(x, te_tpAUC, label='STACO test', linewidth=3) plt.plot(x, aucm_tr_tpAUC, linestyle='--', label='PESG train', linewidth=3) plt.plot(x, aucm_te_tpAUC, label='PESG test', linewidth=3) plt.title('CIFAR-10 (20% imbalanced)',fontsize=25) plt.legend(fontsize=15) plt.ylabel('TPAUC(0.7,0.3)',fontsize=25) plt.xlabel('epochs',fontsize=25) .. container:: output display_data .. image:: ./imgs/staco_tp.png