================================================================================================================================ Multiple Instance Deep AUC Maximization on Tabular Dataset ================================================================================================================================ .. raw:: html
Run on Colab
Download Notebook
View on Github
------------------------------------------------------------------------------------ .. container:: cell markdown | **Author**: Dixian Zhu, | **Edited by**: Zhuoning Yuan, Tianbao Yang \ Introduction ----------------------- In this tutorial, we will learn how to quickly train a simple Feed Forward Neural Network (FFNN) model by optimizing **Multiple Instance Deep AUC Maximization (MIDAM)** under our novel :obj:`MIDAMLoss` and :obj:`MIDAM` optimizer `[Ref] `__ method on a binary classification task on MUSK2 dataset. Please note that :obj:`MIDAMLoss` is a wrapper function for different types of MIDAM losses. It currently supports two primary modes: - :obj:`MIDAMLoss(mode='attention')`: This mode uses :obj:`MIDAM_attention_pooling_loss` as the backend and utilizes the :obj:`MIDAM` optimizer for optimization. - :obj:`MIDAMLoss(mode='softmax')`: This mode uses :obj:`MIDAM_softmax_pooling_loss` as the backend and utilizes the :obj:`MIDAM` optimizer for optimization. In this tutorial, we focus on attention pooling (MIDAM-att), i.e., :obj:`MIDAMLoss(mode='attention')`. For other tutorials, please refer to :ref:`MIDAM-att-image `, :ref:`MIDAM-smx-tabular `. After completion of this tutorial, you should be able to use LibAUC to train your own models on your own datasets. **Reference**: If you find this tutorial helpful in your work, please cite our `library paper `__ and the following papers: .. code-block:: RST @inproceedings{zhu2023mil, title={Provable Multi-instance Deep AUC Maximization with Stochastic Pooling}, author={Zhu, Dixian and Wang, Bokun and Chen, Zhi and Wang, Yaxing and Sonka, Milan and Wu, Xiaodong and Yang, Tianbao}, booktitle={International Conference on Machine Learning}, year={2023}, organization={PMLR} } Install LibAUC ------------------------------------------------------------------------------------ Let's start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using ``pip install -U``. .. container:: cell code .. code:: python !pip install -U libauc Importing LibAUC ----------------------- Import required libraries to use .. container:: cell code .. code:: python import torch import matplotlib.pyplot as plt import numpy as np from libauc.optimizers import MIDAM from libauc.losses import MIDAMLoss from libauc.models import FFNN_stoc_att from libauc.utils import set_all_seeds, collate_fn, MIL_sampling, MIL_evaluate_auc from libauc.sampler import DualSampler from libauc.datasets import MUSK2, CustomDataset Reproducibility ----------------------- These functions limit the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases `[Ref] `__. .. container:: cell code .. code:: python def set_all_seeds(SEED): # REPRODUCIBILITY torch.manual_seed(SEED) np.random.seed(SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False Introduction for Loss and Optimizer ----------------------- In this section, we will introduce pAUC optimization algorithm and how to utilize ``MIDAMLoss`` function and ``MIDAM`` optimizer. HyperParameters ----------------------- The hyper-parameters: batch size (bag-level), instance batch size (instance-level), postive sampling rate, learning rate, weight decay and margin for AUC loss. .. container:: cell code .. code:: python # HyperParameters SEED = 123 set_all_seeds(SEED) batch_size = 16 instance_batch_size = 4 sampling_rate = 0.5 lr = 1e-2 weight_decay = 1e-4 margin = 1.0 momentum = 0.1 gamma = 0.9 Load Data, initialize model and loss ----------------------- In this step, we will use the MUSK2 as benchmark dataset `[Ref] `__. Import data to dataloader. We extend the traditional FFNN with an additional attention module: FFNN_stoc_att. Data format: a list with length equals to number of bags. Each bag is an array with shape: (Number of instances for this bag, Dimension) .. container:: cell code .. code:: python (train_data, train_labels), (test_data, test_labels) = MUSK2() traindSet = CustomDataset(train_data, train_labels, return_index=True) testSet = CustomDataset(test_data, test_labels, return_index=True) DIMS=166 sampler = DualSampler(dataset=traindSet, batch_size=batch_size, shuffle=True, sampling_rate=sampling_rate) trainloader = torch.utils.data.DataLoader(dataset=traindSet, sampler=sampler, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) testloader = torch.utils.data.DataLoader(testSet, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = FFNN_stoc_att(input_dim=DIMS, hidden_sizes=(DIMS, ), num_classes=1).to(device) Loss = MIDAMLoss(mode='attention',data_len=len(traindSet), gamma=gamma, margin=margin) optimizer = MIDAM(model.parameters(), loss_fn=Loss, lr=lr, weight_decay=weight_decay, momentum=momentum) Training ----------------------- .. container:: cell code .. code:: python total_epochs = 100 decay_epoch = [50, 75] train_auc = np.zeros(total_epochs) test_auc = np.zeros(total_epochs) for epoch in range(total_epochs): if epoch in decay_epoch: optimizer.update_lr(decay_factor=10) Loss.update_smoothing(decay_factor=2) for idx, data in enumerate(trainloader): y_pred = [] sd = [] train_data_bags, train_labels, ids = data for i in range(len(ids)): tmp_pred, tmp_sd = MIL_sampling(bag_X=train_data_bags[i], model=model, instance_batch_size=instance_batch_size, mode='att') y_pred.append(tmp_pred) sd.append(tmp_sd) y_pred = torch.cat(y_pred, dim=0) sd = torch.cat(sd, dim=0) ids = torch.from_numpy(np.array(ids)) train_labels = torch.from_numpy(np.array(train_labels)) loss = Loss(y_pred=(y_pred,sd), y_true=train_labels, index=ids) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() with torch.no_grad(): single_tr_auc = MIL_evaluate_auc(trainloader, model, mode='att') single_te_auc = MIL_evaluate_auc(testloader, model, mode='att') train_auc[epoch] = single_tr_auc test_auc[epoch] = single_te_auc model.train() print ('Epoch=%s, BatchID=%s, Tr_AUC=%.4f, Test_AUC=%.4f, lr=%.4f'%(epoch, idx, single_tr_auc, single_te_auc, optimizer.lr)) .. container:: output stream stdout :: Epoch=0, BatchID=5, Tr_AUC=0.3338, Test_AUC=0.1250, lr=0.0100 Epoch=1, BatchID=5, Tr_AUC=0.3099, Test_AUC=0.1250, lr=0.0100 Epoch=2, BatchID=5, Tr_AUC=0.3025, Test_AUC=0.1250, lr=0.0100 Epoch=3, BatchID=5, Tr_AUC=0.3372, Test_AUC=0.1667, lr=0.0100 Epoch=4, BatchID=5, Tr_AUC=0.3993, Test_AUC=0.2083, lr=0.0100 Epoch=5, BatchID=5, Tr_AUC=0.3615, Test_AUC=0.2083, lr=0.0100 Epoch=6, BatchID=5, Tr_AUC=0.4484, Test_AUC=0.2083, lr=0.0100 Epoch=7, BatchID=5, Tr_AUC=0.4670, Test_AUC=0.2083, lr=0.0100 Epoch=8, BatchID=5, Tr_AUC=0.4605, Test_AUC=0.2083, lr=0.0100 Epoch=9, BatchID=5, Tr_AUC=0.4891, Test_AUC=0.2083, lr=0.0100 Epoch=10, BatchID=5, Tr_AUC=0.5482, Test_AUC=0.2083, lr=0.0100 Epoch=11, BatchID=5, Tr_AUC=0.5260, Test_AUC=0.2917, lr=0.0100 Epoch=12, BatchID=5, Tr_AUC=0.6211, Test_AUC=0.2500, lr=0.0100 Epoch=13, BatchID=5, Tr_AUC=0.5412, Test_AUC=0.2917, lr=0.0100 Epoch=14, BatchID=5, Tr_AUC=0.6289, Test_AUC=0.3750, lr=0.0100 Epoch=15, BatchID=5, Tr_AUC=0.6185, Test_AUC=0.4583, lr=0.0100 Epoch=16, BatchID=5, Tr_AUC=0.6241, Test_AUC=0.4583, lr=0.0100 Epoch=17, BatchID=5, Tr_AUC=0.6549, Test_AUC=0.4583, lr=0.0100 Epoch=18, BatchID=5, Tr_AUC=0.6901, Test_AUC=0.3750, lr=0.0100 Epoch=19, BatchID=5, Tr_AUC=0.6727, Test_AUC=0.3750, lr=0.0100 Epoch=20, BatchID=5, Tr_AUC=0.6797, Test_AUC=0.4167, lr=0.0100 Epoch=21, BatchID=5, Tr_AUC=0.6879, Test_AUC=0.4583, lr=0.0100 Epoch=22, BatchID=5, Tr_AUC=0.6997, Test_AUC=0.5000, lr=0.0100 Epoch=23, BatchID=5, Tr_AUC=0.7127, Test_AUC=0.5833, lr=0.0100 Epoch=24, BatchID=5, Tr_AUC=0.7509, Test_AUC=0.5417, lr=0.0100 Epoch=25, BatchID=5, Tr_AUC=0.7161, Test_AUC=0.5833, lr=0.0100 Epoch=26, BatchID=5, Tr_AUC=0.7365, Test_AUC=0.5000, lr=0.0100 Epoch=27, BatchID=5, Tr_AUC=0.8173, Test_AUC=0.5833, lr=0.0100 Epoch=28, BatchID=5, Tr_AUC=0.8051, Test_AUC=0.5417, lr=0.0100 Epoch=29, BatchID=5, Tr_AUC=0.7760, Test_AUC=0.5833, lr=0.0100 Epoch=30, BatchID=5, Tr_AUC=0.8464, Test_AUC=0.5833, lr=0.0100 Epoch=31, BatchID=5, Tr_AUC=0.8507, Test_AUC=0.6667, lr=0.0100 Epoch=32, BatchID=5, Tr_AUC=0.8885, Test_AUC=0.7083, lr=0.0100 Epoch=33, BatchID=5, Tr_AUC=0.8885, Test_AUC=0.7500, lr=0.0100 Epoch=34, BatchID=5, Tr_AUC=0.8906, Test_AUC=0.7500, lr=0.0100 Epoch=35, BatchID=5, Tr_AUC=0.8733, Test_AUC=0.7083, lr=0.0100 Epoch=36, BatchID=5, Tr_AUC=0.9210, Test_AUC=0.7500, lr=0.0100 Epoch=37, BatchID=5, Tr_AUC=0.8967, Test_AUC=0.6667, lr=0.0100 Epoch=38, BatchID=5, Tr_AUC=0.9062, Test_AUC=0.7500, lr=0.0100 Epoch=39, BatchID=5, Tr_AUC=0.9280, Test_AUC=0.7917, lr=0.0100 Epoch=40, BatchID=5, Tr_AUC=0.9219, Test_AUC=0.8333, lr=0.0100 Epoch=41, BatchID=5, Tr_AUC=0.9457, Test_AUC=0.9167, lr=0.0100 Epoch=42, BatchID=5, Tr_AUC=0.9631, Test_AUC=0.8750, lr=0.0100 Epoch=43, BatchID=5, Tr_AUC=0.9570, Test_AUC=0.8750, lr=0.0100 Epoch=44, BatchID=5, Tr_AUC=0.9592, Test_AUC=0.8333, lr=0.0100 Epoch=45, BatchID=5, Tr_AUC=0.9605, Test_AUC=0.8750, lr=0.0100 Epoch=46, BatchID=5, Tr_AUC=0.9514, Test_AUC=0.8750, lr=0.0100 Epoch=47, BatchID=5, Tr_AUC=0.9657, Test_AUC=0.8333, lr=0.0100 Epoch=48, BatchID=5, Tr_AUC=0.9796, Test_AUC=0.8750, lr=0.0100 Epoch=49, BatchID=5, Tr_AUC=0.9774, Test_AUC=0.9167, lr=0.0100 Reducing learning rate to 0.00100 @ T=300! Updating regularizer @ T=300! Epoch=50, BatchID=5, Tr_AUC=0.9648, Test_AUC=0.9167, lr=0.0010 Epoch=51, BatchID=5, Tr_AUC=0.9753, Test_AUC=0.9167, lr=0.0010 Epoch=52, BatchID=5, Tr_AUC=0.9865, Test_AUC=0.9167, lr=0.0010 Epoch=53, BatchID=5, Tr_AUC=0.9766, Test_AUC=0.9167, lr=0.0010 Epoch=54, BatchID=5, Tr_AUC=0.9878, Test_AUC=0.9167, lr=0.0010 Epoch=55, BatchID=5, Tr_AUC=0.9731, Test_AUC=0.9167, lr=0.0010 Epoch=56, BatchID=5, Tr_AUC=0.9848, Test_AUC=0.9167, lr=0.0010 Epoch=57, BatchID=5, Tr_AUC=0.9761, Test_AUC=0.9167, lr=0.0010 Epoch=58, BatchID=5, Tr_AUC=0.9800, Test_AUC=0.9167, lr=0.0010 Epoch=59, BatchID=5, Tr_AUC=0.9787, Test_AUC=0.9167, lr=0.0010 Epoch=60, BatchID=5, Tr_AUC=0.9826, Test_AUC=0.9167, lr=0.0010 Epoch=61, BatchID=5, Tr_AUC=0.9826, Test_AUC=0.9167, lr=0.0010 Epoch=62, BatchID=5, Tr_AUC=0.9740, Test_AUC=0.9167, lr=0.0010 Epoch=63, BatchID=5, Tr_AUC=0.9705, Test_AUC=0.9167, lr=0.0010 Epoch=64, BatchID=5, Tr_AUC=0.9839, Test_AUC=0.9167, lr=0.0010 Epoch=65, BatchID=5, Tr_AUC=0.9818, Test_AUC=0.9167, lr=0.0010 Epoch=66, BatchID=5, Tr_AUC=0.9822, Test_AUC=0.9167, lr=0.0010 Epoch=67, BatchID=5, Tr_AUC=0.9861, Test_AUC=0.9167, lr=0.0010 Epoch=68, BatchID=5, Tr_AUC=0.9844, Test_AUC=0.9167, lr=0.0010 Epoch=69, BatchID=5, Tr_AUC=0.9874, Test_AUC=0.9167, lr=0.0010 Epoch=70, BatchID=5, Tr_AUC=0.9757, Test_AUC=0.9167, lr=0.0010 Epoch=71, BatchID=5, Tr_AUC=0.9831, Test_AUC=0.9167, lr=0.0010 Epoch=72, BatchID=5, Tr_AUC=0.9852, Test_AUC=0.9167, lr=0.0010 Epoch=73, BatchID=5, Tr_AUC=0.9805, Test_AUC=0.9167, lr=0.0010 Epoch=74, BatchID=5, Tr_AUC=0.9861, Test_AUC=0.9167, lr=0.0010 Reducing learning rate to 0.00010 @ T=450! Updating regularizer @ T=450! Epoch=75, BatchID=5, Tr_AUC=0.9809, Test_AUC=0.9167, lr=0.0001 Epoch=76, BatchID=5, Tr_AUC=0.9913, Test_AUC=0.9167, lr=0.0001 Epoch=77, BatchID=5, Tr_AUC=0.9839, Test_AUC=0.9167, lr=0.0001 Epoch=78, BatchID=5, Tr_AUC=0.9883, Test_AUC=0.9167, lr=0.0001 Epoch=79, BatchID=5, Tr_AUC=0.9844, Test_AUC=0.9167, lr=0.0001 Epoch=80, BatchID=5, Tr_AUC=0.9896, Test_AUC=0.9167, lr=0.0001 Epoch=81, BatchID=5, Tr_AUC=0.9805, Test_AUC=0.9167, lr=0.0001 Epoch=82, BatchID=5, Tr_AUC=0.9887, Test_AUC=0.9167, lr=0.0001 Epoch=83, BatchID=5, Tr_AUC=0.9818, Test_AUC=0.9167, lr=0.0001 Epoch=84, BatchID=5, Tr_AUC=0.9861, Test_AUC=0.9167, lr=0.0001 Epoch=85, BatchID=5, Tr_AUC=0.9861, Test_AUC=0.9167, lr=0.0001 Epoch=86, BatchID=5, Tr_AUC=0.9848, Test_AUC=0.9167, lr=0.0001 Epoch=87, BatchID=5, Tr_AUC=0.9852, Test_AUC=0.9167, lr=0.0001 Epoch=88, BatchID=5, Tr_AUC=0.9870, Test_AUC=0.9167, lr=0.0001 Epoch=89, BatchID=5, Tr_AUC=0.9770, Test_AUC=0.9167, lr=0.0001 Epoch=90, BatchID=5, Tr_AUC=0.9939, Test_AUC=0.9167, lr=0.0001 Epoch=91, BatchID=5, Tr_AUC=0.9822, Test_AUC=0.9167, lr=0.0001 Epoch=92, BatchID=5, Tr_AUC=0.9770, Test_AUC=0.9167, lr=0.0001 Epoch=93, BatchID=5, Tr_AUC=0.9822, Test_AUC=0.9167, lr=0.0001 Epoch=94, BatchID=5, Tr_AUC=0.9861, Test_AUC=0.9167, lr=0.0001 Epoch=95, BatchID=5, Tr_AUC=0.9878, Test_AUC=0.9167, lr=0.0001 Epoch=96, BatchID=5, Tr_AUC=0.9844, Test_AUC=0.9167, lr=0.0001 Epoch=97, BatchID=5, Tr_AUC=0.9688, Test_AUC=0.9167, lr=0.0001 Epoch=98, BatchID=5, Tr_AUC=0.9857, Test_AUC=0.9167, lr=0.0001 Epoch=99, BatchID=5, Tr_AUC=0.9909, Test_AUC=0.9167, lr=0.0001 Visualization ----------------------- .. container:: cell code .. code:: python plt.rcParams["figure.figsize"] = (9,5) x=np.arange(len(train_auc)) plt.figure() plt.plot(x, train_auc, linestyle='--', label='train', linewidth=3) plt.plot(x, test_auc, label='test', linewidth=3) plt.title('MUSK2',fontsize=25) plt.legend(fontsize=15) plt.ylabel('AUC',fontsize=25) plt.xlabel('epochs',fontsize=25) .. container:: output execute_result .. container:: output display_data .. image:: ./imgs/midam-att-musk2.png