Optimizing Robust Global Contrastive Loss with Small Batch Size (iSogCLR)
Introduction
In this tutorial, you will learn how to train a self-supervised model by optimizing Robust Global Contrastive Loss (RGCLoss) on CIFAR10. This version was implementated in PyTorch based on moco’s codebase. It is recommended to run this notebook on a GPU-enabled environment, e.g., Google Colab.
Reference
If you find this tutorial helpful in your work, please cite our library paper and the following papers:
@inproceedings{qiu2023isogclr,
title={Not All Semantics are Created Equal: Contrastive Self-supervised Learning with Automatic Temperature Individualization},
author={Qiu, Zi-Hao and Hu, Quanqi and Yuan, Zhuoning and Zhou, Denny and Zhang, Lijun and Yang, Tianbao},
booktitle={International Conference on Machine Learning},
year={2023},
organization={PMLR}
}
Install LibAUC
Let’s start with installing our library here. In this tutorial, we will
use the lastest version for LibAUC by using pip install -U.
!pip install -U libauc
Importing LibAUC
import libauc
from libauc.models import resnet50, resnet18
from libauc.datasets import CIFAR100
from libauc.optimizers import SogCLR, LARS
from libauc.losses import GCLoss
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import os,math,shutil
Reproducibility
The following function set_all_seeds limits the number of sources of
randomness behaviors, such as model intialization, data shuffling, etcs.
However, completely reproducible results are not guaranteed across
PyTorch releases [Ref].
def set_all_seeds(SEED):
# REPRODUCIBILITY
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Robust Global Contrastive Loss
Robust Global Contrastive Loss (RGCLoss) aims to maximize the similarity between an anchor image \(\mathbf{x}_i\) and its corresponding positive image \(\mathbf{x}_i^{+}\), while minimizing the similarity between the anchor and a set of negative samples \(\mathbf{S}_i^{-}\). In this version, negative samples are from full data instead of mini-batch data. For more details about the formulation of RGCL, please refer to the iSogCLR paper.
# model: non-linear projection layer
num_proj_layers=2
dim=256
mlp_dim=2048
# dataset: cifar10
data_name = 'cifar10'
batch_size = 256
# optimizer
weight_decay = 1e-6
init_lr=0.075
epochs = 50
warmup_epochs = 2
# dynamtic loss
gamma = 0.9
temperature = 0.5
eta = 0.03
rho = 0.4
beta = 0.5
# path
logdir = './logs/'
logname = 'resnet18_cifar10'
os.makedirs(os.path.join(logdir, logname), exist_ok=True)
Dataset Pipeline for Contrastive Learning
The dataset pipeline presented here is different from standard
pipelines. Firstly, TwoCropsTransform generates two random augmented
crops of a single image to construct pairwise samples, as opposed to
just one random crop in standard pipeline. Secondly, the
augmentation follows SimCLR’s implementation. Lastly,
libauc.datasets.CIFAR10 returns the index of the image along with
the image and its label.
class TwoCropsTransform:
"""Take two random crops of one image."""
def __init__(self, base_transform1, base_transform2):
self.base_transform1 = base_transform1
self.base_transform2 = base_transform2
def __call__(self, x):
im1 = self.base_transform1(x)
im2 = self.base_transform2(x)
return [im1, im2]
image_size = 32
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
# SimCLR augmentations
augmentation = [
transforms.RandomResizedCrop(image_size, scale=(0.08, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # Not strengthened
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]
DATA_ROOT = './'
train_dataset = libauc.datasets.CIFAR10(
root=DATA_ROOT, train=True, download=True, return_index=True,
transform=TwoCropsTransform(
transforms.Compose(augmentation),
transforms.Compose(augmentation)
)
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
drop_last=True
)
Helper functions
def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
mlp = []
for l in range(num_layers):
dim1 = input_dim if l == 0 else mlp_dim
dim2 = output_dim if l == num_layers - 1 else mlp_dim
mlp.append(nn.Linear(dim1, dim2, bias=False))
if l < num_layers - 1:
mlp.append(nn.BatchNorm1d(dim2))
mlp.append(nn.ReLU(inplace=True))
elif last_bn:
# Follow SimCLR's design:
# https://github.com/google-research/simclr/blob/master/model_util.py#L157
# For simplicity, we further removed gamma in BN
mlp.append(nn.BatchNorm1d(dim2, affine=False))
return nn.Sequential(*mlp)
def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
"""Decays the learning rate with half-cycle cosine after warmup."""
if epoch < warmup_epochs:
cur_lr = init_lr * epoch / warmup_epochs
else:
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr
return cur_lr
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def train(train_loader, model, loss_fn, optimizer, epoch, lr):
model.train()
iters_per_epoch = len(train_loader)
for i, (images, _, index) in enumerate(train_loader):
cur_lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, lr)
images[0] = images[0].cuda()
images[1] = images[1].cuda()
with torch.cuda.amp.autocast(True):
hidden1 = model(images[0])
hidden2 = model(images[1])
loss = loss_fn(hidden1, hidden2, index)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print(f'Epoch: {epoch}, Dynamtic Loss: {loss:.3f}')
Creating Model & Optimizer
Note that to enable iSogCLR, we need to set enable_isogclr=True in
‘GCLoss’.
set_all_seeds(123)
# resNet-18 + 2-layer non-linear layers
base_encoder = resnet18(
pretrained=False, last_activation=None, num_classes=128
)
hidden_dim = base_encoder.fc.weight.shape[1]
del base_encoder.fc # Remove original fc layer
base_encoder.fc = build_mlp(num_proj_layers, hidden_dim, mlp_dim, dim)
base_encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
base_encoder.maxpool = nn.Identity()
model = base_encoder.cuda()
# square root lr scaling
lr = init_lr * math.sqrt(batch_size)
# LARS optimizer
optimizer = libauc.optimizers.SogCLR(
base_encoder.parameters(),
mode = 'lars',
lr=lr,
weight_decay=weight_decay,
momentum=0.9
)
# Global Contrastive Loss
loss_fn = GCLoss('unimodal', N=50000, tau=temperature, gamma=gamma,
eta=eta, rho=rho, beta=beta, distributed=False,
enable_isogclr=True)
Pretraining
# mixed precision training
scaler = torch.cuda.amp.GradScaler()
print ('Pretraining')
for epoch in range(epochs):
#if epoch in [int(epochs)*0.5, int(epochs)*0.75]:
# optimizer.update_regularizer()
# train for one epoch
train(train_loader, model, loss_fn, optimizer, epoch, lr)
# save checkpoint
if epoch % 5 == 0 or epochs - epoch < 3:
save_checkpoint(
{'epoch': epoch + 1,
'arch': 'resnet18',
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scaler': scaler.state_dict(),
},
is_best=False,
filename=os.path.join(logdir, logname, f'checkpoint_{epoch:04d}.pth.tar'))
Linear Evaluation
By default, we use momentum-SGD without weight decay and a batch size of 1024 for linear classification on frozen features/weights. In this stage, it runs 90 epochs.
Configurations
# dataset
image_size = 32
batch_size = 512
num_classes = 10 # cifar10
# optimizer
epochs = 90
init_lr = 0.075
weight_decay = 0
# checkpoint
checkpoint_dir = './logs/resnet18_cifar10/checkpoint_0010.pth.tar'
torch.cuda.empty_cache()
Dataset pipeline
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
train_dataset = libauc.datasets.CIFAR10(root=DATA_ROOT, train=True, download=True,
transform=transforms.Compose([transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,]))
val_dataset = libauc.datasets.CIFAR10(root=DATA_ROOT, train=False,download=True,
transform=transforms.Compose([transforms.ToTensor(),
normalize,]))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
Helper functions (linear evaluation)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
"""Decay the learning rate based on schedule"""
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr
def train(train_loader, model, criterion, optimizer, epoch):
model.eval()
for i, (images, target) in enumerate(train_loader):
images = images.float().cuda()
target = target.long().cuda()
output = model(images)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def validate(val_loader, model, criterion):
model.eval()
acc1_list = []
acc5_list = []
with torch.no_grad():
for i, (images, target) in enumerate(val_loader):
images = images.float().cuda()
target = target.long().cuda()
output = model(images)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1_list.append(acc1)
acc5_list.append(acc5)
acc1_array = torch.stack(acc1_list)
acc5_array = torch.stack(acc5_list)
return torch.mean(acc1_array), torch.mean(acc5_array)
Define model
set_all_seeds(123)
# ResNet-18 + classification layer
model = resnet18(pretrained=False, last_activation=None, num_classes=128)
hidden_dim = model.fc.weight.shape[1]
del model.fc
model.fc = nn.Linear(hidden_dim, num_classes, bias=True)
# cifar head for resnet18
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
# load pretrained checkpoint excluding non-linear layers
linear_keyword = 'fc'
checkpoint = torch.load(checkpoint_dir, map_location="cpu")
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
if linear_keyword in k:
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print ('Linear Classifier Variables: %s'%(msg.missing_keys))
# cuda
model = model.cuda()
# freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
param.requires_grad = False
# init the fc layer
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
getattr(model, linear_keyword).bias.data.zero_()
# optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2 # weight, bias
Define loss and optimizer
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# linear lr scaling
lr = init_lr * batch_size / 256
optimizer = torch.optim.SGD(parameters,
lr=lr,
momentum=0.9,
weight_decay=weight_decay)
Training (linear evaluation)
# linear evaluation
print ('Linear Evaluation')
for epoch in range(epochs):
adjust_learning_rate(optimizer, epoch, lr)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
# evaluate on validation set
acc1, acc5 = validate(val_loader, model, criterion)
# log
print ('Epoch: %s, Top1: %.2f, Top5: %.2f'%(epoch, acc1, acc5))
Visualization
We run SimCLR, SogCLR and iSogCLR for 400 epochs on CIFAR10, and visualize the learned features in the following figures. From the results we can see that iSogCLR is better at separating the features from different classes than the other two approaches.
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig, ax = plt.subplots(1, 3, figsize=(20, 4.8))
simclr = mpimg.imread('simclr.png')
sogclr = mpimg.imread('sogclr.png')
isogclr = mpimg.imread('isogclr.png')
ax[0].imshow(simclr)
ax[1].imshow(sogclr)
ax[2].imshow(isogclr)
ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')
ax[0].set_title('SimCLR')
ax[1].set_title('SogCLR')
ax[2].set_title('iSogCLR')
plt.show()