.. _pauc_sopas: ================================================================================================================================ Optimizing One-Way partial AUC on Imbalanced CIFAR10 Dataset (SOPAs) ================================================================================================================================ .. raw:: html
Run on Colab
Download Notebook
View on Github
------------------------------------------------------------------------------------ .. container:: cell markdown | **Author**: Gang Li, Zhuoning Yuan, Tianbao Yang | **Version**: 1.4.0 \ Introduction ----------------------- In this tutorial, you will learn how to quickly train a Resnet18 model by optimizing **One way Partial AUC (OPAUC)** with our novel :obj:`pAUC_DRO_Loss` and :obj:`SOPAs` optimizer `[ref] `__ on a binary image classification task with CIFAR-10 dataset. After completion of this tutorial, you should be able to use LibAUC to train your own models on your own datasets. **Reference**: If you find this tutorial helpful in your work, please cite our `library paper `__ and the following papers: .. code-block:: RST @inproceedings{zhu2022auc, title={When auc meets dro: Optimizing partial auc for deep learning with non-convex convergence guarantee}, author={Zhu, Dixian and Li, Gang and Wang, Bokun and Wu, Xiaodong and Yang, Tianbao}, booktitle={International Conference on Machine Learning}, pages={27548--27573}, year={2022}, organization={PMLR} } Install LibAUC ------------------------------------------------------------------------------------ Let's start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using ``pip install -U``. .. container:: cell code .. code:: python !pip install -U libauc Importing LibAUC ----------------------- Import required packages to use .. container:: cell code .. code:: python from libauc.losses import pAUC_DRO_Loss from libauc.optimizers import SOPAs from libauc.models import resnet18 as ResNet18 from libauc.datasets import CIFAR10 from libauc.utils import ImbalancedDataGenerator from libauc.sampler import DualSampler from libauc.metrics import auc_roc_score import torchvision.transforms as transforms from torch.utils.data import Dataset import numpy as np import torch from PIL import Image Reproducibility ----------------------- The following function ``set_all_seeds`` limits the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases `[Ref] `__. .. container:: cell code .. code:: python def set_all_seeds(SEED): # REPRODUCIBILITY np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_all_seeds(2023) Loading datasets ----------------------- .. container:: cell markdown In this step, we will use the `CIFAR10 `__ as benchmark dataset. Before importing data to ``dataloader``, we construct imbalanced version for CIFAR10 by ``ImbalanceDataGenerator``. Specifically, it first randomly splits the training data by class ID (e.g., 10 classes) into two even portions as the positive and negative classes, and then it randomly removes some samples from the positive class to make it imbalanced. We refer ``imratio`` to the ratio of number of positive examples to number of all examples. .. container:: cell code .. code:: python train_data, train_targets = CIFAR10(root='./data', train=True).as_array() test_data, test_targets = CIFAR10(root='./data', train=False).as_array() imratio = 0.2 ##we set the imratio as 0.2 here. generator = ImbalancedDataGenerator(verbose=True, random_seed=2023) (train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio) (test_images, test_labels) = generator.transform(test_data, test_targets, imratio=imratio) .. container:: cell markdown We define the data input pipeline such as data augmentations. In this tutorial, we use ``RandomCrop``, ``RandomHorizontalFlip``. .. container:: cell code .. code:: python class ImageDataset(Dataset): def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'): self.images = images.astype(np.uint8) self.targets = targets self.mode = mode self.transform_train = transforms.Compose([ transforms.ToTensor(), transforms.RandomCrop((crop_size, crop_size), padding=None), transforms.RandomHorizontalFlip(), transforms.Resize((image_size, image_size), antialias=True), ]) self.transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Resize((image_size, image_size), antialias=True), ]) def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] target = self.targets[idx] image = Image.fromarray(image.astype('uint8')) if self.mode == 'train': image = self.transform_train(image) else: image = self.transform_test(image) return image, target, idx .. container:: cell code .. code:: python batch_size = 64 trainSet = ImageDataset(train_images, train_labels) trainSet_eval = ImageDataset(train_images, train_labels,mode='test') testSet = ImageDataset(test_images, test_labels, mode='test') Configuration ----------------------- .. container:: cell markdown Hyper-Parameters .. container:: cell code .. code:: python lr = 1e-3 margin = 0.6 gamma = 0.1 Lambda = 1.0 weight_decay = 2e-4 total_epoch = 60 decay_epoch = [30, 45] load_pretrain = False Pretraining (Recommended) ----------------------- .. container:: cell markdown Following the original `paper `__, it's recommended to start from a pretrained checkpoint with cross-entropy loss to significantly boost models' performance. It includes a pre-training step with standard cross-entropy loss, and a Partial AUC maximization step that maximizes a Partial AUC surrogate loss of the pre-trained model. .. container:: cell code .. code:: python from torch.optim import Adam import warnings warnings.filterwarnings('ignore') load_pretrain = True model = ResNet18(pretrained=False, last_activation=None, num_classes=1) model = model.cuda() loss_fn = torch.nn.BCELoss() optimizer =Adam(model.parameters(), lr=lr, weight_decay=weight_decay) trainloader = torch.utils.data.DataLoader(trainSet, batch_size=batch_size, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testSet, batch_size=batch_size, shuffle=False, num_workers=2) best_test = 0 for epoch in range(total_epoch): if epoch in decay_epoch: for param_group in optimizer.param_groups: param_group['lr'] = 0.1 * param_group['lr'] model.train() for idx, (data, targets, index) in enumerate(trainloader): data, targets, index = data.cuda(), targets.cuda(), index.cuda() y_pred = model(data) y_prob = torch.sigmoid(y_pred) loss = loss_fn(y_prob, targets) optimizer.zero_grad() loss.backward() optimizer.step() ######***evaluation***#### # evaluation on test sets model.eval() test_pred_list, test_true_list = [], [] with torch.no_grad(): for j, data in enumerate(testloader): test_data, test_targets, _ = data test_data = test_data.cuda() y_pred = model(test_data) y_prob = torch.sigmoid(y_pred) test_pred_list.append(y_prob.cpu().detach().numpy()) test_true_list.append(test_targets.numpy()) test_true = np.concatenate(test_true_list) test_pred = np.concatenate(test_pred_list) test_pauc = auc_roc_score(test_true, test_pred,max_fpr=0.3) if best_test < test_pauc: best_test = test_pauc torch.save(model.state_dict(), 'ce_pretrained_model_sopas.pth') model.train() print("epoch: %s, test_pauc: %.4f, best_test_pauc: %.4f, lr: %.4f"%(epoch, test_pauc, best_test, optimizer.param_groups[0]['lr'] )) .. container:: output stream stdout :: epoch: 0, test_pauc: 0.6030, best_test_pauc: 0.6030, lr: 0.0010 epoch: 1, test_pauc: 0.6749, best_test_pauc: 0.6749, lr: 0.0010 epoch: 2, test_pauc: 0.6666, best_test_pauc: 0.6749, lr: 0.0010 epoch: 3, test_pauc: 0.6917, best_test_pauc: 0.6917, lr: 0.0010 epoch: 4, test_pauc: 0.6525, best_test_pauc: 0.6917, lr: 0.0010 epoch: 5, test_pauc: 0.7070, best_test_pauc: 0.7070, lr: 0.0010 epoch: 6, test_pauc: 0.7555, best_test_pauc: 0.7555, lr: 0.0010 epoch: 7, test_pauc: 0.7448, best_test_pauc: 0.7555, lr: 0.0010 epoch: 8, test_pauc: 0.7583, best_test_pauc: 0.7583, lr: 0.0010 epoch: 9, test_pauc: 0.7545, best_test_pauc: 0.7583, lr: 0.0010 epoch: 10, test_pauc: 0.7258, best_test_pauc: 0.7583, lr: 0.0010 epoch: 11, test_pauc: 0.7956, best_test_pauc: 0.7956, lr: 0.0010 epoch: 12, test_pauc: 0.8150, best_test_pauc: 0.8150, lr: 0.0010 epoch: 13, test_pauc: 0.8036, best_test_pauc: 0.8150, lr: 0.0010 epoch: 14, test_pauc: 0.7952, best_test_pauc: 0.8150, lr: 0.0010 epoch: 15, test_pauc: 0.8003, best_test_pauc: 0.8150, lr: 0.0010 epoch: 16, test_pauc: 0.7778, best_test_pauc: 0.8150, lr: 0.0010 epoch: 17, test_pauc: 0.8249, best_test_pauc: 0.8249, lr: 0.0010 epoch: 18, test_pauc: 0.8038, best_test_pauc: 0.8249, lr: 0.0010 epoch: 19, test_pauc: 0.8295, best_test_pauc: 0.8295, lr: 0.0010 epoch: 20, test_pauc: 0.8236, best_test_pauc: 0.8295, lr: 0.0010 epoch: 21, test_pauc: 0.8370, best_test_pauc: 0.8370, lr: 0.0010 epoch: 22, test_pauc: 0.8464, best_test_pauc: 0.8464, lr: 0.0010 epoch: 23, test_pauc: 0.8088, best_test_pauc: 0.8464, lr: 0.0010 epoch: 24, test_pauc: 0.8337, best_test_pauc: 0.8464, lr: 0.0010 epoch: 25, test_pauc: 0.8365, best_test_pauc: 0.8464, lr: 0.0010 epoch: 26, test_pauc: 0.8448, best_test_pauc: 0.8464, lr: 0.0010 epoch: 27, test_pauc: 0.8076, best_test_pauc: 0.8464, lr: 0.0010 epoch: 28, test_pauc: 0.8388, best_test_pauc: 0.8464, lr: 0.0010 epoch: 29, test_pauc: 0.8477, best_test_pauc: 0.8477, lr: 0.0010 epoch: 30, test_pauc: 0.8617, best_test_pauc: 0.8617, lr: 0.0001 epoch: 31, test_pauc: 0.8637, best_test_pauc: 0.8637, lr: 0.0001 epoch: 32, test_pauc: 0.8625, best_test_pauc: 0.8637, lr: 0.0001 epoch: 33, test_pauc: 0.8640, best_test_pauc: 0.8640, lr: 0.0001 epoch: 34, test_pauc: 0.8611, best_test_pauc: 0.8640, lr: 0.0001 epoch: 35, test_pauc: 0.8651, best_test_pauc: 0.8651, lr: 0.0001 epoch: 36, test_pauc: 0.8583, best_test_pauc: 0.8651, lr: 0.0001 epoch: 37, test_pauc: 0.8665, best_test_pauc: 0.8665, lr: 0.0001 epoch: 38, test_pauc: 0.8643, best_test_pauc: 0.8665, lr: 0.0001 epoch: 39, test_pauc: 0.8634, best_test_pauc: 0.8665, lr: 0.0001 epoch: 40, test_pauc: 0.8619, best_test_pauc: 0.8665, lr: 0.0001 epoch: 41, test_pauc: 0.8612, best_test_pauc: 0.8665, lr: 0.0001 epoch: 42, test_pauc: 0.8610, best_test_pauc: 0.8665, lr: 0.0001 epoch: 43, test_pauc: 0.8600, best_test_pauc: 0.8665, lr: 0.0001 epoch: 44, test_pauc: 0.8612, best_test_pauc: 0.8665, lr: 0.0001 epoch: 45, test_pauc: 0.8629, best_test_pauc: 0.8665, lr: 0.0000 epoch: 46, test_pauc: 0.8630, best_test_pauc: 0.8665, lr: 0.0000 epoch: 47, test_pauc: 0.8625, best_test_pauc: 0.8665, lr: 0.0000 epoch: 48, test_pauc: 0.8627, best_test_pauc: 0.8665, lr: 0.0000 epoch: 49, test_pauc: 0.8642, best_test_pauc: 0.8665, lr: 0.0000 epoch: 50, test_pauc: 0.8619, best_test_pauc: 0.8665, lr: 0.0000 epoch: 51, test_pauc: 0.8629, best_test_pauc: 0.8665, lr: 0.0000 epoch: 52, test_pauc: 0.8627, best_test_pauc: 0.8665, lr: 0.0000 epoch: 53, test_pauc: 0.8637, best_test_pauc: 0.8665, lr: 0.0000 epoch: 54, test_pauc: 0.8598, best_test_pauc: 0.8665, lr: 0.0000 epoch: 55, test_pauc: 0.8605, best_test_pauc: 0.8665, lr: 0.0000 epoch: 56, test_pauc: 0.8606, best_test_pauc: 0.8665, lr: 0.0000 epoch: 57, test_pauc: 0.8622, best_test_pauc: 0.8665, lr: 0.0000 epoch: 58, test_pauc: 0.8618, best_test_pauc: 0.8665, lr: 0.0000 epoch: 59, test_pauc: 0.8610, best_test_pauc: 0.8665, lr: 0.0000 Optimizing pAUC Loss with SOPAs ----------------------- .. container:: cell markdown We define ``dataset``, ``DualSampler`` and ``dataloader`` here. By default, we use ``batch_size`` 64 and we oversample the minority class with ``pos:neg=1:1`` by setting ``sampling_rate=0.5``. .. container:: cell code .. code:: python sampling_rate = 0.5 sampler = DualSampler(trainSet, batch_size, sampling_rate=sampling_rate) trainloader = torch.utils.data.DataLoader(trainSet, batch_size=batch_size, sampler=sampler, num_workers=2) trainloader_eval = torch.utils.data.DataLoader(trainSet_eval, batch_size=batch_size, shuffle=False, num_workers=2) testloader = torch.utils.data.DataLoader(testSet, batch_size=batch_size, shuffle=False, num_workers=2) Model, Loss and Optimizer ----------------------- .. container:: cell code .. code:: python model = ResNet18(pretrained=False, last_activation=None, num_classes=1) model = model.cuda() # load pretrained model if load_pretrain: PATH = 'ce_pretrained_model_sopas.pth' state_dict = torch.load(PATH) filtered = {k:v for k,v in state_dict.items() if 'fc' not in k} msg = model.load_state_dict(filtered, False) print(msg) model.fc.reset_parameters() loss_fn = pAUC_DRO_Loss(data_len=len(trainSet), margin=margin, gamma=gamma) optimizer = SOPAs(model.parameters(), lr=lr, mode='adam', weight_decay=weight_decay) Training ----------------------- .. container:: cell markdown Now it's time for training. And we evaluate partial AUC performance with False Positive Rate(FPR) less than or equal to 0.3, i.e., FPR ≤ 0.3. .. container:: cell code .. code:: python print ('Start Training') print ('-'*30) train_log, test_log = [], [] best_test = 0 for epoch in range(total_epoch): if epoch in decay_epoch: optimizer.update_lr(decay_factor=10) train_loss = [] model.train() for idx, (data, targets, index) in enumerate(trainloader): data, targets, index = data.cuda(), targets.cuda(), index.cuda() y_pred = model(data) y_prob = torch.sigmoid(y_pred) loss = loss_fn(y_prob, targets, index) optimizer.zero_grad() loss.backward() optimizer.step() train_loss.append(loss.item()) ######***evaluation***#### # evaluation on training sets model.eval() train_pred_list, train_true_list = [], [] with torch.no_grad(): for i, data in enumerate(trainloader_eval): train_data, train_targets, _ = data train_data = train_data.cuda() y_pred = model(train_data) y_prob = torch.sigmoid(y_pred) train_pred_list.append(y_prob.cpu().detach().numpy()) train_true_list.append(train_targets.cpu().detach().numpy()) train_true = np.concatenate(train_true_list) train_pred = np.concatenate(train_pred_list) train_pauc = auc_roc_score(train_true, train_pred, max_fpr=0.3) train_loss = np.mean(train_loss) train_log.append(train_pauc) # evaluation on test sets model.eval() test_pred_list, test_true_list = [], [] with torch.no_grad(): for j, data in enumerate(testloader): test_data, test_targets, _ = data test_data = test_data.cuda() y_pred = model(test_data) y_prob = torch.sigmoid(y_pred) test_pred_list.append(y_prob.cpu().detach().numpy()) test_true_list.append(test_targets.numpy()) test_true = np.concatenate(test_true_list) test_pred = np.concatenate(test_pred_list) test_pauc = auc_roc_score(test_true, test_pred,max_fpr=0.3) test_log.append(test_pauc) if best_test < test_pauc: best_test = test_pauc model.train() # print results print("epoch: %s, train_loss: %.4f, train_pauc: %.4f, test_pauc: %.4f, best_test_pauc: %.4f, lr: %.5f"%(epoch, train_loss, train_pauc, test_pauc, best_test, optimizer.lr )) .. container:: output stream stdout :: Start Training ------------------------------ epoch: 0, train_loss: 0.3522, train_pauc: 0.9630, test_pauc: 0.8541, best_test_pauc: 0.8541, lr: 0.00100 epoch: 1, train_loss: 0.0697, train_pauc: 0.9768, test_pauc: 0.8649, best_test_pauc: 0.8649, lr: 0.00100 epoch: 2, train_loss: 0.0506, train_pauc: 0.9738, test_pauc: 0.8564, best_test_pauc: 0.8649, lr: 0.00100 epoch: 3, train_loss: 0.0469, train_pauc: 0.9654, test_pauc: 0.8444, best_test_pauc: 0.8649, lr: 0.00100 epoch: 4, train_loss: 0.0520, train_pauc: 0.9526, test_pauc: 0.8472, best_test_pauc: 0.8649, lr: 0.00100 epoch: 5, train_loss: 0.0553, train_pauc: 0.9336, test_pauc: 0.8261, best_test_pauc: 0.8649, lr: 0.00100 epoch: 6, train_loss: 0.0549, train_pauc: 0.9498, test_pauc: 0.8424, best_test_pauc: 0.8649, lr: 0.00100 epoch: 7, train_loss: 0.0581, train_pauc: 0.9371, test_pauc: 0.8402, best_test_pauc: 0.8649, lr: 0.00100 epoch: 8, train_loss: 0.0574, train_pauc: 0.9286, test_pauc: 0.8284, best_test_pauc: 0.8649, lr: 0.00100 epoch: 9, train_loss: 0.0565, train_pauc: 0.9347, test_pauc: 0.8321, best_test_pauc: 0.8649, lr: 0.00100 epoch: 10, train_loss: 0.0549, train_pauc: 0.9601, test_pauc: 0.8605, best_test_pauc: 0.8649, lr: 0.00100 epoch: 11, train_loss: 0.0561, train_pauc: 0.9535, test_pauc: 0.8406, best_test_pauc: 0.8649, lr: 0.00100 epoch: 12, train_loss: 0.0580, train_pauc: 0.9267, test_pauc: 0.8260, best_test_pauc: 0.8649, lr: 0.00100 epoch: 13, train_loss: 0.0533, train_pauc: 0.9428, test_pauc: 0.8412, best_test_pauc: 0.8649, lr: 0.00100 epoch: 14, train_loss: 0.0559, train_pauc: 0.9352, test_pauc: 0.8281, best_test_pauc: 0.8649, lr: 0.00100 epoch: 15, train_loss: 0.0543, train_pauc: 0.9522, test_pauc: 0.8451, best_test_pauc: 0.8649, lr: 0.00100 epoch: 16, train_loss: 0.0540, train_pauc: 0.9401, test_pauc: 0.8430, best_test_pauc: 0.8649, lr: 0.00100 epoch: 17, train_loss: 0.0512, train_pauc: 0.9563, test_pauc: 0.8546, best_test_pauc: 0.8649, lr: 0.00100 epoch: 18, train_loss: 0.0546, train_pauc: 0.9514, test_pauc: 0.8495, best_test_pauc: 0.8649, lr: 0.00100 epoch: 19, train_loss: 0.0530, train_pauc: 0.9581, test_pauc: 0.8441, best_test_pauc: 0.8649, lr: 0.00100 epoch: 20, train_loss: 0.0512, train_pauc: 0.9253, test_pauc: 0.8182, best_test_pauc: 0.8649, lr: 0.00100 epoch: 21, train_loss: 0.0483, train_pauc: 0.9604, test_pauc: 0.8523, best_test_pauc: 0.8649, lr: 0.00100 epoch: 22, train_loss: 0.0519, train_pauc: 0.9605, test_pauc: 0.8637, best_test_pauc: 0.8649, lr: 0.00100 epoch: 23, train_loss: 0.0521, train_pauc: 0.9546, test_pauc: 0.8557, best_test_pauc: 0.8649, lr: 0.00100 epoch: 24, train_loss: 0.0522, train_pauc: 0.9569, test_pauc: 0.8489, best_test_pauc: 0.8649, lr: 0.00100 epoch: 25, train_loss: 0.0479, train_pauc: 0.9481, test_pauc: 0.8423, best_test_pauc: 0.8649, lr: 0.00100 epoch: 26, train_loss: 0.0504, train_pauc: 0.9369, test_pauc: 0.8345, best_test_pauc: 0.8649, lr: 0.00100 epoch: 27, train_loss: 0.0487, train_pauc: 0.9485, test_pauc: 0.8392, best_test_pauc: 0.8649, lr: 0.00100 epoch: 28, train_loss: 0.0509, train_pauc: 0.9250, test_pauc: 0.8219, best_test_pauc: 0.8649, lr: 0.00100 epoch: 29, train_loss: 0.0488, train_pauc: 0.9502, test_pauc: 0.8422, best_test_pauc: 0.8649, lr: 0.00100 Reducing learning rate to 0.00010 @ T=23430! epoch: 30, train_loss: 0.0240, train_pauc: 0.9849, test_pauc: 0.8724, best_test_pauc: 0.8724, lr: 0.00010 epoch: 31, train_loss: 0.0153, train_pauc: 0.9880, test_pauc: 0.8733, best_test_pauc: 0.8733, lr: 0.00010 epoch: 32, train_loss: 0.0117, train_pauc: 0.9906, test_pauc: 0.8746, best_test_pauc: 0.8746, lr: 0.00010 epoch: 33, train_loss: 0.0100, train_pauc: 0.9914, test_pauc: 0.8739, best_test_pauc: 0.8746, lr: 0.00010 epoch: 34, train_loss: 0.0087, train_pauc: 0.9930, test_pauc: 0.8735, best_test_pauc: 0.8746, lr: 0.00010 epoch: 35, train_loss: 0.0082, train_pauc: 0.9927, test_pauc: 0.8746, best_test_pauc: 0.8746, lr: 0.00010 epoch: 36, train_loss: 0.0069, train_pauc: 0.9925, test_pauc: 0.8687, best_test_pauc: 0.8746, lr: 0.00010 epoch: 37, train_loss: 0.0059, train_pauc: 0.9939, test_pauc: 0.8746, best_test_pauc: 0.8746, lr: 0.00010 epoch: 38, train_loss: 0.0056, train_pauc: 0.9955, test_pauc: 0.8726, best_test_pauc: 0.8746, lr: 0.00010 epoch: 39, train_loss: 0.0051, train_pauc: 0.9939, test_pauc: 0.8707, best_test_pauc: 0.8746, lr: 0.00010 epoch: 40, train_loss: 0.0048, train_pauc: 0.9946, test_pauc: 0.8727, best_test_pauc: 0.8746, lr: 0.00010 epoch: 41, train_loss: 0.0048, train_pauc: 0.9954, test_pauc: 0.8733, best_test_pauc: 0.8746, lr: 0.00010 epoch: 42, train_loss: 0.0041, train_pauc: 0.9963, test_pauc: 0.8803, best_test_pauc: 0.8803, lr: 0.00010 epoch: 43, train_loss: 0.0041, train_pauc: 0.9952, test_pauc: 0.8714, best_test_pauc: 0.8803, lr: 0.00010 epoch: 44, train_loss: 0.0037, train_pauc: 0.9964, test_pauc: 0.8760, best_test_pauc: 0.8803, lr: 0.00010 Reducing learning rate to 0.00001 @ T=35145! epoch: 45, train_loss: 0.0032, train_pauc: 0.9968, test_pauc: 0.8773, best_test_pauc: 0.8803, lr: 0.00001 epoch: 46, train_loss: 0.0027, train_pauc: 0.9970, test_pauc: 0.8776, best_test_pauc: 0.8803, lr: 0.00001 epoch: 47, train_loss: 0.0024, train_pauc: 0.9973, test_pauc: 0.8782, best_test_pauc: 0.8803, lr: 0.00001 epoch: 48, train_loss: 0.0022, train_pauc: 0.9966, test_pauc: 0.8763, best_test_pauc: 0.8803, lr: 0.00001 epoch: 49, train_loss: 0.0022, train_pauc: 0.9973, test_pauc: 0.8776, best_test_pauc: 0.8803, lr: 0.00001 epoch: 50, train_loss: 0.0019, train_pauc: 0.9970, test_pauc: 0.8769, best_test_pauc: 0.8803, lr: 0.00001 epoch: 51, train_loss: 0.0022, train_pauc: 0.9975, test_pauc: 0.8777, best_test_pauc: 0.8803, lr: 0.00001 epoch: 52, train_loss: 0.0018, train_pauc: 0.9970, test_pauc: 0.8773, best_test_pauc: 0.8803, lr: 0.00001 epoch: 53, train_loss: 0.0019, train_pauc: 0.9970, test_pauc: 0.8763, best_test_pauc: 0.8803, lr: 0.00001 epoch: 54, train_loss: 0.0019, train_pauc: 0.9975, test_pauc: 0.8776, best_test_pauc: 0.8803, lr: 0.00001 epoch: 55, train_loss: 0.0018, train_pauc: 0.9974, test_pauc: 0.8775, best_test_pauc: 0.8803, lr: 0.00001 epoch: 56, train_loss: 0.0019, train_pauc: 0.9973, test_pauc: 0.8778, best_test_pauc: 0.8803, lr: 0.00001 epoch: 57, train_loss: 0.0016, train_pauc: 0.9978, test_pauc: 0.8787, best_test_pauc: 0.8803, lr: 0.00001 epoch: 58, train_loss: 0.0018, train_pauc: 0.9974, test_pauc: 0.8767, best_test_pauc: 0.8803, lr: 0.00001 epoch: 59, train_loss: 0.0015, train_pauc: 0.9975, test_pauc: 0.8758, best_test_pauc: 0.8803, lr: 0.00001 Visualization ----------------------- Now, let's see the learning curves for optimizing pAUC from scratch and from a pretrained model with cross entropy loss. .. container:: cell code .. code:: python import matplotlib.pyplot as plt import numpy as np train_log= [0.9629983121568626, 0.976789697254902, 0.9737728125490197, 0.9653674415686273, 0.9526080000000001, 0.9336441725490195, 0.9497985317647059, 0.9371098352941176, 0.9286405458823529, 0.9346945380392156, 0.9601394321568628, 0.953512702745098, 0.9266596831372549, 0.9428325396078432, 0.9351774494117648, 0.9522345788235294, 0.9401062211764706, 0.9563442384313725, 0.9514192501960784, 0.9580526368627451, 0.9253358619607843, 0.9603840376470587, 0.9605023058823529, 0.9546136219607844, 0.9568967592156863, 0.9480975999999999, 0.9369100862745098, 0.9484633725490195, 0.9249502117647059, 0.9502059231372548, 0.9849486870588235, 0.9879563984313726, 0.9905859952941177, 0.9913526462745099, 0.9930277772549019, 0.9927193474509803, 0.9925047341176472, 0.9938547639215686, 0.9954592501960785, 0.993869954509804, 0.9945825756862745, 0.995418591372549, 0.996329662745098, 0.9951745380392156, 0.9963752847058824, 0.9967506447058823, 0.9969702023529412, 0.9972820956862744, 0.9966350243137254, 0.9973091764705881, 0.9970341521568626, 0.9974569160784313, 0.9970073286274509, 0.997043074509804, 0.9975259356862745, 0.9974475294117646, 0.9973477584313726, 0.9977820172549019, 0.9974407278431374, 0.9975082290196078] test_log = [0.8540740392156863, 0.8649267450980391, 0.8564147450980393, 0.8444150588235293, 0.8471686274509804, 0.826121725490196, 0.8424131764705882, 0.8401847843137256, 0.8284373333333332, 0.8320996078431373, 0.8604749803921569, 0.840553725490196, 0.8259722352941177, 0.8412304313725489, 0.8280956862745098, 0.8450829803921568, 0.843001725490196, 0.854564862745098, 0.8495196862745098, 0.8440784313725489, 0.8182186666666666, 0.8522569411764706, 0.8636740392156863, 0.8557358431372548, 0.8489019607843137, 0.8422538039215686, 0.8344840784313725, 0.8391745882352941, 0.8218570980392157, 0.8421825882352941, 0.8723526274509803, 0.8732980392156863, 0.8746343529411764, 0.8739411764705883, 0.8735287843137255, 0.8745794509803921, 0.8686682352941176, 0.8746359215686275, 0.8726382745098038, 0.8706779607843138, 0.8727118431372549, 0.8732894117647059, 0.880304156862745, 0.8714334117647058, 0.8760409411764706, 0.8772798431372548, 0.8776307450980392, 0.8781731764705882, 0.8763093333333334, 0.8776247843137255, 0.8768696470588234, 0.8777276862745098, 0.8773055686274509, 0.8763356862745098, 0.877628862745098, 0.8775463529411764, 0.8778290196078431, 0.8787496470588235, 0.876702431372549, 0.8758170980392157] train_log_scratch = [0.6890966588235294, 0.7498066949019607, 0.77299424, 0.7538028737254903, 0.7556846180392157, 0.8407256533333333, 0.8185373364705881, 0.8357328627450981, 0.8290430870588235, 0.8706400188235295, 0.8983218823529411, 0.8901459952941176, 0.8762650792156863, 0.8643537192156863, 0.8762003952941176, 0.9327183498039215, 0.9058492549019608, 0.9225154133333333, 0.9391776564705883, 0.9153873443137255, 0.9427768156862745, 0.918611105882353, 0.9488297160784314, 0.9422692141176472, 0.9397862211764705, 0.9447884047058823, 0.9376482133333333, 0.9170201662745099, 0.9400590368627451, 0.9428715294117647, 0.9827926776470589, 0.9869608282352941, 0.9892946447058824, 0.9910100705882354, 0.9913128219607843, 0.993834779607843, 0.9938273317647058, 0.9955538698039215, 0.9949663435294117, 0.9957406619607845, 0.9957660737254902, 0.9957580674509803, 0.9957963294117649, 0.9963942901960784, 0.9965594917647058, 0.9969945600000001, 0.9972550839215686, 0.9974924862745098, 0.9969971074509804, 0.997616803137255, 0.9976887654901962, 0.9980113945098039, 0.9977375623529412, 0.9977079843137253, 0.9979535686274509, 0.998025951372549, 0.998142908235294, 0.9981611168627451, 0.9977764141176471, 0.9979765709803923] test_log_scratch = [0.6823946666666666, 0.7297479215686274, 0.747155137254902, 0.7293926274509804, 0.7213540392156863, 0.7886792156862745, 0.7819234509803922, 0.7881909019607842, 0.7723794509803921, 0.7988341960784313, 0.8359945098039215, 0.8195965490196079, 0.8108655686274511, 0.7986409411764706, 0.7982105098039216, 0.852886274509804, 0.819928156862745, 0.8454426666666666, 0.8503579607843137, 0.8292087843137255, 0.8498487843137255, 0.8227733333333334, 0.8496687058823529, 0.8344301176470588, 0.84, 0.8395469803921568, 0.8437207843137255, 0.8221675294117646, 0.8352611764705882, 0.8306180392156863, 0.8663220392156863, 0.8681480784313725, 0.8710854901960785, 0.8700506666666667, 0.8700886274509803, 0.8714505098039216, 0.8706748235294117, 0.8733140392156862, 0.873332705882353, 0.874576, 0.8714436078431371, 0.8704718431372549, 0.8702296470588236, 0.8724760784313725, 0.8709339607843137, 0.8711752156862745, 0.8721429019607843, 0.8720614901960784, 0.869857725490196, 0.8717300392156864, 0.8717584313725489, 0.8734641568627451, 0.8732724705882353, 0.8717474509803922, 0.8723347450980391, 0.8721515294117648, 0.8720542745098039, 0.8722836078431372, 0.87048, 0.8709046274509804] fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12,5)) plt.suptitle('CIFAR-10 (20% imbalanced)',fontsize=25) x=np.arange(len(train_log)) ax0.plot(x, train_log_scratch, label='From Scratch', linewidth=3) ax0.plot(x, train_log, label='From Pretraining', linewidth=3) ax0.set_title('Training',fontsize=25) ax1.plot(x, test_log_scratch, label='From Scratch', linewidth=3) ax1.plot(x, test_log, label='From Pretraining', linewidth=3) ax1.set_title('Testing',fontsize=25) ax0.legend(fontsize=15) ax1.legend(fontsize=15) ax0.set_ylabel('pAUC(FPR≤0.3)', fontsize=20) ax0.set_xlabel('Epoch', fontsize=20) ax1.set_ylabel('pAUC(FPR≤0.3)', fontsize=20) ax1.set_xlabel('Epoch', fontsize=20) plt.tight_layout() .. container:: output display_data .. image:: ./imgs/sopas_p.png .. container:: cell markdown .. rubric:: **Comparison** Furthermore, we compare our performance in terms of pAUC with the `AUCM Loss`, which can directly optimize AUROC. For detail about `AUCM Loss`, please refer to :ref:`AUCM `. .. container:: output display_data .. image:: ./imgs/comparison_dro.png