Optimizing Robust Global Contrastive Loss with Small Batch Size (iSogCLR)


Author: Zhuoning Yuan, Xiyuan Wei

Introduction

In this tutorial, you will learn how to train a self-supervised model by optimizing Robust Global Contrastive Loss (RGCLoss) on CIFAR10. This version was implementated in PyTorch based on moco’s codebase. It is recommended to run this notebook on a GPU-enabled environment, e.g., Google Colab.

Reference

If you find this tutorial helpful in your work, please cite our library paper and the following papers:

@inproceedings{qiu2023isogclr,
    title={Not All Semantics are Created Equal: Contrastive Self-supervised Learning with Automatic Temperature Individualization},
    author={Qiu, Zi-Hao and Hu, Quanqi and Yuan, Zhuoning and Zhou, Denny and Zhang, Lijun and Yang, Tianbao},
    booktitle={International Conference on Machine Learning},
    year={2023},
    organization={PMLR}
}

Install LibAUC

Let’s start with installing our library here. In this tutorial, we will use the lastest version for LibAUC by using pip install -U.

!pip install -U libauc

Importing LibAUC

import libauc
from libauc.models import resnet50, resnet18
from libauc.datasets import CIFAR100
from libauc.optimizers import SogCLR, LARS
from libauc.losses import GCLoss

import torch
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import os,math,shutil

Reproducibility

The following function set_all_seeds limits the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases [Ref].

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Robust Global Contrastive Loss

Robust Global Contrastive Loss (RGCLoss) aims to maximize the similarity between an anchor image \(\mathbf{x}_i\) and its corresponding positive image \(\mathbf{x}_i^{+}\), while minimizing the similarity between the anchor and a set of negative samples \(\mathbf{S}_i^{-}\). In this version, negative samples are from full data instead of mini-batch data. For more details about the formulation of RGCL, please refer to the iSogCLR paper.

# model: non-linear projection layer
num_proj_layers=2
dim=256
mlp_dim=2048

# dataset: cifar10
data_name = 'cifar10'
batch_size = 256

# optimizer
weight_decay = 1e-6
init_lr=0.075
epochs = 50
warmup_epochs = 2

# dynamtic loss
gamma = 0.9
temperature = 0.5
eta = 0.03
rho = 0.4
beta = 0.5

# path
logdir = './logs/'
logname = 'resnet18_cifar10'
os.makedirs(os.path.join(logdir, logname), exist_ok=True)

Dataset Pipeline for Contrastive Learning

The dataset pipeline presented here is different from standard pipelines. Firstly, TwoCropsTransform generates two random augmented crops of a single image to construct pairwise samples, as opposed to just one random crop in standard pipeline. Secondly, the augmentation follows SimCLR’s implementation. Lastly, libauc.datasets.CIFAR10 returns the index of the image along with the image and its label.

class TwoCropsTransform:
    """Take two random crops of one image."""
    def __init__(self, base_transform1, base_transform2):
        self.base_transform1 = base_transform1
        self.base_transform2 = base_transform2

    def __call__(self, x):
        im1 = self.base_transform1(x)
        im2 = self.base_transform2(x)
        return [im1, im2]


image_size = 32
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]

normalize = transforms.Normalize(mean=mean, std=std)

# SimCLR augmentations
augmentation = [
    transforms.RandomResizedCrop(image_size, scale=(0.08, 1.)),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # Not strengthened
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
]

DATA_ROOT = './'
train_dataset = libauc.datasets.CIFAR10(
    root=DATA_ROOT, train=True, download=True, return_index=True,
    transform=TwoCropsTransform(
        transforms.Compose(augmentation),
        transforms.Compose(augmentation)
    )
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True
)

Helper functions

def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
    mlp = []
    for l in range(num_layers):
        dim1 = input_dim if l == 0 else mlp_dim
        dim2 = output_dim if l == num_layers - 1 else mlp_dim
        mlp.append(nn.Linear(dim1, dim2, bias=False))

        if l < num_layers - 1:
            mlp.append(nn.BatchNorm1d(dim2))
            mlp.append(nn.ReLU(inplace=True))
        elif last_bn:
            # Follow SimCLR's design:
            # https://github.com/google-research/simclr/blob/master/model_util.py#L157
            # For simplicity, we further removed gamma in BN
            mlp.append(nn.BatchNorm1d(dim2, affine=False))
    return nn.Sequential(*mlp)

def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
    """Decays the learning rate with half-cycle cosine after warmup."""
    if epoch < warmup_epochs:
        cur_lr = init_lr * epoch / warmup_epochs
    else:
        cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = cur_lr
    return cur_lr

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

def train(train_loader, model, loss_fn, optimizer, epoch, lr):
    model.train()
    iters_per_epoch = len(train_loader)
    for i, (images, _, index) in enumerate(train_loader):
        cur_lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, lr)
        images[0] = images[0].cuda()
        images[1] = images[1].cuda()

        with torch.cuda.amp.autocast(True):
            hidden1 = model(images[0])
            hidden2 = model(images[1])
            loss = loss_fn(hidden1, hidden2, index)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    print(f'Epoch: {epoch}, Dynamtic Loss: {loss:.3f}')

Creating Model & Optimizer

Note that to enable iSogCLR, we need to set enable_isogclr=True in ‘GCLoss’.

set_all_seeds(123)

# resNet-18 + 2-layer non-linear layers
base_encoder = resnet18(
    pretrained=False, last_activation=None, num_classes=128
)
hidden_dim = base_encoder.fc.weight.shape[1]
del base_encoder.fc  # Remove original fc layer

base_encoder.fc = build_mlp(num_proj_layers, hidden_dim, mlp_dim, dim)
base_encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
base_encoder.maxpool = nn.Identity()
model = base_encoder.cuda()

# square root lr scaling
lr = init_lr * math.sqrt(batch_size)

# LARS optimizer
optimizer = libauc.optimizers.SogCLR(
    base_encoder.parameters(),
    mode = 'lars',
    lr=lr,
    weight_decay=weight_decay,
    momentum=0.9
)


# Global Contrastive Loss
loss_fn = GCLoss('unimodal', N=50000, tau=temperature, gamma=gamma,
                 eta=eta, rho=rho, beta=beta, distributed=False,
                 enable_isogclr=True)

Pretraining

# mixed precision training
scaler = torch.cuda.amp.GradScaler()

print ('Pretraining')
for epoch in range(epochs):

    #if epoch in [int(epochs)*0.5, int(epochs)*0.75]:
    #   optimizer.update_regularizer()

    # train for one epoch
    train(train_loader, model, loss_fn, optimizer, epoch, lr)

    # save checkpoint
    if epoch % 5 == 0 or epochs - epoch < 3:
        save_checkpoint(
            {'epoch': epoch + 1,
             'arch': 'resnet18',
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'scaler': scaler.state_dict(),
            },
            is_best=False,
            filename=os.path.join(logdir, logname, f'checkpoint_{epoch:04d}.pth.tar'))

Linear Evaluation

By default, we use momentum-SGD without weight decay and a batch size of 1024 for linear classification on frozen features/weights. In this stage, it runs 90 epochs.

Configurations

# dataset
image_size = 32
batch_size = 512
num_classes = 10 # cifar10

# optimizer
epochs = 90
init_lr = 0.075
weight_decay = 0

# checkpoint
checkpoint_dir = './logs/resnet18_cifar10/checkpoint_0010.pth.tar'
torch.cuda.empty_cache()

Dataset pipeline

mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
train_dataset = libauc.datasets.CIFAR10(root=DATA_ROOT, train=True, download=True,
                                         transform=transforms.Compose([transforms.RandomResizedCrop(32),
                                                                       transforms.RandomHorizontalFlip(),
                                                                       transforms.ToTensor(),
                                                                       normalize,]))
val_dataset = libauc.datasets.CIFAR10(root=DATA_ROOT, train=False,download=True,
                                      transform=transforms.Compose([transforms.ToTensor(),
                                                                    normalize,]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Helper functions (linear evaluation)

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = cur_lr

def train(train_loader, model, criterion, optimizer, epoch):
    model.eval()
    for i, (images, target) in enumerate(train_loader):
        images = images.float().cuda()
        target = target.long().cuda()
        output = model(images)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def validate(val_loader, model, criterion):
    model.eval()
    acc1_list = []
    acc5_list = []
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.float().cuda()
            target = target.long().cuda()
            output = model(images)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            acc1_list.append(acc1)
            acc5_list.append(acc5)
    acc1_array = torch.stack(acc1_list)
    acc5_array = torch.stack(acc5_list)
    return torch.mean(acc1_array), torch.mean(acc5_array)

Define model

set_all_seeds(123)

# ResNet-18 + classification layer
model = resnet18(pretrained=False, last_activation=None, num_classes=128)
hidden_dim = model.fc.weight.shape[1]
del model.fc
model.fc = nn.Linear(hidden_dim, num_classes, bias=True)

# cifar head for resnet18
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()

# load pretrained checkpoint excluding non-linear layers
linear_keyword = 'fc'
checkpoint = torch.load(checkpoint_dir, map_location="cpu")
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
    if linear_keyword in k:
       del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print ('Linear Classifier Variables: %s'%(msg.missing_keys))

# cuda
model = model.cuda()

# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
       param.requires_grad = False

# init the fc layer
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
getattr(model, linear_keyword).bias.data.zero_()

# optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # weight, bias

Define loss and optimizer

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()

# linear lr scaling
lr = init_lr * batch_size / 256
optimizer = torch.optim.SGD(parameters,
                            lr=lr,
                            momentum=0.9,
                            weight_decay=weight_decay)

Training (linear evaluation)

# linear evaluation
print ('Linear Evaluation')
for epoch in range(epochs):

      adjust_learning_rate(optimizer, epoch, lr)

      # train for one epoch
      train(train_loader, model, criterion, optimizer, epoch)

      # evaluate on validation set
      acc1, acc5 = validate(val_loader, model, criterion)

      # log
      print ('Epoch: %s, Top1: %.2f, Top5: %.2f'%(epoch, acc1, acc5))

Visualization

We run SimCLR, SogCLR and iSogCLR for 400 epochs on CIFAR10, and visualize the learned features in the following figures. From the results we can see that iSogCLR is better at separating the features from different classes than the other two approaches.

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig, ax = plt.subplots(1, 3, figsize=(20, 4.8))
simclr = mpimg.imread('simclr.png')
sogclr = mpimg.imread('sogclr.png')
isogclr = mpimg.imread('isogclr.png')
ax[0].imshow(simclr)
ax[1].imshow(sogclr)
ax[2].imshow(isogclr)
ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')
ax[0].set_title('SimCLR')
ax[1].set_title('SogCLR')
ax[2].set_title('iSogCLR')
plt.show()
../_images/iSogCLR_Unimodal_33_0.png