.. _isogclr_unimodal: Optimizing Robust Global Contrastive Loss with Small Batch Size (iSogCLR) ================================================================================================================================ .. raw:: html
Run on Colab
Download Notebook
View on Github
------------------------------------------------------------------------------------ .. container:: cell markdown | **Author**: Zhuoning Yuan, Xiyuan Wei \ Introduction ------------------------------------------------------------------------------------ In this tutorial, you will learn how to train a self-supervised model by optimizing `Robust Global Contrastive Loss `__ (RGCLoss) on CIFAR10. This version was implementated in PyTorch based on `moco's `__ codebase. It is recommended to run this notebook on a GPU-enabled environment, e.g., `Google Colab `__. **Reference** If you find this tutorial helpful in your work, please cite our `library paper `__ and the following papers: .. code-block:: RST @inproceedings{qiu2023isogclr, title={Not All Semantics are Created Equal: Contrastive Self-supervised Learning with Automatic Temperature Individualization}, author={Qiu, Zi-Hao and Hu, Quanqi and Yuan, Zhuoning and Zhou, Denny and Zhang, Lijun and Yang, Tianbao}, booktitle={International Conference on Machine Learning}, year={2023}, organization={PMLR} } Install LibAUC ------------------------------------------------------------------------------------ Let's start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using ``pip install -U``. .. container:: cell code .. code:: python !pip install -U libauc Importing LibAUC -------------------------------------------------------------------------------- .. code:: python import libauc from libauc.models import resnet50, resnet18 from libauc.datasets import CIFAR100 from libauc.optimizers import SogCLR, LARS from libauc.losses import GCLoss import torch import torchvision.transforms as transforms import torch.nn as nn import numpy as np import os,math,shutil Reproducibility -------------------------------------------------------------------------------- The following function ``set_all_seeds`` limits the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases [Ref]. .. code:: python def set_all_seeds(SEED): # REPRODUCIBILITY np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False Robust Global Contrastive Loss -------------------------------------------------------------------------------- Robust Global Contrastive Loss (RGCLoss) aims to maximize the similarity between an anchor image :math:`\mathbf{x}_i` and its corresponding positive image :math:`\mathbf{x}_i^{+}`, while minimizing the similarity between the anchor and a set of negative samples :math:`\mathbf{S}_i^{-}`. In this version, negative samples are from full data instead of mini-batch data. For more details about the formulation of RGCL, please refer to the `iSogCLR `__ paper. .. code:: python # model: non-linear projection layer num_proj_layers=2 dim=256 mlp_dim=2048 # dataset: cifar10 data_name = 'cifar10' batch_size = 256 # optimizer weight_decay = 1e-6 init_lr=0.075 epochs = 50 warmup_epochs = 2 # dynamtic loss gamma = 0.9 temperature = 0.5 eta = 0.03 rho = 0.4 beta = 0.5 # path logdir = './logs/' logname = 'resnet18_cifar10' os.makedirs(os.path.join(logdir, logname), exist_ok=True) Dataset Pipeline for Contrastive Learning -------------------------------------------------------------------------------- The dataset pipeline presented here is different from standard pipelines. Firstly, ``TwoCropsTransform`` generates two random augmented crops of a single image to construct pairwise samples, as opposed to just one random crop in standard pipeline. Secondly, the ``augmentation`` follows SimCLR’s implementation. Lastly, ``libauc.datasets.CIFAR10`` returns the index of the image along with the image and its label. .. code:: python class TwoCropsTransform: """Take two random crops of one image.""" def __init__(self, base_transform1, base_transform2): self.base_transform1 = base_transform1 self.base_transform2 = base_transform2 def __call__(self, x): im1 = self.base_transform1(x) im2 = self.base_transform2(x) return [im1, im2] image_size = 32 mean = [0.4914, 0.4822, 0.4465] std = [0.2470, 0.2435, 0.2616] normalize = transforms.Normalize(mean=mean, std=std) # SimCLR augmentations augmentation = [ transforms.RandomResizedCrop(image_size, scale=(0.08, 1.)), transforms.RandomApply([ transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # Not strengthened ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ] DATA_ROOT = './' train_dataset = libauc.datasets.CIFAR10( root=DATA_ROOT, train=True, download=True, return_index=True, transform=TwoCropsTransform( transforms.Compose(augmentation), transforms.Compose(augmentation) ) ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True ) Helper functions -------------------------------------------------------------------------------- .. code:: python def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True): mlp = [] for l in range(num_layers): dim1 = input_dim if l == 0 else mlp_dim dim2 = output_dim if l == num_layers - 1 else mlp_dim mlp.append(nn.Linear(dim1, dim2, bias=False)) if l < num_layers - 1: mlp.append(nn.BatchNorm1d(dim2)) mlp.append(nn.ReLU(inplace=True)) elif last_bn: # Follow SimCLR's design: # https://github.com/google-research/simclr/blob/master/model_util.py#L157 # For simplicity, we further removed gamma in BN mlp.append(nn.BatchNorm1d(dim2, affine=False)) return nn.Sequential(*mlp) def adjust_learning_rate(optimizer, epoch, init_lr=0.075): """Decays the learning rate with half-cycle cosine after warmup.""" if epoch < warmup_epochs: cur_lr = init_lr * epoch / warmup_epochs else: cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) for param_group in optimizer.param_groups: param_group['lr'] = cur_lr return cur_lr def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: shutil.copyfile(filename, 'model_best.pth.tar') def train(train_loader, model, loss_fn, optimizer, epoch, lr): model.train() iters_per_epoch = len(train_loader) for i, (images, _, index) in enumerate(train_loader): cur_lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, lr) images[0] = images[0].cuda() images[1] = images[1].cuda() with torch.cuda.amp.autocast(True): hidden1 = model(images[0]) hidden2 = model(images[1]) loss = loss_fn(hidden1, hidden2, index) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() print(f'Epoch: {epoch}, Dynamtic Loss: {loss:.3f}') Creating Model & Optimizer -------------------------------------------------------------------------------- Note that to enable iSogCLR, we need to set ``enable_isogclr=True`` in ‘GCLoss’. .. code:: python set_all_seeds(123) # resNet-18 + 2-layer non-linear layers base_encoder = resnet18( pretrained=False, last_activation=None, num_classes=128 ) hidden_dim = base_encoder.fc.weight.shape[1] del base_encoder.fc # Remove original fc layer base_encoder.fc = build_mlp(num_proj_layers, hidden_dim, mlp_dim, dim) base_encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) base_encoder.maxpool = nn.Identity() model = base_encoder.cuda() # square root lr scaling lr = init_lr * math.sqrt(batch_size) # LARS optimizer optimizer = libauc.optimizers.SogCLR( base_encoder.parameters(), mode = 'lars', lr=lr, weight_decay=weight_decay, momentum=0.9 ) # Global Contrastive Loss loss_fn = GCLoss('unimodal', N=50000, tau=temperature, gamma=gamma, eta=eta, rho=rho, beta=beta, distributed=False, enable_isogclr=True) Pretraining -------------------------------------------------------------------------------- .. code:: python # mixed precision training scaler = torch.cuda.amp.GradScaler() print ('Pretraining') for epoch in range(epochs): #if epoch in [int(epochs)*0.5, int(epochs)*0.75]: # optimizer.update_regularizer() # train for one epoch train(train_loader, model, loss_fn, optimizer, epoch, lr) # save checkpoint if epoch % 5 == 0 or epochs - epoch < 3: save_checkpoint( {'epoch': epoch + 1, 'arch': 'resnet18', 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scaler': scaler.state_dict(), }, is_best=False, filename=os.path.join(logdir, logname, f'checkpoint_{epoch:04d}.pth.tar')) Linear Evaluation -------------------------------------------------------------------------------- By default, we use momentum-SGD without weight decay and a batch size of 1024 for linear classification on frozen features/weights. In this stage, it runs 90 epochs. Configurations ^^^^^^^^^^^^^^^^ .. code:: python # dataset image_size = 32 batch_size = 512 num_classes = 10 # cifar10 # optimizer epochs = 90 init_lr = 0.075 weight_decay = 0 # checkpoint checkpoint_dir = './logs/resnet18_cifar10/checkpoint_0010.pth.tar' .. code:: python torch.cuda.empty_cache() Dataset pipeline ^^^^^^^^^^^^^^^^ .. code:: python mean = [0.4914, 0.4822, 0.4465] std = [0.2470, 0.2435, 0.2616] normalize = transforms.Normalize(mean=mean, std=std) train_dataset = libauc.datasets.CIFAR10(root=DATA_ROOT, train=True, download=True, transform=transforms.Compose([transforms.RandomResizedCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize,])) val_dataset = libauc.datasets.CIFAR10(root=DATA_ROOT, train=False,download=True, transform=transforms.Compose([transforms.ToTensor(), normalize,])) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) Helper functions (linear evaluation) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def adjust_learning_rate(optimizer, epoch, init_lr=0.075): """Decay the learning rate based on schedule""" cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))) for param_group in optimizer.param_groups: param_group['lr'] = cur_lr def train(train_loader, model, criterion, optimizer, epoch): model.eval() for i, (images, target) in enumerate(train_loader): images = images.float().cuda() target = target.long().cuda() output = model(images) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() def validate(val_loader, model, criterion): model.eval() acc1_list = [] acc5_list = [] with torch.no_grad(): for i, (images, target) in enumerate(val_loader): images = images.float().cuda() target = target.long().cuda() output = model(images) acc1, acc5 = accuracy(output, target, topk=(1, 5)) acc1_list.append(acc1) acc5_list.append(acc5) acc1_array = torch.stack(acc1_list) acc5_array = torch.stack(acc5_list) return torch.mean(acc1_array), torch.mean(acc5_array) Define model ^^^^^^^^^^^^ .. code:: python set_all_seeds(123) # ResNet-18 + classification layer model = resnet18(pretrained=False, last_activation=None, num_classes=128) hidden_dim = model.fc.weight.shape[1] del model.fc model.fc = nn.Linear(hidden_dim, num_classes, bias=True) # cifar head for resnet18 model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) model.maxpool = nn.Identity() # load pretrained checkpoint excluding non-linear layers linear_keyword = 'fc' checkpoint = torch.load(checkpoint_dir, map_location="cpu") state_dict = checkpoint['state_dict'] for k in list(state_dict.keys()): if linear_keyword in k: del state_dict[k] msg = model.load_state_dict(state_dict, strict=False) print ('Linear Classifier Variables: %s'%(msg.missing_keys)) # cuda model = model.cuda() # freeze all layers but the last fc for name, param in model.named_parameters(): if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]: param.requires_grad = False # init the fc layer getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) getattr(model, linear_keyword).bias.data.zero_() # optimize only the linear classifier parameters = list(filter(lambda p: p.requires_grad, model.parameters())) assert len(parameters) == 2 # weight, bias Define loss and optimizer ^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() # linear lr scaling lr = init_lr * batch_size / 256 optimizer = torch.optim.SGD(parameters, lr=lr, momentum=0.9, weight_decay=weight_decay) Training (linear evaluation) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python # linear evaluation print ('Linear Evaluation') for epoch in range(epochs): adjust_learning_rate(optimizer, epoch, lr) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set acc1, acc5 = validate(val_loader, model, criterion) # log print ('Epoch: %s, Top1: %.2f, Top5: %.2f'%(epoch, acc1, acc5)) Visualization -------------------------------------------------------------------------------- We run SimCLR, SogCLR and iSogCLR for 400 epochs on CIFAR10, and visualize the learned features in the following figures. From the results we can see that iSogCLR is better at separating the features from different classes than the other two approaches. .. code:: python import matplotlib.pyplot as plt import matplotlib.image as mpimg .. code:: python fig, ax = plt.subplots(1, 3, figsize=(20, 4.8)) simclr = mpimg.imread('simclr.png') sogclr = mpimg.imread('sogclr.png') isogclr = mpimg.imread('isogclr.png') ax[0].imshow(simclr) ax[1].imshow(sogclr) ax[2].imshow(isogclr) ax[0].axis('off') ax[1].axis('off') ax[2].axis('off') ax[0].set_title('SimCLR') ax[1].set_title('SogCLR') ax[2].set_title('iSogCLR') plt.show() .. image:: ./imgs/iSogCLR_Unimodal_33_0.png