Optimizing Global Contrastive Loss with Small Batch Size (SogCLR)


Author: Zhuoning Yuan, Tianbao Yang

Introduction

In this tutorial, you will learn how to train a self-supervised model by optimizing Global Contrastive Loss (GCLoss) on CIFAR10/CIFAR100. This version was implementated in PyTorch based on moco’s codebase. It is recommended to run this notebook on a GPU-enabled environment, e.g., Google Colab. For training ImageNet-1K, please refer to this Github repo.

Reference

If you find this tutorial helpful in your work, please cite our library paper and the following papers:

@inproceedings{yuan2022provable,
  title={Provable stochastic optimization for global contrastive learning: Small batch does not harm performance},
  author={Yuan, Zhuoning and Wu, Yuexin and Qiu, Zi-Hao and Du, Xianzhi and Zhang, Lijun and Zhou, Denny and Yang, Tianbao},
  booktitle={International Conference on Machine Learning},
  pages={25760--25782},
  year={2022},
  organization={PMLR}
}

Install LibAUC

Let’s start with install our library here. In this tutorial, we will use the lastest version for LibAUC by using pip install -U.

!pip install -U libauc

Importing LibAUC

Importing related packages

import libauc
from libauc.models import resnet50, resnet18
from libauc.datasets import CIFAR100
from libauc.optimizers import SogCLR
from libauc.losses import GCLoss

import torch
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import os,math,shutil

Reproducibility

The following function set_all_seeds limits the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. However, completely reproducible results are not guaranteed across PyTorch releases [Ref].

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Global Contrastive Loss

Global Contrastive Loss (GCLoss) aims to maximize the similarity between an anchor image \(\mathbf{x}_i\) and its corresponding positive image \(\mathbf{x}_i^{+}\), while minimizing the similarity between the anchor and a set of negative samples \(\mathbf{S}_i^{-}\). For GCLoss, negative samples are from full data instead of mini-batch data. For more details about the formulation of GCL, please refer to the SogCLR paper.

Hyper-parameters

# model: non-linear projection layer
num_proj_layers=2
dim=256
mlp_dim=2048

# dataset: cifar100
data_name = 'cifar100'
batch_size = 256

# optimizer
weight_decay = 1e-6
init_lr=0.075
epochs = 200
warmup_epochs = 10

# dynamtic loss
gamma = 0.9
temperature = 0.5

# path
logdir = './logs/'
logname = 'resnet18_cifar100'
os.makedirs(os.path.join(logdir, logname), exist_ok=True)

Dataset Pipeline for Contrastive Learning

The dataset pipeline presented here is different from standard pipelines. Firstly, TwoCropsTransform generates two random augmented crops of a single image to construct pairwise samples, as opposed to just one random crop in standard pipeline. Secondly, the augmentation follows SimCLR’s implementation. Lastly, libauc.datasets.CIFAR100 returns the index of the image along with the image and its label.

class TwoCropsTransform:
    """Take two random crops of one image."""
    def __init__(self, base_transform1, base_transform2):
        self.base_transform1 = base_transform1
        self.base_transform2 = base_transform2

    def __call__(self, x):
        im1 = self.base_transform1(x)
        im2 = self.base_transform2(x)
        return [im1, im2]


image_size = 32
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]

normalize = transforms.Normalize(mean=mean, std=std)

# SimCLR augmentations
augmentation = [
    transforms.RandomResizedCrop(image_size, scale=(0.08, 1.)),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # Not strengthened
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
]

DATA_ROOT = './'
train_dataset = libauc.datasets.CIFAR100(
    root=DATA_ROOT, train=True, download=True, return_index=True,
    transform=TwoCropsTransform(
        transforms.Compose(augmentation),
        transforms.Compose(augmentation)
    )
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True
)

Helper functions

def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
    mlp = []
    for l in range(num_layers):
        dim1 = input_dim if l == 0 else mlp_dim
        dim2 = output_dim if l == num_layers - 1 else mlp_dim
        mlp.append(nn.Linear(dim1, dim2, bias=False))

        if l < num_layers - 1:
            mlp.append(nn.BatchNorm1d(dim2))
            mlp.append(nn.ReLU(inplace=True))
        elif last_bn:
            # Follow SimCLR's design:
            # https://github.com/google-research/simclr/blob/master/model_util.py#L157
            # For simplicity, we further removed gamma in BN
            mlp.append(nn.BatchNorm1d(dim2, affine=False))
    return nn.Sequential(*mlp)

def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
    """Decays the learning rate with half-cycle cosine after warmup."""
    if epoch < warmup_epochs:
        cur_lr = init_lr * epoch / warmup_epochs
    else:
        cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = cur_lr
    return cur_lr

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

def train(train_loader, model, loss_fn, optimizer, epoch, lr):
    model.train()
    iters_per_epoch = len(train_loader)
    for i, (images, _, index) in enumerate(train_loader):
        cur_lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, lr)
        images[0] = images[0].cuda()
        images[1] = images[1].cuda()

        with torch.cuda.amp.autocast(True):
            hidden1 = model(images[0])
            hidden2 = model(images[1])
            loss = loss_fn(hidden1, hidden2, index)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    print(f'Epoch: {epoch}, Dynamtic Loss: {loss:.3f}')

Creating Model & Optimizer

set_all_seeds(123)

# resNet-18 + 2-layer non-linear layers
base_encoder = resnet18(
    pretrained=False, last_activation=None, num_classes=128
)
hidden_dim = base_encoder.fc.weight.shape[1]
del base_encoder.fc  # Remove original fc layer

base_encoder.fc = build_mlp(num_proj_layers, hidden_dim, mlp_dim, dim)
base_encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
base_encoder.maxpool = nn.Identity()
model = base_encoder.cuda()

# square root lr scaling
lr = init_lr * math.sqrt(batch_size)

# LARS optimizer
optimizer = libauc.optimizers.SogCLR(
    base_encoder.parameters(),
    mode = 'lars',
    lr=lr,
    weight_decay=weight_decay,
    momentum=0.9
)

# Global Contrastive Loss
loss_fn = GCLoss('unimodal', N=50000, tau=temperature, gamma=gamma, distributed=False)

Pretraining

# mixed precision training
scaler = torch.cuda.amp.GradScaler()

print ('Pretraining')
for epoch in range(epochs):

    #if epoch in [int(epochs)*0.5, int(epochs)*0.75]:
    #   optimizer.update_regularizer()

    # train for one epoch
    train(train_loader, model, loss_fn, optimizer, epoch, lr)

    # save checkpoint
    if epoch % 10 == 0 or epochs - epoch < 3:
        save_checkpoint(
            {'epoch': epoch + 1,
             'arch': 'resnet18',
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'scaler': scaler.state_dict(),
            },
            is_best=False,
            filename=os.path.join(logdir, logname, f'checkpoint_{epoch:04d}.pth.tar'))
Pretraining
Epoch: 0, Dynamtic Loss: -0.466
Epoch: 1, Dynamtic Loss: -0.716
Epoch: 2, Dynamtic Loss: -0.851
Epoch: 3, Dynamtic Loss: -0.842
Epoch: 4, Dynamtic Loss: -0.928
Epoch: 5, Dynamtic Loss: -1.007
Epoch: 6, Dynamtic Loss: -1.031
Epoch: 7, Dynamtic Loss: -1.064
Epoch: 8, Dynamtic Loss: -1.056
Epoch: 9, Dynamtic Loss: -1.127
Epoch: 10, Dynamtic Loss: -1.136
Epoch: 11, Dynamtic Loss: -1.132
Epoch: 12, Dynamtic Loss: -1.118
Epoch: 13, Dynamtic Loss: -1.177
Epoch: 14, Dynamtic Loss: -1.140
Epoch: 15, Dynamtic Loss: -1.198
Epoch: 16, Dynamtic Loss: -1.217
Epoch: 17, Dynamtic Loss: -1.251
Epoch: 18, Dynamtic Loss: -1.216
Epoch: 19, Dynamtic Loss: -1.192
Epoch: 20, Dynamtic Loss: -1.248
Epoch: 21, Dynamtic Loss: -1.197
Epoch: 22, Dynamtic Loss: -1.294
Epoch: 23, Dynamtic Loss: -1.313
Epoch: 24, Dynamtic Loss: -1.239
Epoch: 25, Dynamtic Loss: -1.305
Epoch: 26, Dynamtic Loss: -1.259
Epoch: 27, Dynamtic Loss: -1.275
Epoch: 28, Dynamtic Loss: -1.287
Epoch: 29, Dynamtic Loss: -1.229
Epoch: 30, Dynamtic Loss: -1.272
Epoch: 31, Dynamtic Loss: -1.311
Epoch: 32, Dynamtic Loss: -1.272
Epoch: 33, Dynamtic Loss: -1.293
Epoch: 34, Dynamtic Loss: -1.274
Epoch: 35, Dynamtic Loss: -1.331
Epoch: 36, Dynamtic Loss: -1.352
Epoch: 37, Dynamtic Loss: -1.299
Epoch: 38, Dynamtic Loss: -1.319
Epoch: 39, Dynamtic Loss: -1.317
Epoch: 40, Dynamtic Loss: -1.327
Epoch: 41, Dynamtic Loss: -1.340
Epoch: 42, Dynamtic Loss: -1.334
Epoch: 43, Dynamtic Loss: -1.322
Epoch: 44, Dynamtic Loss: -1.350
Epoch: 45, Dynamtic Loss: -1.338
Epoch: 46, Dynamtic Loss: -1.311
Epoch: 47, Dynamtic Loss: -1.291
Epoch: 48, Dynamtic Loss: -1.329
Epoch: 49, Dynamtic Loss: -1.345
Epoch: 50, Dynamtic Loss: -1.376
Epoch: 51, Dynamtic Loss: -1.358
Epoch: 52, Dynamtic Loss: -1.315
Epoch: 53, Dynamtic Loss: -1.355
Epoch: 54, Dynamtic Loss: -1.330
Epoch: 55, Dynamtic Loss: -1.346
Epoch: 56, Dynamtic Loss: -1.379
Epoch: 57, Dynamtic Loss: -1.375
Epoch: 58, Dynamtic Loss: -1.376
Epoch: 59, Dynamtic Loss: -1.353
Epoch: 60, Dynamtic Loss: -1.340
Epoch: 61, Dynamtic Loss: -1.387
Epoch: 62, Dynamtic Loss: -1.356
Epoch: 63, Dynamtic Loss: -1.370
Epoch: 64, Dynamtic Loss: -1.364
Epoch: 65, Dynamtic Loss: -1.344
Epoch: 66, Dynamtic Loss: -1.380
Epoch: 67, Dynamtic Loss: -1.410
Epoch: 68, Dynamtic Loss: -1.428
Epoch: 69, Dynamtic Loss: -1.377
Epoch: 70, Dynamtic Loss: -1.408
Epoch: 71, Dynamtic Loss: -1.420
Epoch: 72, Dynamtic Loss: -1.398
Epoch: 73, Dynamtic Loss: -1.395
Epoch: 74, Dynamtic Loss: -1.409
Epoch: 75, Dynamtic Loss: -1.361
Epoch: 76, Dynamtic Loss: -1.396
Epoch: 77, Dynamtic Loss: -1.371
Epoch: 78, Dynamtic Loss: -1.373
Epoch: 79, Dynamtic Loss: -1.406
Epoch: 80, Dynamtic Loss: -1.407
Epoch: 81, Dynamtic Loss: -1.437
Epoch: 82, Dynamtic Loss: -1.425
Epoch: 83, Dynamtic Loss: -1.375
Epoch: 84, Dynamtic Loss: -1.405
Epoch: 85, Dynamtic Loss: -1.372
Epoch: 86, Dynamtic Loss: -1.417
Epoch: 87, Dynamtic Loss: -1.458
Epoch: 88, Dynamtic Loss: -1.399
Epoch: 89, Dynamtic Loss: -1.433
Epoch: 90, Dynamtic Loss: -1.442
Epoch: 91, Dynamtic Loss: -1.423
Epoch: 92, Dynamtic Loss: -1.426
Epoch: 93, Dynamtic Loss: -1.443
Epoch: 94, Dynamtic Loss: -1.454
Epoch: 95, Dynamtic Loss: -1.426
Epoch: 96, Dynamtic Loss: -1.437
Epoch: 97, Dynamtic Loss: -1.478
Epoch: 98, Dynamtic Loss: -1.425
Epoch: 99, Dynamtic Loss: -1.452
Epoch: 100, Dynamtic Loss: -1.453
Epoch: 101, Dynamtic Loss: -1.485
Epoch: 102, Dynamtic Loss: -1.441
Epoch: 103, Dynamtic Loss: -1.446
Epoch: 104, Dynamtic Loss: -1.400
Epoch: 105, Dynamtic Loss: -1.448
Epoch: 106, Dynamtic Loss: -1.433
Epoch: 107, Dynamtic Loss: -1.406
Epoch: 108, Dynamtic Loss: -1.436
Epoch: 109, Dynamtic Loss: -1.460
Epoch: 110, Dynamtic Loss: -1.488
Epoch: 111, Dynamtic Loss: -1.463
Epoch: 112, Dynamtic Loss: -1.502
Epoch: 113, Dynamtic Loss: -1.504
Epoch: 114, Dynamtic Loss: -1.443
Epoch: 115, Dynamtic Loss: -1.395
Epoch: 116, Dynamtic Loss: -1.487
Epoch: 117, Dynamtic Loss: -1.472
Epoch: 118, Dynamtic Loss: -1.446
Epoch: 119, Dynamtic Loss: -1.484
Epoch: 120, Dynamtic Loss: -1.438
Epoch: 121, Dynamtic Loss: -1.494
Epoch: 122, Dynamtic Loss: -1.490
Epoch: 123, Dynamtic Loss: -1.464
Epoch: 124, Dynamtic Loss: -1.490
Epoch: 125, Dynamtic Loss: -1.462
Epoch: 126, Dynamtic Loss: -1.479
Epoch: 127, Dynamtic Loss: -1.478
Epoch: 128, Dynamtic Loss: -1.448
Epoch: 129, Dynamtic Loss: -1.432
Epoch: 130, Dynamtic Loss: -1.498
Epoch: 131, Dynamtic Loss: -1.469
Epoch: 132, Dynamtic Loss: -1.463
Epoch: 133, Dynamtic Loss: -1.514
Epoch: 134, Dynamtic Loss: -1.495
Epoch: 135, Dynamtic Loss: -1.494
Epoch: 136, Dynamtic Loss: -1.503
Epoch: 137, Dynamtic Loss: -1.515
Epoch: 138, Dynamtic Loss: -1.478
Epoch: 139, Dynamtic Loss: -1.467
Epoch: 140, Dynamtic Loss: -1.508
Epoch: 141, Dynamtic Loss: -1.476
Epoch: 142, Dynamtic Loss: -1.523
Epoch: 143, Dynamtic Loss: -1.499
Epoch: 144, Dynamtic Loss: -1.471
Epoch: 145, Dynamtic Loss: -1.471
Epoch: 146, Dynamtic Loss: -1.492
Epoch: 147, Dynamtic Loss: -1.532
Epoch: 148, Dynamtic Loss: -1.503
Epoch: 149, Dynamtic Loss: -1.507
Epoch: 150, Dynamtic Loss: -1.518
Epoch: 151, Dynamtic Loss: -1.494
Epoch: 152, Dynamtic Loss: -1.510
Epoch: 153, Dynamtic Loss: -1.480
Epoch: 154, Dynamtic Loss: -1.525
Epoch: 155, Dynamtic Loss: -1.499
Epoch: 156, Dynamtic Loss: -1.502
Epoch: 157, Dynamtic Loss: -1.527
Epoch: 158, Dynamtic Loss: -1.531
Epoch: 159, Dynamtic Loss: -1.477
Epoch: 160, Dynamtic Loss: -1.491
Epoch: 161, Dynamtic Loss: -1.519
Epoch: 162, Dynamtic Loss: -1.534
Epoch: 163, Dynamtic Loss: -1.475
Epoch: 164, Dynamtic Loss: -1.493
Epoch: 165, Dynamtic Loss: -1.516
Epoch: 166, Dynamtic Loss: -1.494
Epoch: 167, Dynamtic Loss: -1.515
Epoch: 168, Dynamtic Loss: -1.534
Epoch: 169, Dynamtic Loss: -1.503
Epoch: 170, Dynamtic Loss: -1.516
Epoch: 171, Dynamtic Loss: -1.493
Epoch: 172, Dynamtic Loss: -1.498
Epoch: 173, Dynamtic Loss: -1.509
Epoch: 174, Dynamtic Loss: -1.510
Epoch: 175, Dynamtic Loss: -1.509
Epoch: 176, Dynamtic Loss: -1.519
Epoch: 177, Dynamtic Loss: -1.524
Epoch: 178, Dynamtic Loss: -1.520
Epoch: 179, Dynamtic Loss: -1.490
Epoch: 180, Dynamtic Loss: -1.504
Epoch: 181, Dynamtic Loss: -1.489
Epoch: 182, Dynamtic Loss: -1.540
Epoch: 183, Dynamtic Loss: -1.552
Epoch: 184, Dynamtic Loss: -1.504
Epoch: 185, Dynamtic Loss: -1.527
Epoch: 186, Dynamtic Loss: -1.525
Epoch: 187, Dynamtic Loss: -1.541
Epoch: 188, Dynamtic Loss: -1.535
Epoch: 189, Dynamtic Loss: -1.501
Epoch: 190, Dynamtic Loss: -1.530
Epoch: 191, Dynamtic Loss: -1.529
Epoch: 192, Dynamtic Loss: -1.527
Epoch: 193, Dynamtic Loss: -1.521
Epoch: 194, Dynamtic Loss: -1.523
Epoch: 195, Dynamtic Loss: -1.514
Epoch: 196, Dynamtic Loss: -1.501
Epoch: 197, Dynamtic Loss: -1.535
Epoch: 198, Dynamtic Loss: -1.532
Epoch: 199, Dynamtic Loss: -1.507

Linear Evaluation

By default, we use momentum-SGD without weight decay and a batch size of 1024 for linear classification on frozen features/weights. In this stage, it runs 90 epochs.

Configurations

# dataset
image_size = 32
batch_size = 1024
num_classes = 100 # cifar100

# optimizer
epochs = 90
init_lr = 0.075
weight_decay = 0

# checkpoint
checkpoint_dir = '/content/logs/resnet18_cifar100/checkpoint_0199.pth.tar'

Dataset pipeline

mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
train_dataset = libauc.datasets.CIFAR100(root=DATA_ROOT, train=True, download=True,
                                         transform=transforms.Compose([transforms.RandomResizedCrop(32),
                                                                       transforms.RandomHorizontalFlip(),
                                                                       transforms.ToTensor(),
                                                                       normalize,]))
val_dataset = libauc.datasets.CIFAR100(root=DATA_ROOT, train=False,download=True,
                                      transform=transforms.Compose([transforms.ToTensor(),
                                                                    normalize,]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Helper functions

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = cur_lr

def train(train_loader, model, criterion, optimizer, epoch):
    model.eval()
    for i, (images, target) in enumerate(train_loader):
        images = images.float().cuda()
        target = target.long().cuda()
        output = model(images)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def validate(val_loader, model, criterion):
    model.eval()
    acc1_list = []
    acc5_list = []
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.float().cuda()
            target = target.long().cuda()
            output = model(images)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            acc1_list.append(acc1)
            acc5_list.append(acc5)
    acc1_array = torch.stack(acc1_list)
    acc5_array = torch.stack(acc5_list)
    return torch.mean(acc1_array), torch.mean(acc5_array)

Define model

set_all_seeds(123)

# ResNet-18 + classification layer
model = resnet18(pretrained=False, last_activation=None, num_classes=128)
hidden_dim = model.fc.weight.shape[1]
del model.fc
model.fc = nn.Linear(hidden_dim, num_classes, bias=True)

# cifar head for resnet18
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()

# load pretrained checkpoint excluding non-linear layers
linear_keyword = 'fc'
checkpoint = torch.load(checkpoint_dir, map_location="cpu")
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
    if linear_keyword in k:
       del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print ('Linear Classifier Variables: %s'%(msg.missing_keys))

# cuda
model = model.cuda()

# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
       param.requires_grad = False

# init the fc layer
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
getattr(model, linear_keyword).bias.data.zero_()

# optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # weight, bias

Define loss & optimizer

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()

# linear lr scaling
lr = init_lr * batch_size / 256
optimizer = torch.optim.SGD(parameters,
                            lr=lr,
                            momentum=0.9,
                            weight_decay=weight_decay)

Training

# linear evaluation
print ('Linear Evaluation')
for epoch in range(epochs):

      adjust_learning_rate(optimizer, epoch, lr)

      # train for one epoch
      train(train_loader, model, criterion, optimizer, epoch)

      # evaluate on validation set
      acc1, acc5 = validate(val_loader, model, criterion)

      # log
      print ('Epoch: %s, Top1: %.2f, Top5: %.2f'%(epoch, acc1, acc5))
Linear Evaluation
Epoch: 0, Top1: 49.25, Top5: 81.33
Epoch: 1, Top1: 51.79, Top5: 82.66
Epoch: 2, Top1: 52.86, Top5: 83.64
Epoch: 3, Top1: 53.38, Top5: 84.51
Epoch: 4, Top1: 53.22, Top5: 84.41
Epoch: 5, Top1: 54.46, Top5: 84.61
Epoch: 6, Top1: 54.76, Top5: 85.17
Epoch: 7, Top1: 54.53, Top5: 84.90
Epoch: 8, Top1: 54.93, Top5: 85.01
Epoch: 9, Top1: 55.07, Top5: 85.55
Epoch: 10, Top1: 55.65, Top5: 85.41
Epoch: 11, Top1: 55.48, Top5: 85.69
Epoch: 12, Top1: 55.64, Top5: 85.57
Epoch: 13, Top1: 56.29, Top5: 85.84
Epoch: 14, Top1: 56.13, Top5: 85.53
Epoch: 15, Top1: 55.94, Top5: 85.93
Epoch: 16, Top1: 55.92, Top5: 85.99
Epoch: 17, Top1: 56.58, Top5: 86.02
Epoch: 18, Top1: 56.41, Top5: 85.83
Epoch: 19, Top1: 56.62, Top5: 86.17
Epoch: 20, Top1: 56.78, Top5: 86.25
Epoch: 21, Top1: 56.09, Top5: 85.70
Epoch: 22, Top1: 56.65, Top5: 85.69
Epoch: 23, Top1: 56.50, Top5: 85.96
Epoch: 24, Top1: 56.62, Top5: 86.07
Epoch: 25, Top1: 56.75, Top5: 86.05
Epoch: 26, Top1: 56.73, Top5: 85.95
Epoch: 27, Top1: 56.92, Top5: 86.06
Epoch: 28, Top1: 56.88, Top5: 86.02
Epoch: 29, Top1: 57.26, Top5: 86.10
Epoch: 30, Top1: 56.69, Top5: 86.10
Epoch: 31, Top1: 57.29, Top5: 86.24
Epoch: 32, Top1: 56.95, Top5: 86.18
Epoch: 33, Top1: 56.99, Top5: 86.20
Epoch: 34, Top1: 56.95, Top5: 86.14
Epoch: 35, Top1: 57.36, Top5: 86.40
Epoch: 36, Top1: 57.50, Top5: 86.50
Epoch: 37, Top1: 57.35, Top5: 86.63
Epoch: 38, Top1: 57.44, Top5: 86.12
Epoch: 39, Top1: 57.51, Top5: 86.09
Epoch: 40, Top1: 57.28, Top5: 86.27
Epoch: 41, Top1: 57.43, Top5: 86.04
Epoch: 42, Top1: 57.56, Top5: 86.30
Epoch: 43, Top1: 57.67, Top5: 86.04
Epoch: 44, Top1: 57.82, Top5: 86.33
Epoch: 45, Top1: 57.71, Top5: 86.34
Epoch: 46, Top1: 57.56, Top5: 86.46
Epoch: 47, Top1: 57.69, Top5: 86.25
Epoch: 48, Top1: 57.77, Top5: 86.36
Epoch: 49, Top1: 57.87, Top5: 86.32
Epoch: 50, Top1: 57.78, Top5: 86.14
Epoch: 51, Top1: 57.69, Top5: 86.31
Epoch: 52, Top1: 57.33, Top5: 86.31
Epoch: 53, Top1: 57.89, Top5: 86.37
Epoch: 54, Top1: 57.68, Top5: 86.15
Epoch: 55, Top1: 57.88, Top5: 86.24
Epoch: 56, Top1: 58.18, Top5: 86.25
Epoch: 57, Top1: 57.56, Top5: 86.45
Epoch: 58, Top1: 57.97, Top5: 86.46
Epoch: 59, Top1: 57.78, Top5: 86.18
Epoch: 60, Top1: 57.97, Top5: 86.37
Epoch: 61, Top1: 57.97, Top5: 86.28
Epoch: 62, Top1: 57.85, Top5: 86.30
Epoch: 63, Top1: 58.04, Top5: 86.39
Epoch: 64, Top1: 57.80, Top5: 86.40
Epoch: 65, Top1: 57.94, Top5: 86.61
Epoch: 66, Top1: 57.95, Top5: 86.46
Epoch: 67, Top1: 58.13, Top5: 86.52
Epoch: 68, Top1: 58.09, Top5: 86.45
Epoch: 69, Top1: 57.94, Top5: 86.51
Epoch: 70, Top1: 57.97, Top5: 86.44
Epoch: 71, Top1: 58.09, Top5: 86.46
Epoch: 72, Top1: 58.26, Top5: 86.63
Epoch: 73, Top1: 58.24, Top5: 86.42
Epoch: 74, Top1: 58.15, Top5: 86.42
Epoch: 75, Top1: 58.10, Top5: 86.55
Epoch: 76, Top1: 58.05, Top5: 86.58
Epoch: 77, Top1: 58.13, Top5: 86.54
Epoch: 78, Top1: 57.95, Top5: 86.51
Epoch: 79, Top1: 58.00, Top5: 86.52
Epoch: 80, Top1: 58.07, Top5: 86.56
Epoch: 81, Top1: 58.00, Top5: 86.50
Epoch: 82, Top1: 57.99, Top5: 86.50
Epoch: 83, Top1: 58.15, Top5: 86.57
Epoch: 84, Top1: 58.10, Top5: 86.53
Epoch: 85, Top1: 58.11, Top5: 86.56
Epoch: 86, Top1: 58.12, Top5: 86.51
Epoch: 87, Top1: 58.10, Top5: 86.51
Epoch: 88, Top1: 58.10, Top5: 86.52
Epoch: 89, Top1: 58.10, Top5: 86.51