Optimizing Two-Way partial AUC on Imbalanced CIFAR10 Dataset (SOTAs)


Author: Dixian Zhu
Edited by: Zhuoning Yuan, Tianbao Yang

Introduction

In this tutorial, we will learn how to quickly train a ResNet18 model by optimizing two way partial AUC (TPAUC) score using our novel tpAUC_KL_Loss and SOTAs optimizer [Ref] method on a binary image classification task on Cifar10. After completion of this tutorial, you should be able to use LibAUC to train your own models on your own datasets.

Reference:

If you find this tutorial helpful in your work, please cite our [library paper] and the following papers:

@inproceedings{zhu2022auc,
               title={When auc meets dro: Optimizing partial auc for deep learning with non-convex convergence guarantee},
               author={Zhu, Dixian and Li, Gang and Wang, Bokun and Wu, Xiaodong and Yang, Tianbao},
               booktitle={International Conference on Machine Learning},
               pages={27548--27573},
               year={2022},
               organization={PMLR}
             }

Install LibAUC

Let’s start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using pip install -U.

!pip install -U libauc

Importing LibAUC

Import required libraries to use

from libauc.models import resnet18
from libauc.datasets import CIFAR10
from libauc.losses import tpAUC_KL_Loss
from libauc.optimizers import SOTAs
from libauc.utils import ImbalancedDataGenerator
from libauc.sampler import DualSampler # data resampling (for binary class)
from libauc.metrics import pauc_roc_score

import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset

Reproducibility

These functions limit the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases [Ref].

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Image Dataset

Now we define the data input pipeline such as data augmentations. In this tutorials, we use RandomCrop, RandomHorizontalFlip. The pos_index_map helps map global index to local index for reducing memory cost in loss function since we only need to track the indices for positive samples.

class ImageDataset(Dataset):
  def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
     self.images = images.astype(np.uint8)
     self.targets = targets
     self.mode = mode
     self.transform_train = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.RandomCrop((crop_size, crop_size), padding=None),
                            transforms.RandomHorizontalFlip(),
                            transforms.Resize((image_size, image_size)),
                            ])
     self.transform_test = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Resize((image_size, image_size)),
                            ])

     # for loss function
     self.pos_indices = np.flatnonzero(targets==1)
     self.pos_index_map = {}
     for i, idx in enumerate(self.pos_indices):
         self.pos_index_map[idx] = i

  def __len__(self):
      return len(self.images)

  def __getitem__(self, idx):
      image = self.images[idx]
      target = self.targets[idx]
      image = Image.fromarray(image.astype('uint8'))
      if self.mode == 'train':
          idx = self.pos_index_map[idx] if idx in self.pos_indices else -1
          image = self.transform_train(image)
      else:
          image = self.transform_test(image)
      return image, target, idx

HyperParameters

# HyperParameters
SEED = 123
batch_size = 64
total_epochs = 60
weight_decay = 1e-4
lr = 1e-3
decay_epochs = [20, 40]
decay_factor = 10

gamma0 = 0.5 # learning rate for control negative samples weights
gamma1 = 0.5 # learning rate for control positive samples weights

tau = 1.0 # KL-DRO regularization for outer positive samples part  #
Lambda = 0.5 # KL-DRO regularization for inner negative samples part #

# oversampling minority class, you can tune it in (0, 0.5]
# e.g., sampling_rate=0.5 is that num of positive samples in mini-batch is sampling_rate*batch_size=32
sampling_rate = 0.5
num_pos = round(sampling_rate*batch_size)
num_neg = batch_size - num_pos

Loading datasets

# load data as numpy arrays
train_data, train_targets = CIFAR10(root='./data', train=True).as_array()
test_data, test_targets  = CIFAR10(root='./data', train=False).as_array()

# generate imbalanced data
generator = ImbalancedDataGenerator(shuffle=True, verbose=True, random_seed=0)
(train_images, train_labels) = generator.transform(train_data, train_targets, imratio=0.2)
(test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5)

# data augmentations
trainDataset = ImageDataset(train_images, train_labels)
testDataset = ImageDataset(test_images, test_labels, mode='test')

# dataloaders
sampler = DualSampler(trainDataset, batch_size, sampling_rate=sampling_rate)
trainloader = torch.utils.data.DataLoader(trainDataset, batch_size, sampler=sampler, shuffle=False, num_workers=1)
testloader = torch.utils.data.DataLoader(testDataset, batch_size=batch_size, shuffle=False, num_workers=1)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
  0%|          | 0/170498071 [00:00<?, ?it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
#SAMPLES: 31250, CLASS 0.0 COUNT: 25000, CLASS RATIO: 0.8000
#SAMPLES: 31250, CLASS 1.0 COUNT: 6250, CLASS RATIO: 0.2000
#SAMPLES: 10000, CLASS 0 COUNT: 5000, CLASS RATIO: 0.5000
#SAMPLES: 10000, CLASS 1 COUNT: 5000, CLASS RATIO: 0.5000

Creating models & TPAUC Optimizer

# You can include sigmoid/l2 activations on model's outputs before computing loss
set_all_seeds(SEED)
model = resnet18(pretrained=False, num_classes=1, last_activation=None)
model = model.cuda()


# Initialize the loss function and optimizer
loss_fn = tpAUC_KL_Loss(data_len=sampler.pos_len, Lambda=Lambda, tau=tau, gammas=(gamma0, gamma1))
optimizer = SOTAs(model.parameters(), loss_fn=loss_fn, mode='adam', lr=lr, weight_decay=weight_decay)

Training

import warnings
warnings.filterwarnings("ignore")
print ('Start Training')
print ('-'*30)

tr_tpAUC=[]
te_tpAUC=[]

for epoch in range(total_epochs):
    if epoch in decay_epochs:
        optimizer.update_lr(decay_factor=decay_factor)

    train_loss = 0
    model.train()
    for idx, data in enumerate(trainloader):
        train_data, train_labels, index = data
        train_data, train_labels = train_data.cuda(), train_labels.cuda()
        y_pred = model(train_data)
        y_prob = torch.sigmoid(y_pred)
        loss = loss_fn(y_prob, train_labels, index[:num_pos])
        train_loss = train_loss  + loss.cpu().detach().numpy()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_loss = train_loss/(idx+1)

    # evaluation
    model.eval()
    with torch.no_grad():
        train_pred = []
        train_true = []
        for jdx, data in enumerate(trainloader):
            train_data, train_labels,_ = data
            train_data = train_data.cuda()
            y_pred = model(train_data)
            y_prob = torch.sigmoid(y_pred)
            train_pred.append(y_prob.cpu().detach().numpy())
            train_true.append(train_labels.numpy())
        train_true = np.concatenate(train_true)
        train_pred = np.concatenate(train_pred)
        single_train_auc =  pauc_roc_score(train_true, train_pred, max_fpr = 0.3, min_tpr=0.7)

        test_pred = []
        test_true = []
        for jdx, data in enumerate(testloader):
            test_data, test_labels, _ = data
            test_data = test_data.cuda()
            y_pred = model(test_data)
            y_prob = torch.sigmoid(y_pred)
            test_pred.append(y_prob.cpu().detach().numpy())
            test_true.append(test_labels.numpy())
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        single_test_auc =  pauc_roc_score(test_true, test_pred, max_fpr = 0.3, min_tpr=0.7)
        print('Epoch=%s, Loss=%0.4f, Train_tpAUC(0.3,0.7)=%.4f, Test_tpAUC(0.3,0.7)=%.4f, lr=%.4f'%(epoch, train_loss, single_train_auc, single_test_auc, optimizer.lr))

        tr_tpAUC.append(single_train_auc)
        te_tpAUC.append(single_test_auc)
Start Training
------------------------------
Epoch=0, Loss=1.3559, Train_tpAUC(0.3,0.7)=0.0000, Test_tpAUC(0.3,0.7)=0.0000, lr=0.0010
Epoch=1, Loss=0.9319, Train_tpAUC(0.3,0.7)=0.0238, Test_tpAUC(0.3,0.7)=0.0072, lr=0.0010
Epoch=2, Loss=0.8985, Train_tpAUC(0.3,0.7)=0.0201, Test_tpAUC(0.3,0.7)=0.0076, lr=0.0010
Epoch=3, Loss=0.8669, Train_tpAUC(0.3,0.7)=0.1492, Test_tpAUC(0.3,0.7)=0.0692, lr=0.0010
Epoch=4, Loss=0.8389, Train_tpAUC(0.3,0.7)=0.1715, Test_tpAUC(0.3,0.7)=0.1109, lr=0.0010
Epoch=5, Loss=0.8074, Train_tpAUC(0.3,0.7)=0.2485, Test_tpAUC(0.3,0.7)=0.1212, lr=0.0010
Epoch=6, Loss=0.7615, Train_tpAUC(0.3,0.7)=0.1549, Test_tpAUC(0.3,0.7)=0.0624, lr=0.0010
Epoch=7, Loss=0.7338, Train_tpAUC(0.3,0.7)=0.3678, Test_tpAUC(0.3,0.7)=0.2257, lr=0.0010
Epoch=8, Loss=0.7005, Train_tpAUC(0.3,0.7)=0.4222, Test_tpAUC(0.3,0.7)=0.1703, lr=0.0010
Epoch=9, Loss=0.6752, Train_tpAUC(0.3,0.7)=0.5709, Test_tpAUC(0.3,0.7)=0.2833, lr=0.0010
Epoch=10, Loss=0.6236, Train_tpAUC(0.3,0.7)=0.6198, Test_tpAUC(0.3,0.7)=0.3084, lr=0.0010
Epoch=11, Loss=0.5767, Train_tpAUC(0.3,0.7)=0.6812, Test_tpAUC(0.3,0.7)=0.2975, lr=0.0010
Epoch=12, Loss=0.5610, Train_tpAUC(0.3,0.7)=0.7206, Test_tpAUC(0.3,0.7)=0.2951, lr=0.0010
Epoch=13, Loss=0.4974, Train_tpAUC(0.3,0.7)=0.6405, Test_tpAUC(0.3,0.7)=0.2713, lr=0.0010
Epoch=14, Loss=0.4837, Train_tpAUC(0.3,0.7)=0.7563, Test_tpAUC(0.3,0.7)=0.3307, lr=0.0010
Epoch=15, Loss=0.4584, Train_tpAUC(0.3,0.7)=0.8308, Test_tpAUC(0.3,0.7)=0.2802, lr=0.0010
Epoch=16, Loss=0.4179, Train_tpAUC(0.3,0.7)=0.7479, Test_tpAUC(0.3,0.7)=0.2940, lr=0.0010
Epoch=17, Loss=0.3814, Train_tpAUC(0.3,0.7)=0.8547, Test_tpAUC(0.3,0.7)=0.3091, lr=0.0010
Epoch=18, Loss=0.3980, Train_tpAUC(0.3,0.7)=0.8610, Test_tpAUC(0.3,0.7)=0.3040, lr=0.0010
Epoch=19, Loss=0.3435, Train_tpAUC(0.3,0.7)=0.8609, Test_tpAUC(0.3,0.7)=0.3290, lr=0.0010
Reducing learning rate to 0.00010 @ T=15620!
Epoch=20, Loss=0.1765, Train_tpAUC(0.3,0.7)=0.9732, Test_tpAUC(0.3,0.7)=0.3890, lr=0.0001
Epoch=21, Loss=0.1250, Train_tpAUC(0.3,0.7)=0.9828, Test_tpAUC(0.3,0.7)=0.3940, lr=0.0001
Epoch=22, Loss=0.0892, Train_tpAUC(0.3,0.7)=0.9875, Test_tpAUC(0.3,0.7)=0.3982, lr=0.0001
Epoch=23, Loss=0.0818, Train_tpAUC(0.3,0.7)=0.9920, Test_tpAUC(0.3,0.7)=0.3997, lr=0.0001
Epoch=24, Loss=0.0590, Train_tpAUC(0.3,0.7)=0.9943, Test_tpAUC(0.3,0.7)=0.3934, lr=0.0001
Epoch=25, Loss=0.0555, Train_tpAUC(0.3,0.7)=0.9937, Test_tpAUC(0.3,0.7)=0.3870, lr=0.0001
Epoch=26, Loss=0.0487, Train_tpAUC(0.3,0.7)=0.9953, Test_tpAUC(0.3,0.7)=0.3958, lr=0.0001
Epoch=27, Loss=0.0488, Train_tpAUC(0.3,0.7)=0.9956, Test_tpAUC(0.3,0.7)=0.4003, lr=0.0001
Epoch=28, Loss=0.0352, Train_tpAUC(0.3,0.7)=0.9971, Test_tpAUC(0.3,0.7)=0.3877, lr=0.0001
Epoch=29, Loss=0.0392, Train_tpAUC(0.3,0.7)=0.9974, Test_tpAUC(0.3,0.7)=0.3874, lr=0.0001
Epoch=30, Loss=0.0299, Train_tpAUC(0.3,0.7)=0.9979, Test_tpAUC(0.3,0.7)=0.3865, lr=0.0001
Epoch=31, Loss=0.0297, Train_tpAUC(0.3,0.7)=0.9968, Test_tpAUC(0.3,0.7)=0.3822, lr=0.0001
Epoch=32, Loss=0.0286, Train_tpAUC(0.3,0.7)=0.9977, Test_tpAUC(0.3,0.7)=0.3928, lr=0.0001
Epoch=33, Loss=0.0241, Train_tpAUC(0.3,0.7)=0.9984, Test_tpAUC(0.3,0.7)=0.3885, lr=0.0001
Epoch=34, Loss=0.0295, Train_tpAUC(0.3,0.7)=0.9981, Test_tpAUC(0.3,0.7)=0.3935, lr=0.0001
Epoch=35, Loss=0.0225, Train_tpAUC(0.3,0.7)=0.9987, Test_tpAUC(0.3,0.7)=0.3826, lr=0.0001
Epoch=36, Loss=0.0198, Train_tpAUC(0.3,0.7)=0.9983, Test_tpAUC(0.3,0.7)=0.3836, lr=0.0001
Epoch=37, Loss=0.0214, Train_tpAUC(0.3,0.7)=0.9982, Test_tpAUC(0.3,0.7)=0.4031, lr=0.0001
Epoch=38, Loss=0.0215, Train_tpAUC(0.3,0.7)=0.9987, Test_tpAUC(0.3,0.7)=0.3978, lr=0.0001
Epoch=39, Loss=0.0155, Train_tpAUC(0.3,0.7)=0.9988, Test_tpAUC(0.3,0.7)=0.3939, lr=0.0001
Reducing learning rate to 0.00001 @ T=31240!
Epoch=40, Loss=0.0171, Train_tpAUC(0.3,0.7)=0.9987, Test_tpAUC(0.3,0.7)=0.3909, lr=0.0000
Epoch=41, Loss=0.0142, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.4008, lr=0.0000
Epoch=42, Loss=0.0076, Train_tpAUC(0.3,0.7)=0.9989, Test_tpAUC(0.3,0.7)=0.4031, lr=0.0000
Epoch=43, Loss=0.0094, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3930, lr=0.0000
Epoch=44, Loss=0.0096, Train_tpAUC(0.3,0.7)=0.9989, Test_tpAUC(0.3,0.7)=0.3971, lr=0.0000
Epoch=45, Loss=0.0093, Train_tpAUC(0.3,0.7)=0.9993, Test_tpAUC(0.3,0.7)=0.3965, lr=0.0000
Epoch=46, Loss=0.0089, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3928, lr=0.0000
Epoch=47, Loss=0.0066, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3991, lr=0.0000
Epoch=48, Loss=0.0079, Train_tpAUC(0.3,0.7)=0.9990, Test_tpAUC(0.3,0.7)=0.3880, lr=0.0000
Epoch=49, Loss=0.0088, Train_tpAUC(0.3,0.7)=0.9995, Test_tpAUC(0.3,0.7)=0.3928, lr=0.0000
Epoch=50, Loss=0.0048, Train_tpAUC(0.3,0.7)=0.9993, Test_tpAUC(0.3,0.7)=0.3874, lr=0.0000
Epoch=51, Loss=0.0062, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3957, lr=0.0000
Epoch=52, Loss=0.0077, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3913, lr=0.0000
Epoch=53, Loss=0.0049, Train_tpAUC(0.3,0.7)=0.9993, Test_tpAUC(0.3,0.7)=0.3926, lr=0.0000
Epoch=54, Loss=0.0060, Train_tpAUC(0.3,0.7)=0.9993, Test_tpAUC(0.3,0.7)=0.4049, lr=0.0000
Epoch=55, Loss=0.0040, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.4000, lr=0.0000
Epoch=56, Loss=0.0056, Train_tpAUC(0.3,0.7)=0.9993, Test_tpAUC(0.3,0.7)=0.3869, lr=0.0000
Epoch=57, Loss=0.0036, Train_tpAUC(0.3,0.7)=0.9993, Test_tpAUC(0.3,0.7)=0.3948, lr=0.0000
Epoch=58, Loss=0.0053, Train_tpAUC(0.3,0.7)=0.9991, Test_tpAUC(0.3,0.7)=0.3949, lr=0.0000
Epoch=59, Loss=0.0057, Train_tpAUC(0.3,0.7)=0.9994, Test_tpAUC(0.3,0.7)=0.3939, lr=0.0000

Visualization

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (9,5)
x=np.arange(60)

plt.figure()
plt.plot(x, tr_tpAUC, linestyle='--', label='SOTA-s train', linewidth=3)
plt.plot(x, te_tpAUC, label='SOTA-s test', linewidth=3)
plt.title('CIFAR-10 (20% imbalanced)',fontsize=25)
plt.legend(fontsize=15)
plt.ylabel('TPAUC(0.7,0.3)',fontsize=25)
plt.xlabel('epochs',fontsize=25)
../_images/sotas-tutorial.png