Multiple Instance Deep AUC Maximization with smoothed-max pooling (MIDAM-smx) on Tabular Dataset
Introduction
In this tutorial, we will learn how to quickly train a simple Feed Forward Neural Network (FFNN) model by optimizing Multiple Instance Deep AUC Maximization (MIDAM) under our novel MIDAMLoss(mode='softmax')
and MIDAM
optimizer [Ref] method on a binary classification task on MUSK2 dataset. After completion of this tutorial, you should be able to use LibAUC to train your own models on your own datasets.
Reference:
If you find this tutorial helpful in your work, please cite our library paper and the following papers:
@inproceedings{zhu2023mil,
title={Provable Multi-instance Deep AUC Maximization with Stochastic Pooling},
author={Zhu, Dixian and Wang, Bokun and Chen, Zhi and Wang, Yaxing and Sonka, Milan and Wu, Xiaodong and Yang, Tianbao},
booktitle={International Conference on Machine Learning},
year={2023},
organization={PMLR}
}
Install LibAUC
Let’s start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using pip install -U
.
!pip install -U libauc
Importing LibAUC
Import required libraries to use
import torch
import matplotlib.pyplot as plt
import numpy as np
from libauc_dev.optimizers import MIDAM
from libauc_dev.losses import MIDAMLoss
from libauc_dev.models import FFNN
from libauc_dev.utils import set_all_seeds, collate_fn, TabularDataset, MIL_sampling, MIL_evaluate_auc
from libauc_dev.sampler import DualSampler
Reproducibility
These functions limit the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases [Ref].
def set_all_seeds(SEED):
# REPRODUCIBILITY
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Introduction for Loss and Optimizer
In this section, we will introduce pAUC optimization algorithm and how to utilize MIDAMLoss
function and MIDAM
optimizer.
HyperParameters**
The hyper-parameters: batch size (bag-level), instance batch size (instance-level), postive sampling rate, learning rate, weight decay and margin for AUC loss.
# HyperParameters
SEED = 123
set_all_seeds(SEED)
batch_size = 16
instance_batch_size = 4
sampling_rate = 0.5
lr = 1e-2
weight_decay = 1e-4
margin = 1.0
tau = 0.1
momentum = 0.1
gamma = 0.9
Load Data, initialize model and loss
In this step, we will use the MUSK2 as benchmark dataset [Ref]. Import data to dataloader. We adopt a simple FFNN as the backbone model. Data format: a list with length equals to number of bags. Each bag: (Number of instances for this bag, Dimension)
tmp = np.load('/home/dixzhu/data/musk_2.npz',allow_pickle=True) # replace this with an url, file size: 8.8 MB.
train_data = tmp['train_X']
test_data = tmp['test_X']
train_labels = tmp['train_Y'].astype(int)
test_labels = tmp['test_Y'].astype(int)
traindSet = TabularDataset(train_data, train_labels)
testSet = TabularDataset(test_data, test_labels)
DIMS=166
sampler = DualSampler(dataset=traindSet, batch_size=batch_size, shuffle=True, sampling_rate=sampling_rate)
trainloader = torch.utils.data.DataLoader(dataset=traindSet, sampler=sampler, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
testloader = torch.utils.data.DataLoader(testSet, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FFNN(input_dim=DIMS, hidden_sizes=(DIMS, ), num_classes=1, last_activation='sigmoid').to(device)
Loss = MIDAMLoss(mode='softmax',data_len=len(traindSet), tau=tau, margin=margin, gamma=gamma)
optimizer = MIDAM(model.parameters(), loss_fn=Loss, lr=lr, weight_decay=weight_decay, momentum=momentum)
Training
total_epochs = 100
decay_epoch = [50, 75]
train_auc = np.zeros(total_epochs)
test_auc = np.zeros(total_epochs)
for epoch in range(total_epochs):
if epoch in decay_epoch:
optimizer.update_lr(decay_factor=10)
Loss.update_smoothing(decay_factor=2)
tr_loss = 0
for idx, data in enumerate(trainloader):
y_pred = []
train_data_bags, train_labels, ids = data
for i in range(len(ids)):
tmp_pred = MIL_sampling(bag_X=train_data_bags[i], model=model, instance_batch_size=instance_batch_size, tau=tau, mode='exp')
y_pred.append(tmp_pred)
y_pred = torch.cat(y_pred, dim=0)
ids = torch.from_numpy(np.array(ids))
train_labels = torch.from_numpy(np.array(train_labels))
loss = Loss(y_pred=y_pred, y_true=train_labels, index=ids)
optimizer.zero_grad()
loss.backward()
optimizer.step()
tr_loss += loss.detach().cpu().numpy()
model.eval()
with torch.no_grad():
single_tr_auc = MIL_evaluate_auc(trainloader, model, mode='softmax', tau=tau)
single_te_auc = MIL_evaluate_auc(testloader, model, mode='softmax', tau=tau)
train_auc[epoch] = single_tr_auc
test_auc[epoch] = single_te_auc
model.train()
tr_loss = tr_loss/idx
print ('Epoch=%s, BatchID=%s, loss=%.4f, Tr_AUC=%.4f, Test_AUC=%.4f, lr=%.4f'%(epoch, idx, tr_loss, single_tr_auc, single_te_auc, optimizer.lr))
Start Training
------------------------------
Epoch=0, BatchID=5, loss=0.4095, Tr_AUC=0.4596, Test_AUC=0.3333, lr=0.0100
Epoch=1, BatchID=5, loss=0.0932, Tr_AUC=0.4453, Test_AUC=0.3333, lr=0.0100
Epoch=2, BatchID=5, loss=0.0392, Tr_AUC=0.3980, Test_AUC=0.3333, lr=0.0100
Epoch=3, BatchID=5, loss=0.0211, Tr_AUC=0.4514, Test_AUC=0.3333, lr=0.0100
Epoch=4, BatchID=5, loss=0.0133, Tr_AUC=0.4596, Test_AUC=0.3333, lr=0.0100
Epoch=5, BatchID=5, loss=0.0117, Tr_AUC=0.5269, Test_AUC=0.3333, lr=0.0100
Epoch=6, BatchID=5, loss=0.0080, Tr_AUC=0.5868, Test_AUC=0.3333, lr=0.0100
Epoch=7, BatchID=5, loss=0.0076, Tr_AUC=0.6467, Test_AUC=0.2917, lr=0.0100
Epoch=8, BatchID=5, loss=0.0089, Tr_AUC=0.6419, Test_AUC=0.3333, lr=0.0100
Epoch=9, BatchID=5, loss=0.0062, Tr_AUC=0.7161, Test_AUC=0.3333, lr=0.0100
Epoch=10, BatchID=5, loss=0.0076, Tr_AUC=0.7335, Test_AUC=0.3333, lr=0.0100
Epoch=11, BatchID=5, loss=0.0076, Tr_AUC=0.7496, Test_AUC=0.3333, lr=0.0100
Epoch=12, BatchID=5, loss=0.0091, Tr_AUC=0.8164, Test_AUC=0.4583, lr=0.0100
Epoch=13, BatchID=5, loss=0.0146, Tr_AUC=0.8151, Test_AUC=0.4583, lr=0.0100
Epoch=14, BatchID=5, loss=0.0156, Tr_AUC=0.8251, Test_AUC=0.4583, lr=0.0100
Epoch=15, BatchID=5, loss=0.0185, Tr_AUC=0.8416, Test_AUC=0.5000, lr=0.0100
Epoch=16, BatchID=5, loss=0.0249, Tr_AUC=0.8815, Test_AUC=0.5417, lr=0.0100
Epoch=17, BatchID=5, loss=0.0392, Tr_AUC=0.8911, Test_AUC=0.6250, lr=0.0100
Epoch=18, BatchID=5, loss=0.0390, Tr_AUC=0.8906, Test_AUC=0.6667, lr=0.0100
Epoch=19, BatchID=5, loss=0.0468, Tr_AUC=0.9128, Test_AUC=0.6667, lr=0.0100
Epoch=20, BatchID=5, loss=0.0481, Tr_AUC=0.9314, Test_AUC=0.7917, lr=0.0100
Epoch=21, BatchID=5, loss=0.0562, Tr_AUC=0.9119, Test_AUC=0.7917, lr=0.0100
Epoch=22, BatchID=5, loss=0.0664, Tr_AUC=0.9288, Test_AUC=0.8750, lr=0.0100
Epoch=23, BatchID=5, loss=0.0696, Tr_AUC=0.9284, Test_AUC=0.8750, lr=0.0100
Epoch=24, BatchID=5, loss=0.0579, Tr_AUC=0.9392, Test_AUC=0.7917, lr=0.0100
Epoch=25, BatchID=5, loss=0.0676, Tr_AUC=0.9527, Test_AUC=0.8750, lr=0.0100
Epoch=26, BatchID=5, loss=0.1044, Tr_AUC=0.9601, Test_AUC=0.8750, lr=0.0100
Epoch=27, BatchID=5, loss=0.0965, Tr_AUC=0.9688, Test_AUC=0.8750, lr=0.0100
Epoch=28, BatchID=5, loss=0.0928, Tr_AUC=0.9614, Test_AUC=0.8750, lr=0.0100
Epoch=29, BatchID=5, loss=0.0851, Tr_AUC=0.9614, Test_AUC=0.9167, lr=0.0100
Epoch=30, BatchID=5, loss=0.0936, Tr_AUC=0.9544, Test_AUC=0.8750, lr=0.0100
Epoch=31, BatchID=5, loss=0.0914, Tr_AUC=0.9622, Test_AUC=0.8750, lr=0.0100
Epoch=32, BatchID=5, loss=0.0942, Tr_AUC=0.9679, Test_AUC=0.8750, lr=0.0100
Epoch=33, BatchID=5, loss=0.0637, Tr_AUC=0.9622, Test_AUC=0.8750, lr=0.0100
Epoch=34, BatchID=5, loss=0.1122, Tr_AUC=0.9661, Test_AUC=0.8750, lr=0.0100
Epoch=35, BatchID=5, loss=0.0640, Tr_AUC=0.9761, Test_AUC=0.8750, lr=0.0100
Epoch=36, BatchID=5, loss=0.0794, Tr_AUC=0.9757, Test_AUC=0.8750, lr=0.0100
Epoch=37, BatchID=5, loss=0.0924, Tr_AUC=0.9648, Test_AUC=0.8750, lr=0.0100
Epoch=38, BatchID=5, loss=0.0803, Tr_AUC=0.9757, Test_AUC=0.8750, lr=0.0100
Epoch=39, BatchID=5, loss=0.0796, Tr_AUC=0.9696, Test_AUC=0.8750, lr=0.0100
Epoch=40, BatchID=5, loss=0.0847, Tr_AUC=0.9740, Test_AUC=0.8750, lr=0.0100
Epoch=41, BatchID=5, loss=0.0692, Tr_AUC=0.9731, Test_AUC=0.8750, lr=0.0100
Epoch=42, BatchID=5, loss=0.0665, Tr_AUC=0.9596, Test_AUC=0.8750, lr=0.0100
Epoch=43, BatchID=5, loss=0.1014, Tr_AUC=0.9661, Test_AUC=0.8750, lr=0.0100
Epoch=44, BatchID=5, loss=0.0931, Tr_AUC=0.9609, Test_AUC=0.8750, lr=0.0100
Epoch=45, BatchID=5, loss=0.1023, Tr_AUC=0.9644, Test_AUC=0.8750, lr=0.0100
Epoch=46, BatchID=5, loss=0.0795, Tr_AUC=0.9731, Test_AUC=0.8750, lr=0.0100
Epoch=47, BatchID=5, loss=0.0615, Tr_AUC=0.9809, Test_AUC=0.8750, lr=0.0100
Epoch=48, BatchID=5, loss=0.0750, Tr_AUC=0.9800, Test_AUC=0.8750, lr=0.0100
Epoch=49, BatchID=5, loss=0.0884, Tr_AUC=0.9614, Test_AUC=0.8750, lr=0.0100
Reducing learning rate to 0.00100 @ T=300!
Updating regularizer @ T=300!
Epoch=50, BatchID=5, loss=0.0753, Tr_AUC=0.9701, Test_AUC=0.8750, lr=0.0010
Epoch=51, BatchID=5, loss=0.0595, Tr_AUC=0.9766, Test_AUC=0.9167, lr=0.0010
Epoch=52, BatchID=5, loss=0.0451, Tr_AUC=0.9757, Test_AUC=0.9167, lr=0.0010
Epoch=53, BatchID=5, loss=0.0571, Tr_AUC=0.9722, Test_AUC=0.9167, lr=0.0010
Epoch=54, BatchID=5, loss=0.0612, Tr_AUC=0.9813, Test_AUC=0.9167, lr=0.0010
Epoch=55, BatchID=5, loss=0.0717, Tr_AUC=0.9818, Test_AUC=0.8750, lr=0.0010
Epoch=56, BatchID=5, loss=0.0901, Tr_AUC=0.9844, Test_AUC=0.8750, lr=0.0010
Epoch=57, BatchID=5, loss=0.0749, Tr_AUC=0.9848, Test_AUC=0.8750, lr=0.0010
Epoch=58, BatchID=5, loss=0.0750, Tr_AUC=0.9848, Test_AUC=0.8750, lr=0.0010
Epoch=59, BatchID=5, loss=0.0717, Tr_AUC=0.9740, Test_AUC=0.8750, lr=0.0010
Epoch=60, BatchID=5, loss=0.0712, Tr_AUC=0.9831, Test_AUC=0.8750, lr=0.0010
Epoch=61, BatchID=5, loss=0.0620, Tr_AUC=0.9874, Test_AUC=0.9167, lr=0.0010
Epoch=62, BatchID=5, loss=0.0758, Tr_AUC=0.9874, Test_AUC=0.8750, lr=0.0010
Epoch=63, BatchID=5, loss=0.0921, Tr_AUC=0.9831, Test_AUC=0.8750, lr=0.0010
Epoch=64, BatchID=5, loss=0.0866, Tr_AUC=0.9883, Test_AUC=0.8750, lr=0.0010
Epoch=65, BatchID=5, loss=0.0719, Tr_AUC=0.9891, Test_AUC=0.9167, lr=0.0010
Epoch=66, BatchID=5, loss=0.0924, Tr_AUC=0.9835, Test_AUC=0.9167, lr=0.0010
Epoch=67, BatchID=5, loss=0.0916, Tr_AUC=0.9813, Test_AUC=0.9167, lr=0.0010
Epoch=68, BatchID=5, loss=0.0772, Tr_AUC=0.9839, Test_AUC=0.9167, lr=0.0010
Epoch=69, BatchID=5, loss=0.0695, Tr_AUC=0.9826, Test_AUC=0.9167, lr=0.0010
Epoch=70, BatchID=5, loss=0.0769, Tr_AUC=0.9905, Test_AUC=0.9167, lr=0.0010
Epoch=71, BatchID=5, loss=0.0755, Tr_AUC=0.9844, Test_AUC=0.9167, lr=0.0010
Epoch=72, BatchID=5, loss=0.0646, Tr_AUC=0.9891, Test_AUC=0.9167, lr=0.0010
Epoch=73, BatchID=5, loss=0.0826, Tr_AUC=0.9896, Test_AUC=0.9167, lr=0.0010
Epoch=74, BatchID=5, loss=0.0830, Tr_AUC=0.9844, Test_AUC=0.9167, lr=0.0010
Reducing learning rate to 0.00010 @ T=450!
Updating regularizer @ T=450!
Epoch=75, BatchID=5, loss=0.0838, Tr_AUC=0.9839, Test_AUC=0.9167, lr=0.0001
Epoch=76, BatchID=5, loss=0.0754, Tr_AUC=0.9909, Test_AUC=0.9167, lr=0.0001
Epoch=77, BatchID=5, loss=0.0596, Tr_AUC=0.9831, Test_AUC=0.9167, lr=0.0001
Epoch=78, BatchID=5, loss=0.0711, Tr_AUC=0.9900, Test_AUC=0.9167, lr=0.0001
Epoch=79, BatchID=5, loss=0.0890, Tr_AUC=0.9874, Test_AUC=0.9167, lr=0.0001
Epoch=80, BatchID=5, loss=0.0815, Tr_AUC=0.9809, Test_AUC=0.9167, lr=0.0001
Epoch=81, BatchID=5, loss=0.0744, Tr_AUC=0.9870, Test_AUC=0.9167, lr=0.0001
Epoch=82, BatchID=5, loss=0.0694, Tr_AUC=0.9905, Test_AUC=0.9167, lr=0.0001
Epoch=83, BatchID=5, loss=0.0884, Tr_AUC=0.9913, Test_AUC=0.9167, lr=0.0001
Epoch=84, BatchID=5, loss=0.0923, Tr_AUC=0.9813, Test_AUC=0.9167, lr=0.0001
Epoch=85, BatchID=5, loss=0.0632, Tr_AUC=0.9935, Test_AUC=0.9167, lr=0.0001
Epoch=86, BatchID=5, loss=0.0800, Tr_AUC=0.9909, Test_AUC=0.9167, lr=0.0001
Epoch=87, BatchID=5, loss=0.0902, Tr_AUC=0.9844, Test_AUC=0.9167, lr=0.0001
Epoch=88, BatchID=5, loss=0.0751, Tr_AUC=0.9909, Test_AUC=0.9167, lr=0.0001
Epoch=89, BatchID=5, loss=0.0613, Tr_AUC=0.9900, Test_AUC=0.9167, lr=0.0001
Epoch=90, BatchID=5, loss=0.0873, Tr_AUC=0.9926, Test_AUC=0.9167, lr=0.0001
Epoch=91, BatchID=5, loss=0.0826, Tr_AUC=0.9935, Test_AUC=0.9167, lr=0.0001
Epoch=92, BatchID=5, loss=0.0873, Tr_AUC=0.9900, Test_AUC=0.9167, lr=0.0001
Epoch=93, BatchID=5, loss=0.0872, Tr_AUC=0.9861, Test_AUC=0.9167, lr=0.0001
Epoch=94, BatchID=5, loss=0.0815, Tr_AUC=0.9922, Test_AUC=0.9167, lr=0.0001
Epoch=95, BatchID=5, loss=0.0748, Tr_AUC=0.9913, Test_AUC=0.9167, lr=0.0001
Epoch=96, BatchID=5, loss=0.0724, Tr_AUC=0.9931, Test_AUC=0.9167, lr=0.0001
Epoch=97, BatchID=5, loss=0.0652, Tr_AUC=0.9896, Test_AUC=0.9167, lr=0.0001
Epoch=98, BatchID=5, loss=0.0828, Tr_AUC=0.9870, Test_AUC=0.9167, lr=0.0001
Epoch=99, BatchID=5, loss=0.1051, Tr_AUC=0.9883, Test_AUC=0.9167, lr=0.0001
Visualization
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (9,5)
x=np.arange(len(train_auc))
plt.figure()
plt.plot(x, train_auc, linestyle='--', label='train', linewidth=3)
plt.plot(x, test_auc, label='test', linewidth=3)
plt.title('MUSK2',fontsize=25)
plt.legend(fontsize=15)
plt.ylabel('AUC',fontsize=25)
plt.xlabel('epochs',fontsize=25)