.. _isogclr_unimodal:
Optimizing Robust Global Contrastive Loss with Small Batch Size (iSogCLR)
================================================================================================================================
.. raw:: html
------------------------------------------------------------------------------------
.. container:: cell markdown
| **Author**: Zhuoning Yuan, Xiyuan Wei
\
Introduction
------------------------------------------------------------------------------------
In this tutorial, you will learn how to train a self-supervised model by
optimizing `Robust Global Contrastive
Loss `__ (RGCLoss) on CIFAR10. This
version was implementated in PyTorch based on
`moco's `__ codebase. It is
recommended to run this notebook on a GPU-enabled environment, e.g.,
`Google Colab `__.
**Reference**
If you find this tutorial helpful in your work, please cite our `library
paper `__ and the following papers:
.. code-block:: RST
@inproceedings{qiu2023isogclr,
title={Not All Semantics are Created Equal: Contrastive Self-supervised Learning with Automatic Temperature Individualization},
author={Qiu, Zi-Hao and Hu, Quanqi and Yuan, Zhuoning and Zhou, Denny and Zhang, Lijun and Yang, Tianbao},
booktitle={International Conference on Machine Learning},
year={2023},
organization={PMLR}
}
Install LibAUC
------------------------------------------------------------------------------------
Let's start with installing our library here. In this tutorial, we will
use the lastest version for LibAUC by using ``pip install -U``.
.. container:: cell code
.. code:: python
!pip install -U libauc
Importing LibAUC
--------------------------------------------------------------------------------
.. code:: python
import libauc
from libauc.models import resnet50, resnet18
from libauc.datasets import CIFAR100
from libauc.optimizers import SogCLR, LARS
from libauc.losses import GCLoss
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import os,math,shutil
Reproducibility
--------------------------------------------------------------------------------
The following function ``set_all_seeds`` limits the number of sources of
randomness behaviors, such as model intialization, data shuffling, etcs.
However, completely reproducible results are not guaranteed across
PyTorch releases [Ref].
.. code:: python
def set_all_seeds(SEED):
# REPRODUCIBILITY
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Robust Global Contrastive Loss
--------------------------------------------------------------------------------
Robust Global Contrastive Loss (RGCLoss) aims to maximize the similarity
between an anchor image :math:`\mathbf{x}_i` and its corresponding
positive image :math:`\mathbf{x}_i^{+}`, while minimizing the similarity
between the anchor and a set of negative samples
:math:`\mathbf{S}_i^{-}`. In this version, negative samples are from
full data instead of mini-batch data. For more details about the
formulation of RGCL, please refer to the
`iSogCLR `__ paper.
.. code:: python
# model: non-linear projection layer
num_proj_layers=2
dim=256
mlp_dim=2048
# dataset: cifar10
data_name = 'cifar10'
batch_size = 256
# optimizer
weight_decay = 1e-6
init_lr=0.075
epochs = 50
warmup_epochs = 2
# dynamtic loss
gamma = 0.9
temperature = 0.5
eta = 0.03
rho = 0.4
beta = 0.5
# path
logdir = './logs/'
logname = 'resnet18_cifar10'
os.makedirs(os.path.join(logdir, logname), exist_ok=True)
Dataset Pipeline for Contrastive Learning
--------------------------------------------------------------------------------
The dataset pipeline presented here is different from standard
pipelines. Firstly, ``TwoCropsTransform`` generates two random augmented
crops of a single image to construct pairwise samples, as opposed to
just one random crop in standard pipeline. Secondly, the
``augmentation`` follows SimCLR’s implementation. Lastly,
``libauc.datasets.CIFAR10`` returns the index of the image along with
the image and its label.
.. code:: python
class TwoCropsTransform:
"""Take two random crops of one image."""
def __init__(self, base_transform1, base_transform2):
self.base_transform1 = base_transform1
self.base_transform2 = base_transform2
def __call__(self, x):
im1 = self.base_transform1(x)
im2 = self.base_transform2(x)
return [im1, im2]
image_size = 32
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
# SimCLR augmentations
augmentation = [
transforms.RandomResizedCrop(image_size, scale=(0.08, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # Not strengthened
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]
DATA_ROOT = './'
train_dataset = libauc.datasets.CIFAR10(
root=DATA_ROOT, train=True, download=True, return_index=True,
transform=TwoCropsTransform(
transforms.Compose(augmentation),
transforms.Compose(augmentation)
)
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
drop_last=True
)
Helper functions
--------------------------------------------------------------------------------
.. code:: python
def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
mlp = []
for l in range(num_layers):
dim1 = input_dim if l == 0 else mlp_dim
dim2 = output_dim if l == num_layers - 1 else mlp_dim
mlp.append(nn.Linear(dim1, dim2, bias=False))
if l < num_layers - 1:
mlp.append(nn.BatchNorm1d(dim2))
mlp.append(nn.ReLU(inplace=True))
elif last_bn:
# Follow SimCLR's design:
# https://github.com/google-research/simclr/blob/master/model_util.py#L157
# For simplicity, we further removed gamma in BN
mlp.append(nn.BatchNorm1d(dim2, affine=False))
return nn.Sequential(*mlp)
def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
"""Decays the learning rate with half-cycle cosine after warmup."""
if epoch < warmup_epochs:
cur_lr = init_lr * epoch / warmup_epochs
else:
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr
return cur_lr
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def train(train_loader, model, loss_fn, optimizer, epoch, lr):
model.train()
iters_per_epoch = len(train_loader)
for i, (images, _, index) in enumerate(train_loader):
cur_lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, lr)
images[0] = images[0].cuda()
images[1] = images[1].cuda()
with torch.cuda.amp.autocast(True):
hidden1 = model(images[0])
hidden2 = model(images[1])
loss = loss_fn(hidden1, hidden2, index)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print(f'Epoch: {epoch}, Dynamtic Loss: {loss:.3f}')
Creating Model & Optimizer
--------------------------------------------------------------------------------
Note that to enable iSogCLR, we need to set ``enable_isogclr=True`` in
‘GCLoss’.
.. code:: python
set_all_seeds(123)
# resNet-18 + 2-layer non-linear layers
base_encoder = resnet18(
pretrained=False, last_activation=None, num_classes=128
)
hidden_dim = base_encoder.fc.weight.shape[1]
del base_encoder.fc # Remove original fc layer
base_encoder.fc = build_mlp(num_proj_layers, hidden_dim, mlp_dim, dim)
base_encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
base_encoder.maxpool = nn.Identity()
model = base_encoder.cuda()
# square root lr scaling
lr = init_lr * math.sqrt(batch_size)
# LARS optimizer
optimizer = libauc.optimizers.SogCLR(
base_encoder.parameters(),
mode = 'lars',
lr=lr,
weight_decay=weight_decay,
momentum=0.9
)
# Global Contrastive Loss
loss_fn = GCLoss('unimodal', N=50000, tau=temperature, gamma=gamma,
eta=eta, rho=rho, beta=beta, distributed=False,
enable_isogclr=True)
Pretraining
--------------------------------------------------------------------------------
.. code:: python
# mixed precision training
scaler = torch.cuda.amp.GradScaler()
print ('Pretraining')
for epoch in range(epochs):
#if epoch in [int(epochs)*0.5, int(epochs)*0.75]:
# optimizer.update_regularizer()
# train for one epoch
train(train_loader, model, loss_fn, optimizer, epoch, lr)
# save checkpoint
if epoch % 5 == 0 or epochs - epoch < 3:
save_checkpoint(
{'epoch': epoch + 1,
'arch': 'resnet18',
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scaler': scaler.state_dict(),
},
is_best=False,
filename=os.path.join(logdir, logname, f'checkpoint_{epoch:04d}.pth.tar'))
Linear Evaluation
--------------------------------------------------------------------------------
By default, we use momentum-SGD without weight decay and a batch size of
1024 for linear classification on frozen features/weights. In this
stage, it runs 90 epochs.
Configurations
^^^^^^^^^^^^^^^^
.. code:: python
# dataset
image_size = 32
batch_size = 512
num_classes = 10 # cifar10
# optimizer
epochs = 90
init_lr = 0.075
weight_decay = 0
# checkpoint
checkpoint_dir = './logs/resnet18_cifar10/checkpoint_0010.pth.tar'
.. code:: python
torch.cuda.empty_cache()
Dataset pipeline
^^^^^^^^^^^^^^^^
.. code:: python
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
train_dataset = libauc.datasets.CIFAR10(root=DATA_ROOT, train=True, download=True,
transform=transforms.Compose([transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,]))
val_dataset = libauc.datasets.CIFAR10(root=DATA_ROOT, train=False,download=True,
transform=transforms.Compose([transforms.ToTensor(),
normalize,]))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
Helper functions (linear evaluation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code:: python
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
"""Decay the learning rate based on schedule"""
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr
def train(train_loader, model, criterion, optimizer, epoch):
model.eval()
for i, (images, target) in enumerate(train_loader):
images = images.float().cuda()
target = target.long().cuda()
output = model(images)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def validate(val_loader, model, criterion):
model.eval()
acc1_list = []
acc5_list = []
with torch.no_grad():
for i, (images, target) in enumerate(val_loader):
images = images.float().cuda()
target = target.long().cuda()
output = model(images)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1_list.append(acc1)
acc5_list.append(acc5)
acc1_array = torch.stack(acc1_list)
acc5_array = torch.stack(acc5_list)
return torch.mean(acc1_array), torch.mean(acc5_array)
Define model
^^^^^^^^^^^^
.. code:: python
set_all_seeds(123)
# ResNet-18 + classification layer
model = resnet18(pretrained=False, last_activation=None, num_classes=128)
hidden_dim = model.fc.weight.shape[1]
del model.fc
model.fc = nn.Linear(hidden_dim, num_classes, bias=True)
# cifar head for resnet18
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
# load pretrained checkpoint excluding non-linear layers
linear_keyword = 'fc'
checkpoint = torch.load(checkpoint_dir, map_location="cpu")
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
if linear_keyword in k:
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print ('Linear Classifier Variables: %s'%(msg.missing_keys))
# cuda
model = model.cuda()
# freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
param.requires_grad = False
# init the fc layer
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
getattr(model, linear_keyword).bias.data.zero_()
# optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2 # weight, bias
Define loss and optimizer
^^^^^^^^^^^^^^^^^^^^^^^^^
.. code:: python
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# linear lr scaling
lr = init_lr * batch_size / 256
optimizer = torch.optim.SGD(parameters,
lr=lr,
momentum=0.9,
weight_decay=weight_decay)
Training (linear evaluation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code:: python
# linear evaluation
print ('Linear Evaluation')
for epoch in range(epochs):
adjust_learning_rate(optimizer, epoch, lr)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
# evaluate on validation set
acc1, acc5 = validate(val_loader, model, criterion)
# log
print ('Epoch: %s, Top1: %.2f, Top5: %.2f'%(epoch, acc1, acc5))
Visualization
--------------------------------------------------------------------------------
We run SimCLR, SogCLR and iSogCLR for 400 epochs on CIFAR10, and
visualize the learned features in the following figures. From the
results we can see that iSogCLR is better at separating the features
from different classes than the other two approaches.
.. code:: python
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
.. code:: python
fig, ax = plt.subplots(1, 3, figsize=(20, 4.8))
simclr = mpimg.imread('simclr.png')
sogclr = mpimg.imread('sogclr.png')
isogclr = mpimg.imread('isogclr.png')
ax[0].imshow(simclr)
ax[1].imshow(sogclr)
ax[2].imshow(isogclr)
ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')
ax[0].set_title('SimCLR')
ax[1].set_title('SogCLR')
ax[2].set_title('iSogCLR')
plt.show()
.. image:: ./imgs/iSogCLR_Unimodal_33_0.png