================================ Optimizing AUCMLoss on Imbalanced CIFAR10 Dataset (PESG) ================================ .. raw:: html
Run on Colab
Download Notebook
View on Github
------------------------------------------------------------------------------------ .. container:: cell markdown **Author**: Zhuoning Yuan, Tianbao Yang Introduction ----------------------- In this tutorial, you will learn how to quickly train a ResNet20 model by optimizing **AUROC** using our novel :obj:`AUCMLoss` and :obj:`PESG` optimizer `[Ref] `__ on a binary image classification task on Cifar10. After completion of this tutorial, you should be able to use LibAUC to train your own models on your own datasets. **Reference**: If you find this tutorial helpful in your work, please cite our `library paper `__ and the following papers: .. code-block:: RST @inproceedings{yuan2021large, title={Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification}, author={Yuan, Zhuoning and Yan, Yan and Sonka, Milan and Yang, Tianbao}, booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, pages={3040--3049}, year={2021} } Install LibAUC ------------------------------------------------------------------------------------ Let's start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using ``pip install -U``. .. container:: cell code .. code:: python !pip install -U libauc Importing LibAUC ----------------------- Import required libraries to use .. container:: cell code .. code:: python from libauc.losses import AUCMLoss from libauc.optimizers import PESG from libauc.models import resnet20 as ResNet20 from libauc.datasets import CIFAR10 from libauc.utils import ImbalancedDataGenerator from libauc.sampler import DualSampler from libauc.metrics import auc_roc_score import torch from PIL import Image import numpy as np import torchvision.transforms as transforms from torch.utils.data import Dataset from sklearn.metrics import roc_auc_score Reproducibility ----------------------- These functions limit the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases `[Ref] `__. .. container:: cell code .. code:: python def set_all_seeds(SEED): # REPRODUCIBILITY torch.manual_seed(SEED) np.random.seed(SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False Image Dataset ----------------------- Now we define the data input pipeline such as data augmentations. In this tutorial, we use ``RandomCrop``, ``RandomHorizontalFlip``. .. container:: cell code .. code:: python class ImageDataset(Dataset): def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'): self.images = images.astype(np.uint8) self.targets = targets self.mode = mode self.transform_train = transforms.Compose([ transforms.ToTensor(), transforms.RandomCrop((crop_size, crop_size), padding=None), transforms.RandomHorizontalFlip(), transforms.Resize((image_size, image_size)), ]) self.transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Resize((image_size, image_size)), ]) def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] target = self.targets[idx] image = Image.fromarray(image.astype('uint8')) if self.mode == 'train': image = self.transform_train(image) else: image = self.transform_test(image) return image, target Hyper-parameters ----------------------- .. container:: cell code .. code:: python # HyperParameters SEED = 123 BATCH_SIZE = 128 imratio = 0.1 # for demo total_epochs = 100 decay_epochs = [50, 75] lr = 0.1 margin = 1.0 epoch_decay = 0.003 # refers gamma in the paper weight_decay = 0.0001 # oversampling minority class, you can tune it in (0, 0.5] # e.g., sampling_rate=0.2 is that num of positive samples in mini-batch is sampling_rate*batch_size=13 sampling_rate = 0.2 Loading datasets ----------------------- .. container:: cell code .. code:: python # load data as numpy arrays train_data, train_targets = CIFAR10(root='./data', train=True).as_array() test_data, test_targets = CIFAR10(root='./data', train=False).as_array() # generate imbalanced data generator = ImbalancedDataGenerator(verbose=True, random_seed=0) (train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio) (test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5) # data augmentations trainSet = ImageDataset(train_images, train_labels) trainSet_eval = ImageDataset(train_images, train_labels, mode='test') testSet = ImageDataset(test_images, test_labels, mode='test') # dataloaders sampler = DualSampler(trainSet, BATCH_SIZE, sampling_rate=sampling_rate) trainloader = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2) trainloader_eval = torch.utils.data.DataLoader(trainSet_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) Model, Loss & Optimizer ----------------------- .. container:: cell code .. code:: python # You can include sigmoid/l2 activations on model's outputs before computing loss model = ResNet20(pretrained=False, last_activation=None, num_classes=1) model = model.cuda() # You can also pass Loss.a, Loss.b, Loss.alpha to optimizer (for old version users) loss_fn = AUCMLoss() optimizer = PESG(model.parameters(), loss_fn=loss_fn, lr=lr, momentum=0.9, margin=margin, epoch_decay=epoch_decay, weight_decay=weight_decay) Training ----------------------- .. container:: cell code .. code:: python print ('Start Training') print ('-'*30) train_log = [] test_log = [] for epoch in range(total_epochs): if epoch in decay_epochs: optimizer.update_regularizer(decay_factor=10) # decrease learning rate by 10x & update regularizer train_loss = [] model.train() for data, targets in trainloader: data, targets = data.cuda(), targets.cuda() y_pred = model(data) y_pred = torch.sigmoid(y_pred) loss = loss_fn(y_pred, targets) optimizer.zero_grad() loss.backward() optimizer.step() train_loss.append(loss.item()) # evaluation on train & test sets model.eval() train_pred_list = [] train_true_list = [] for train_data, train_targets in trainloader_eval: train_data = train_data.cuda() train_pred = model(train_data) train_pred_list.append(train_pred.cpu().detach().numpy()) train_true_list.append(train_targets.numpy()) train_true = np.concatenate(train_true_list) train_pred = np.concatenate(train_pred_list) train_auc = auc_roc_score(train_true, train_pred) train_loss = np.mean(train_loss) test_pred_list = [] test_true_list = [] for test_data, test_targets in testloader: test_data = test_data.cuda() test_pred = model(test_data) test_pred_list.append(test_pred.cpu().detach().numpy()) test_true_list.append(test_targets.numpy()) test_true = np.concatenate(test_true_list) test_pred = np.concatenate(test_pred_list) val_auc = auc_roc_score(test_true, test_pred) model.train() # print results print("epoch: %s, train_loss: %.4f, train_auc: %.4f, test_auc: %.4f, lr: %.4f"%(epoch, train_loss, train_auc, val_auc, optimizer.lr )) train_log.append(train_auc) test_log.append(val_auc) .. container:: output stream stdout :: Start Training ------------------------------ epoch: 0, train_loss: 0.1447, train_auc: 0.6534, test_auc: 0.6479, lr: 0.1000 epoch: 1, train_loss: 0.1283, train_auc: 0.6918, test_auc: 0.6849, lr: 0.1000 epoch: 2, train_loss: 0.1194, train_auc: 0.6901, test_auc: 0.6885, lr: 0.1000 epoch: 3, train_loss: 0.1127, train_auc: 0.6964, test_auc: 0.6718, lr: 0.1000 epoch: 4, train_loss: 0.1064, train_auc: 0.7178, test_auc: 0.7023, lr: 0.1000 epoch: 5, train_loss: 0.1023, train_auc: 0.7654, test_auc: 0.7388, lr: 0.1000 epoch: 6, train_loss: 0.0972, train_auc: 0.8062, test_auc: 0.7748, lr: 0.1000 epoch: 7, train_loss: 0.0915, train_auc: 0.7813, test_auc: 0.7545, lr: 0.1000 epoch: 8, train_loss: 0.0875, train_auc: 0.8070, test_auc: 0.7834, lr: 0.1000 epoch: 9, train_loss: 0.0848, train_auc: 0.7982, test_auc: 0.7764, lr: 0.1000 epoch: 10, train_loss: 0.0813, train_auc: 0.8180, test_auc: 0.7883, lr: 0.1000 epoch: 11, train_loss: 0.0778, train_auc: 0.8375, test_auc: 0.8098, lr: 0.1000 epoch: 12, train_loss: 0.0745, train_auc: 0.8527, test_auc: 0.8148, lr: 0.1000 epoch: 13, train_loss: 0.0721, train_auc: 0.8615, test_auc: 0.8268, lr: 0.1000 epoch: 14, train_loss: 0.0697, train_auc: 0.8118, test_auc: 0.7781, lr: 0.1000 epoch: 15, train_loss: 0.0683, train_auc: 0.8657, test_auc: 0.8316, lr: 0.1000 epoch: 16, train_loss: 0.0655, train_auc: 0.8495, test_auc: 0.8084, lr: 0.1000 epoch: 17, train_loss: 0.0642, train_auc: 0.8664, test_auc: 0.8286, lr: 0.1000 epoch: 18, train_loss: 0.0627, train_auc: 0.8706, test_auc: 0.8383, lr: 0.1000 epoch: 19, train_loss: 0.0608, train_auc: 0.8465, test_auc: 0.8147, lr: 0.1000 epoch: 20, train_loss: 0.0589, train_auc: 0.8429, test_auc: 0.8053, lr: 0.1000 epoch: 21, train_loss: 0.0577, train_auc: 0.8858, test_auc: 0.8509, lr: 0.1000 epoch: 22, train_loss: 0.0562, train_auc: 0.7541, test_auc: 0.7374, lr: 0.1000 epoch: 23, train_loss: 0.0564, train_auc: 0.8896, test_auc: 0.8495, lr: 0.1000 epoch: 24, train_loss: 0.0548, train_auc: 0.9161, test_auc: 0.8745, lr: 0.1000 epoch: 25, train_loss: 0.0552, train_auc: 0.8962, test_auc: 0.8543, lr: 0.1000 epoch: 26, train_loss: 0.0537, train_auc: 0.8778, test_auc: 0.8356, lr: 0.1000 epoch: 27, train_loss: 0.0533, train_auc: 0.8778, test_auc: 0.8446, lr: 0.1000 epoch: 28, train_loss: 0.0524, train_auc: 0.9000, test_auc: 0.8614, lr: 0.1000 epoch: 29, train_loss: 0.0513, train_auc: 0.9135, test_auc: 0.8717, lr: 0.1000 epoch: 30, train_loss: 0.0505, train_auc: 0.9130, test_auc: 0.8703, lr: 0.1000 epoch: 31, train_loss: 0.0496, train_auc: 0.8591, test_auc: 0.8237, lr: 0.1000 epoch: 32, train_loss: 0.0489, train_auc: 0.8694, test_auc: 0.8343, lr: 0.1000 epoch: 33, train_loss: 0.0478, train_auc: 0.8602, test_auc: 0.8171, lr: 0.1000 epoch: 34, train_loss: 0.0469, train_auc: 0.8828, test_auc: 0.8412, lr: 0.1000 epoch: 35, train_loss: 0.0468, train_auc: 0.8995, test_auc: 0.8604, lr: 0.1000 epoch: 36, train_loss: 0.0473, train_auc: 0.9174, test_auc: 0.8756, lr: 0.1000 epoch: 37, train_loss: 0.0466, train_auc: 0.8961, test_auc: 0.8504, lr: 0.1000 epoch: 38, train_loss: 0.0459, train_auc: 0.8932, test_auc: 0.8485, lr: 0.1000 epoch: 39, train_loss: 0.0443, train_auc: 0.8867, test_auc: 0.8414, lr: 0.1000 epoch: 40, train_loss: 0.0450, train_auc: 0.9071, test_auc: 0.8611, lr: 0.1000 epoch: 41, train_loss: 0.0438, train_auc: 0.8573, test_auc: 0.8100, lr: 0.1000 epoch: 42, train_loss: 0.0441, train_auc: 0.8667, test_auc: 0.8213, lr: 0.1000 epoch: 43, train_loss: 0.0429, train_auc: 0.9191, test_auc: 0.8803, lr: 0.1000 epoch: 44, train_loss: 0.0440, train_auc: 0.9014, test_auc: 0.8563, lr: 0.1000 epoch: 45, train_loss: 0.0426, train_auc: 0.8835, test_auc: 0.8448, lr: 0.1000 epoch: 46, train_loss: 0.0412, train_auc: 0.9271, test_auc: 0.8810, lr: 0.1000 epoch: 47, train_loss: 0.0419, train_auc: 0.9306, test_auc: 0.8867, lr: 0.1000 epoch: 48, train_loss: 0.0413, train_auc: 0.9173, test_auc: 0.8681, lr: 0.1000 epoch: 49, train_loss: 0.0425, train_auc: 0.9144, test_auc: 0.8706, lr: 0.1000 Reducing learning rate to 0.01000 @ T=12100! Updating regularizer @ T=12100! epoch: 50, train_loss: 0.0274, train_auc: 0.9614, test_auc: 0.9100, lr: 0.0100 epoch: 51, train_loss: 0.0216, train_auc: 0.9663, test_auc: 0.9131, lr: 0.0100 epoch: 52, train_loss: 0.0196, train_auc: 0.9674, test_auc: 0.9108, lr: 0.0100 epoch: 53, train_loss: 0.0185, train_auc: 0.9677, test_auc: 0.9103, lr: 0.0100 epoch: 54, train_loss: 0.0173, train_auc: 0.9708, test_auc: 0.9111, lr: 0.0100 epoch: 55, train_loss: 0.0162, train_auc: 0.9714, test_auc: 0.9106, lr: 0.0100 epoch: 56, train_loss: 0.0148, train_auc: 0.9738, test_auc: 0.9131, lr: 0.0100 epoch: 57, train_loss: 0.0150, train_auc: 0.9751, test_auc: 0.9131, lr: 0.0100 epoch: 58, train_loss: 0.0139, train_auc: 0.9721, test_auc: 0.9068, lr: 0.0100 epoch: 59, train_loss: 0.0129, train_auc: 0.9786, test_auc: 0.9152, lr: 0.0100 epoch: 60, train_loss: 0.0129, train_auc: 0.9769, test_auc: 0.9114, lr: 0.0100 epoch: 61, train_loss: 0.0125, train_auc: 0.9764, test_auc: 0.9094, lr: 0.0100 epoch: 62, train_loss: 0.0116, train_auc: 0.9772, test_auc: 0.9086, lr: 0.0100 epoch: 63, train_loss: 0.0117, train_auc: 0.9789, test_auc: 0.9120, lr: 0.0100 epoch: 64, train_loss: 0.0111, train_auc: 0.9789, test_auc: 0.9113, lr: 0.0100 epoch: 65, train_loss: 0.0103, train_auc: 0.9798, test_auc: 0.9096, lr: 0.0100 epoch: 66, train_loss: 0.0102, train_auc: 0.9801, test_auc: 0.9085, lr: 0.0100 epoch: 67, train_loss: 0.0100, train_auc: 0.9815, test_auc: 0.9138, lr: 0.0100 epoch: 68, train_loss: 0.0102, train_auc: 0.9804, test_auc: 0.9077, lr: 0.0100 epoch: 69, train_loss: 0.0094, train_auc: 0.9810, test_auc: 0.9090, lr: 0.0100 epoch: 70, train_loss: 0.0092, train_auc: 0.9814, test_auc: 0.9070, lr: 0.0100 epoch: 71, train_loss: 0.0092, train_auc: 0.9815, test_auc: 0.9079, lr: 0.0100 epoch: 72, train_loss: 0.0085, train_auc: 0.9809, test_auc: 0.9075, lr: 0.0100 epoch: 73, train_loss: 0.0083, train_auc: 0.9817, test_auc: 0.9061, lr: 0.0100 epoch: 74, train_loss: 0.0084, train_auc: 0.9810, test_auc: 0.9044, lr: 0.0100 Reducing learning rate to 0.00100 @ T=18150! Updating regularizer @ T=18150! epoch: 75, train_loss: 0.0075, train_auc: 0.9833, test_auc: 0.9076, lr: 0.0010 epoch: 76, train_loss: 0.0070, train_auc: 0.9838, test_auc: 0.9074, lr: 0.0010 epoch: 77, train_loss: 0.0070, train_auc: 0.9834, test_auc: 0.9064, lr: 0.0010 epoch: 78, train_loss: 0.0066, train_auc: 0.9844, test_auc: 0.9082, lr: 0.0010 epoch: 79, train_loss: 0.0067, train_auc: 0.9837, test_auc: 0.9061, lr: 0.0010 epoch: 80, train_loss: 0.0069, train_auc: 0.9840, test_auc: 0.9058, lr: 0.0010 epoch: 81, train_loss: 0.0071, train_auc: 0.9840, test_auc: 0.9067, lr: 0.0010 epoch: 82, train_loss: 0.0069, train_auc: 0.9841, test_auc: 0.9053, lr: 0.0010 epoch: 83, train_loss: 0.0065, train_auc: 0.9839, test_auc: 0.9057, lr: 0.0010 epoch: 84, train_loss: 0.0067, train_auc: 0.9837, test_auc: 0.9053, lr: 0.0010 epoch: 85, train_loss: 0.0065, train_auc: 0.9842, test_auc: 0.9060, lr: 0.0010 epoch: 86, train_loss: 0.0066, train_auc: 0.9840, test_auc: 0.9051, lr: 0.0010 epoch: 87, train_loss: 0.0066, train_auc: 0.9847, test_auc: 0.9061, lr: 0.0010 epoch: 88, train_loss: 0.0063, train_auc: 0.9838, test_auc: 0.9036, lr: 0.0010 epoch: 89, train_loss: 0.0062, train_auc: 0.9847, test_auc: 0.9062, lr: 0.0010 epoch: 90, train_loss: 0.0063, train_auc: 0.9840, test_auc: 0.9047, lr: 0.0010 epoch: 91, train_loss: 0.0064, train_auc: 0.9835, test_auc: 0.9032, lr: 0.0010 epoch: 92, train_loss: 0.0064, train_auc: 0.9842, test_auc: 0.9053, lr: 0.0010 epoch: 93, train_loss: 0.0063, train_auc: 0.9838, test_auc: 0.9045, lr: 0.0010 epoch: 94, train_loss: 0.0063, train_auc: 0.9844, test_auc: 0.9040, lr: 0.0010 epoch: 95, train_loss: 0.0063, train_auc: 0.9848, test_auc: 0.9054, lr: 0.0010 epoch: 96, train_loss: 0.0062, train_auc: 0.9836, test_auc: 0.9030, lr: 0.0010 epoch: 97, train_loss: 0.0059, train_auc: 0.9842, test_auc: 0.9041, lr: 0.0010 epoch: 98, train_loss: 0.0063, train_auc: 0.9845, test_auc: 0.9044, lr: 0.0010 epoch: 99, train_loss: 0.0061, train_auc: 0.9846, test_auc: 0.9044, lr: 0.0010 Visualization ----------------------- Now, let's see the learning curve for optimizing AUROC on train and test sets. .. container:: cell code .. code:: python import matplotlib.pyplot as plt plt.rcParams["figure.figsize"] = (9,5) x=np.arange(len(train_log)) plt.figure() plt.plot(x, train_log, lineStyle='-', label='Train Set', linewidth=3) plt.plot(x, test_log, lineStyle='-', label='Test Set', linewidth=3) plt.title('AUCMLoss (10% CIFAR10)',fontsize=25) plt.legend(fontsize=15) plt.ylabel('AUROC', fontsize=25) plt.xlabel('Epoch', fontsize=25) .. container:: output execute_result .. container:: output display_data .. image:: ./imgs/auroc.png