.. _midam_att_image: ================================================================================================================================ Multiple Instance Deep AUC Maximization with attention pooling (MIDAM-att) on Histopathology (Image) Dataset ================================================================================================================================ .. raw:: html
Run on Colab
Download Notebook
View on Github
------------------------------------------------------------------------------------ .. container:: cell markdown | **Author**: Dixian Zhu, | **Edited by**: Zhuoning Yuan, Tianbao Yang \ Introduction ----------------------- In this tutorial, we will learn how to quickly train a ResNet20 model by optimizing **Multiple Instance Deep AUC Maximization (MIDAM)** under our novel :obj:`MIDAMLoss(mode='attention')` and :obj:`MIDAM` optimizer `[Ref] `__ method on a binary classification task on Breast Cancer Histopathology dataset. After completion of this tutorial, you should be able to use LibAUC to train your own models on your own datasets. **Reference**: If you find this tutorial helpful in your work, please cite our `library paper `__ and the following papers: .. code-block:: RST @inproceedings{zhu2023mil, title={Provable Multi-instance Deep AUC Maximization with Stochastic Pooling}, author={Zhu, Dixian and Wang, Bokun and Chen, Zhi and Wang, Yaxing and Sonka, Milan and Wu, Xiaodong and Yang, Tianbao}, booktitle={International Conference on Machine Learning}, year={2023}, organization={PMLR} } Install LibAUC ------------------------------------------------------------------------------------ Let's start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using ``pip install -U``. .. container:: cell code .. code:: python !pip install -U libauc Importing LibAUC ----------------------- Import required libraries to use .. container:: cell code .. code:: python import torch import matplotlib.pyplot as plt import numpy as np from libauc.optimizers import MIDAM from libauc.losses import MIDAMLoss from libauc.models import ResNet20_stoc_att from libauc.utils import set_all_seeds, collate_fn, MIL_sampling, MIL_evaluate_auc from libauc.sampler import DualSampler from libauc.datasets import BreastCancer, CustomDataset Reproducibility ----------------------- These functions limit the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases `[Ref] `__. .. container:: cell code .. code:: python def set_all_seeds(SEED): # REPRODUCIBILITY torch.manual_seed(SEED) np.random.seed(SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False Introduction for Loss and Optimizer ----------------------- In this section, we will introduce pAUC optimization algorithm and how to utilize ``MIDAMLoss`` function and ``MIDAM`` optimizer. HyperParameters ----------------------- The hyper-parameters: batch size (bag-level), instance batch size (instance-level), postive sampling rate, learning rate, weight decay and margin for AUC loss. .. container:: cell code .. code:: python # HyperParameters SEED = 123 set_all_seeds(SEED) batch_size = 8 instance_batch_size = 128 sampling_rate = 0.5 lr = 5e-2 weight_decay = 5e-4 margin = 0.1 momentum = 0.1 gamma = 0.9 Load Data, initialize model and loss ------------------------------------------------------------------------------------ In this step, we will use the Breast Cancer as benchmark dataset `[Ref] `__. Import data to dataloader. We extend the traditional ResNet20 with an additional attention module: ResNet20_stoc_att. The data are arranged with the shape: (num_bag, num_instance_in_bag, C, H, W) .. container:: cell code .. code:: python (train_data, train_labels), (test_data, test_labels) = BreastCancer(MIL_flag=True) traindSet = CustomDataset(train_data, train_labels, return_index=True) testSet = CustomDataset(test_data, test_labels, return_index=True) DIMS=166 sampler = DualSampler(dataset=traindSet, batch_size=batch_size, shuffle=True, sampling_rate=sampling_rate) trainloader = torch.utils.data.DataLoader(dataset=traindSet, sampler=sampler, batch_size=batch_size, shuffle=False, collate_fn=None) testloader = torch.utils.data.DataLoader(testSet, batch_size=batch_size, shuffle=False, collate_fn=None) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ResNet20_stoc_att(num_classes=1).to(device) Loss = MIDAMLoss(mode='attention',data_len=len(traindSet), gamma=gamma, margin=margin) optimizer = MIDAM(model.parameters(), loss_fn=Loss, lr=lr, weight_decay=weight_decay, momentum=momentum) .. container:: output stream stdout The data shapes for training data/label and testing data/label: :: Downloading https://objects.githubusercontent.com/github-production-release-asset-2e65be/647580747/d75046eb-60ac-47e1-b732-67cc3c71e49a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230609%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230609T164941Z&X-Amz-Expires=300&X-Amz-Signature=3089bd6659de76da484b10a9c75dfd83be65ee4dcbe4af6c7dd672b167d2e6bc&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=647580747&response-content-disposition=attachment%3B%20filename%3Dbreast.npz&response-content-type=application%2Foctet-stream to ./data/Breast_Cancer/breast.npz 100%|██████████| 239469738/239469738 [00:03<00:00, 72018432.08it/s] (52, 672, 3, 32, 32) (52, 1) (6, 672, 3, 32, 32) (6, 1) .. container:: cell markdown .. rubric:: **Training** .. container:: cell code .. code:: python total_epochs = 100 decay_epoch = [50, 75] train_auc = np.zeros(total_epochs) test_auc = np.zeros(total_epochs) for epoch in range(total_epochs): if epoch in decay_epoch: optimizer.update_lr(decay_factor=10) Loss.update_smoothing(decay_factor=2) for idx, data in enumerate(trainloader): y_pred = [] sd = [] train_data_bags, train_labels, ids = data for i in range(len(ids)): tmp_pred, tmp_sd = MIL_sampling(bag_X=train_data_bags[i], model=model, instance_batch_size=instance_batch_size, mode='att') y_pred.append(tmp_pred) sd.append(tmp_sd) y_pred = torch.cat(y_pred, dim=0) sd = torch.cat(sd, dim=0) ids = torch.from_numpy(np.array(ids)) train_labels = torch.from_numpy(np.array(train_labels)) loss = Loss(y_pred=(y_pred,sd), y_true=train_labels, index=ids) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() with torch.no_grad(): single_tr_auc = MIL_evaluate_auc(trainloader, model, mode='att') single_te_auc = MIL_evaluate_auc(testloader, model, mode='att') train_auc[epoch] = single_tr_auc test_auc[epoch] = single_te_auc model.train() print ('Epoch=%s, BatchID=%s, Tr_AUC=%.4f, Test_AUC=%.4f, lr=%.4f'%(epoch, idx, single_tr_auc, single_te_auc, optimizer.lr)) .. container:: output stream stdout :: Epoch=0, BatchID=6, Tr_AUC=0.7258, Test_AUC=0.8750, lr=0.0500 Epoch=1, BatchID=6, Tr_AUC=0.6760, Test_AUC=0.8750, lr=0.0500 Epoch=2, BatchID=6, Tr_AUC=0.7321, Test_AUC=1.0000, lr=0.0500 Epoch=3, BatchID=6, Tr_AUC=0.7054, Test_AUC=0.8750, lr=0.0500 Epoch=4, BatchID=6, Tr_AUC=0.6480, Test_AUC=0.8750, lr=0.0500 Epoch=5, BatchID=6, Tr_AUC=0.7194, Test_AUC=1.0000, lr=0.0500 Epoch=6, BatchID=6, Tr_AUC=0.6837, Test_AUC=1.0000, lr=0.0500 Epoch=7, BatchID=6, Tr_AUC=0.6862, Test_AUC=1.0000, lr=0.0500 Epoch=8, BatchID=6, Tr_AUC=0.7462, Test_AUC=1.0000, lr=0.0500 Epoch=9, BatchID=6, Tr_AUC=0.7015, Test_AUC=1.0000, lr=0.0500 Epoch=10, BatchID=6, Tr_AUC=0.7360, Test_AUC=1.0000, lr=0.0500 Epoch=11, BatchID=6, Tr_AUC=0.7679, Test_AUC=1.0000, lr=0.0500 Epoch=12, BatchID=6, Tr_AUC=0.7717, Test_AUC=1.0000, lr=0.0500 Epoch=13, BatchID=6, Tr_AUC=0.7538, Test_AUC=1.0000, lr=0.0500 Epoch=14, BatchID=6, Tr_AUC=0.7449, Test_AUC=1.0000, lr=0.0500 Epoch=15, BatchID=6, Tr_AUC=0.7168, Test_AUC=1.0000, lr=0.0500 Epoch=16, BatchID=6, Tr_AUC=0.6505, Test_AUC=1.0000, lr=0.0500 Epoch=17, BatchID=6, Tr_AUC=0.7653, Test_AUC=1.0000, lr=0.0500 Epoch=18, BatchID=6, Tr_AUC=0.7921, Test_AUC=1.0000, lr=0.0500 Epoch=19, BatchID=6, Tr_AUC=0.7819, Test_AUC=1.0000, lr=0.0500 Epoch=20, BatchID=6, Tr_AUC=0.7411, Test_AUC=1.0000, lr=0.0500 Epoch=21, BatchID=6, Tr_AUC=0.7870, Test_AUC=1.0000, lr=0.0500 Epoch=22, BatchID=6, Tr_AUC=0.7385, Test_AUC=1.0000, lr=0.0500 Epoch=23, BatchID=6, Tr_AUC=0.7436, Test_AUC=1.0000, lr=0.0500 Epoch=24, BatchID=6, Tr_AUC=0.8202, Test_AUC=1.0000, lr=0.0500 Epoch=25, BatchID=6, Tr_AUC=0.8061, Test_AUC=1.0000, lr=0.0500 Epoch=26, BatchID=6, Tr_AUC=0.8036, Test_AUC=1.0000, lr=0.0500 Epoch=27, BatchID=6, Tr_AUC=0.8010, Test_AUC=1.0000, lr=0.0500 Epoch=28, BatchID=6, Tr_AUC=0.8253, Test_AUC=1.0000, lr=0.0500 Epoch=29, BatchID=6, Tr_AUC=0.8010, Test_AUC=1.0000, lr=0.0500 Epoch=30, BatchID=6, Tr_AUC=0.7436, Test_AUC=1.0000, lr=0.0500 Epoch=31, BatchID=6, Tr_AUC=0.8112, Test_AUC=1.0000, lr=0.0500 Epoch=32, BatchID=6, Tr_AUC=0.8163, Test_AUC=1.0000, lr=0.0500 Epoch=33, BatchID=6, Tr_AUC=0.8099, Test_AUC=1.0000, lr=0.0500 Epoch=34, BatchID=6, Tr_AUC=0.8342, Test_AUC=1.0000, lr=0.0500 Epoch=35, BatchID=6, Tr_AUC=0.8087, Test_AUC=1.0000, lr=0.0500 Epoch=36, BatchID=6, Tr_AUC=0.8087, Test_AUC=1.0000, lr=0.0500 Epoch=37, BatchID=6, Tr_AUC=0.8176, Test_AUC=1.0000, lr=0.0500 Epoch=38, BatchID=6, Tr_AUC=0.8227, Test_AUC=1.0000, lr=0.0500 Epoch=39, BatchID=6, Tr_AUC=0.8495, Test_AUC=1.0000, lr=0.0500 Epoch=40, BatchID=6, Tr_AUC=0.8508, Test_AUC=1.0000, lr=0.0500 Epoch=41, BatchID=6, Tr_AUC=0.8431, Test_AUC=1.0000, lr=0.0500 Epoch=42, BatchID=6, Tr_AUC=0.8036, Test_AUC=1.0000, lr=0.0500 Epoch=43, BatchID=6, Tr_AUC=0.8546, Test_AUC=1.0000, lr=0.0500 Epoch=44, BatchID=6, Tr_AUC=0.8584, Test_AUC=1.0000, lr=0.0500 Epoch=45, BatchID=6, Tr_AUC=0.8355, Test_AUC=1.0000, lr=0.0500 Epoch=46, BatchID=6, Tr_AUC=0.8291, Test_AUC=1.0000, lr=0.0500 Epoch=47, BatchID=6, Tr_AUC=0.8482, Test_AUC=1.0000, lr=0.0500 Epoch=48, BatchID=6, Tr_AUC=0.8444, Test_AUC=1.0000, lr=0.0500 Epoch=49, BatchID=6, Tr_AUC=0.8508, Test_AUC=1.0000, lr=0.0500 Reducing learning rate to 0.00500 @ T=350! Updating regularizer @ T=350! Epoch=50, BatchID=6, Tr_AUC=0.8495, Test_AUC=1.0000, lr=0.0050 Epoch=51, BatchID=6, Tr_AUC=0.8304, Test_AUC=1.0000, lr=0.0050 Epoch=52, BatchID=6, Tr_AUC=0.8635, Test_AUC=1.0000, lr=0.0050 Epoch=53, BatchID=6, Tr_AUC=0.8776, Test_AUC=1.0000, lr=0.0050 Epoch=54, BatchID=6, Tr_AUC=0.8418, Test_AUC=1.0000, lr=0.0050 Epoch=55, BatchID=6, Tr_AUC=0.8240, Test_AUC=1.0000, lr=0.0050 Epoch=56, BatchID=6, Tr_AUC=0.8622, Test_AUC=1.0000, lr=0.0050 Epoch=57, BatchID=6, Tr_AUC=0.8635, Test_AUC=1.0000, lr=0.0050 Epoch=58, BatchID=6, Tr_AUC=0.8559, Test_AUC=1.0000, lr=0.0050 Epoch=59, BatchID=6, Tr_AUC=0.8508, Test_AUC=1.0000, lr=0.0050 Epoch=60, BatchID=6, Tr_AUC=0.8291, Test_AUC=1.0000, lr=0.0050 Epoch=61, BatchID=6, Tr_AUC=0.8240, Test_AUC=1.0000, lr=0.0050 Epoch=62, BatchID=6, Tr_AUC=0.8661, Test_AUC=1.0000, lr=0.0050 Epoch=63, BatchID=6, Tr_AUC=0.8559, Test_AUC=1.0000, lr=0.0050 Epoch=64, BatchID=6, Tr_AUC=0.8304, Test_AUC=1.0000, lr=0.0050 Epoch=65, BatchID=6, Tr_AUC=0.8444, Test_AUC=1.0000, lr=0.0050 Epoch=66, BatchID=6, Tr_AUC=0.8482, Test_AUC=1.0000, lr=0.0050 Epoch=67, BatchID=6, Tr_AUC=0.8801, Test_AUC=1.0000, lr=0.0050 Epoch=68, BatchID=6, Tr_AUC=0.8355, Test_AUC=1.0000, lr=0.0050 Epoch=69, BatchID=6, Tr_AUC=0.8699, Test_AUC=1.0000, lr=0.0050 Epoch=70, BatchID=6, Tr_AUC=0.8367, Test_AUC=1.0000, lr=0.0050 Epoch=71, BatchID=6, Tr_AUC=0.8482, Test_AUC=1.0000, lr=0.0050 Epoch=72, BatchID=6, Tr_AUC=0.8597, Test_AUC=1.0000, lr=0.0050 Epoch=73, BatchID=6, Tr_AUC=0.8673, Test_AUC=1.0000, lr=0.0050 Epoch=74, BatchID=6, Tr_AUC=0.8673, Test_AUC=1.0000, lr=0.0050 Reducing learning rate to 0.00050 @ T=525! Updating regularizer @ T=525! Epoch=75, BatchID=6, Tr_AUC=0.8367, Test_AUC=1.0000, lr=0.0005 Epoch=76, BatchID=6, Tr_AUC=0.9005, Test_AUC=1.0000, lr=0.0005 Epoch=77, BatchID=6, Tr_AUC=0.8686, Test_AUC=1.0000, lr=0.0005 Epoch=78, BatchID=6, Tr_AUC=0.8661, Test_AUC=1.0000, lr=0.0005 Epoch=79, BatchID=6, Tr_AUC=0.8750, Test_AUC=1.0000, lr=0.0005 Epoch=80, BatchID=6, Tr_AUC=0.8508, Test_AUC=1.0000, lr=0.0005 Epoch=81, BatchID=6, Tr_AUC=0.8597, Test_AUC=1.0000, lr=0.0005 Epoch=82, BatchID=6, Tr_AUC=0.8520, Test_AUC=1.0000, lr=0.0005 Epoch=83, BatchID=6, Tr_AUC=0.8214, Test_AUC=1.0000, lr=0.0005 Epoch=84, BatchID=6, Tr_AUC=0.8061, Test_AUC=1.0000, lr=0.0005 Epoch=85, BatchID=6, Tr_AUC=0.8737, Test_AUC=1.0000, lr=0.0005 Epoch=86, BatchID=6, Tr_AUC=0.8686, Test_AUC=1.0000, lr=0.0005 Epoch=87, BatchID=6, Tr_AUC=0.8469, Test_AUC=1.0000, lr=0.0005 Epoch=88, BatchID=6, Tr_AUC=0.8635, Test_AUC=1.0000, lr=0.0005 Epoch=89, BatchID=6, Tr_AUC=0.8686, Test_AUC=1.0000, lr=0.0005 Epoch=90, BatchID=6, Tr_AUC=0.8431, Test_AUC=1.0000, lr=0.0005 Epoch=91, BatchID=6, Tr_AUC=0.9031, Test_AUC=1.0000, lr=0.0005 Epoch=92, BatchID=6, Tr_AUC=0.8673, Test_AUC=1.0000, lr=0.0005 Epoch=93, BatchID=6, Tr_AUC=0.8495, Test_AUC=1.0000, lr=0.0005 Epoch=94, BatchID=6, Tr_AUC=0.8712, Test_AUC=1.0000, lr=0.0005 Epoch=95, BatchID=6, Tr_AUC=0.8469, Test_AUC=1.0000, lr=0.0005 Epoch=96, BatchID=6, Tr_AUC=0.8571, Test_AUC=1.0000, lr=0.0005 Epoch=97, BatchID=6, Tr_AUC=0.8329, Test_AUC=1.0000, lr=0.0005 Epoch=98, BatchID=6, Tr_AUC=0.8750, Test_AUC=1.0000, lr=0.0005 Epoch=99, BatchID=6, Tr_AUC=0.8648, Test_AUC=1.0000, lr=0.0005 Visualization ----------------------- .. container:: cell code .. code:: python plt.rcParams["figure.figsize"] = (9,5) x=np.arange(len(train_auc)) plt.figure() plt.plot(x, train_auc, linestyle='--', label='train', linewidth=3) plt.plot(x, test_auc, label='test', linewidth=3) plt.title('Breast Cancer',fontsize=25) plt.legend(fontsize=15) plt.ylabel('AUC',fontsize=25) plt.xlabel('epochs',fontsize=25) .. container:: output execute_result .. container:: output display_data .. image:: ./imgs/midam-att-BC.png Ablation Study on Attention Weights ---------------------------------------------- .. container:: cell code .. code:: python maxV = -1e10 maxbag = 0 for idx, data in enumerate(trainloader): y_pred = [] sd = [] train_data_bags, train_labels, ids = data for i in range(len(ids)): y_pred_bag, weights_bag = model(train_data_bags[i].float().cuda()) value = torch.sum(y_pred_bag*weights_bag)/torch.sum(weights_bag) if value > maxV: maxV = value maxbag = train_data_bags[i] maxlabel = train_labels[i] print(maxV) print(maxlabel) .. container:: output stream stdout :: tensor(-1.9086, device='cuda:0', grad_fn=) tensor([1.], dtype=torch.float64) .. container:: cell code .. code:: python from mpl_toolkits.axes_grid1 import make_axes_locatable tmp = maxbag tmp = tmp.numpy() tmp = np.transpose(tmp,[0,2,3,1]) img = [] for i in range(24): # collate patches back to the original image arrangement tmpimg = [] for j in range(28): tmpimg.append(tmp[i*28+j]) tmpimg = np.concatenate(tmpimg, axis=1) img.append(tmpimg) img = np.concatenate(img, axis=0) y_pred_bag, weights_bag = model(maxbag.float().cuda()) weights = weights_bag.detach().cpu().numpy() weights = np.reshape(weights, [24,28]) preds = y_pred_bag.detach().cpu().numpy() preds = np.reshape(preds, [24,28]) fig, subfigs = plt.subplots(figsize=(15, 6), ncols=3) imgax = subfigs[0] predax = subfigs[1] attax = subfigs[2] imgfig = imgax.imshow(img, interpolation='nearest') imgax.set_title('medical image') predfig = predax.imshow(preds, interpolation=None, norm=None) predax.set_title('prediction scores') attfig = attax.imshow(weights, interpolation=None, norm=None) attax.set_title('attention weights') divider = make_axes_locatable(imgax) cax = divider.new_vertical(size='5%', pad=0.1, pack_start = True) plt.colorbar(imgfig, cax=cax, shrink=0.0, orientation = 'horizontal', pad=0.06) plt.colorbar(predfig, ax=predax, shrink=0.4, orientation = 'horizontal', pad=0.06) plt.colorbar(attfig, ax=attax, shrink=0.4, orientation = 'horizontal', pad=0.06) plt.show() .. container:: output execute_result .. container:: output display_data .. image:: ./imgs/midam-att-BC-demo.png