Optimizing Global Contrastive Loss with Small Batch Size (SogCLR)
Author: Zhuoning Yuan, Tianbao Yang
Introduction
In this tutorial, you will learn how to train a self-supervised model by optimizing Global Contrastive Loss (GCLoss) on CIFAR10/CIFAR100. This version was implementated in PyTorch based on moco’s codebase. It is recommended to run this notebook on a GPU-enabled environment, e.g., Google Colab. For training ImageNet-1K, please refer to this Github repo.
Reference
If you find this tutorial helpful in your work, please cite our library paper and the following papers:
@inproceedings{yuan2022provable,
title={Provable stochastic optimization for global contrastive learning: Small batch does not harm performance},
author={Yuan, Zhuoning and Wu, Yuexin and Qiu, Zi-Hao and Du, Xianzhi and Zhang, Lijun and Zhou, Denny and Yang, Tianbao},
booktitle={International Conference on Machine Learning},
pages={25760--25782},
year={2022},
organization={PMLR}
}
Install LibAUC
Let’s start with install our library here. In this tutorial, we will use the lastest version for LibAUC by using pip install -U
.
!pip install -U libauc
Importing LibAUC
Importing related packages
import libauc
from libauc.models import resnet50, resnet18
from libauc.datasets import CIFAR100
from libauc.optimizers import SogCLR
from libauc.losses import GCLoss
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import os,math,shutil
Reproducibility
The following function set_all_seeds
limits the number of sources
of randomness behaviors, such as model intialization, data shuffling,
etcs. However, completely reproducible results are not guaranteed
across PyTorch releases [Ref].
def set_all_seeds(SEED):
# REPRODUCIBILITY
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Global Contrastive Loss
Global Contrastive Loss (GCLoss) aims to maximize the similarity between an anchor image \(\mathbf{x}_i\) and its corresponding positive image \(\mathbf{x}_i^{+}\), while minimizing the similarity between the anchor and a set of negative samples \(\mathbf{S}_i^{-}\). For GCLoss, negative samples are from full data instead of mini-batch data. For more details about the formulation of GCL, please refer to the SogCLR paper.
Hyper-parameters
# model: non-linear projection layer
num_proj_layers=2
dim=256
mlp_dim=2048
# dataset: cifar100
data_name = 'cifar100'
batch_size = 256
# optimizer
weight_decay = 1e-6
init_lr=0.075
epochs = 200
warmup_epochs = 10
# dynamtic loss
gamma = 0.9
temperature = 0.5
# path
logdir = './logs/'
logname = 'resnet18_cifar100'
os.makedirs(os.path.join(logdir, logname), exist_ok=True)
Dataset Pipeline for Contrastive Learning
The dataset pipeline presented here is different from standard
pipelines. Firstly, TwoCropsTransform
generates two random
augmented crops of a single image to construct pairwise samples, as
opposed to just one random crop in standard pipeline. Secondly, the
augmentation
follows SimCLR’s implementation. Lastly,
libauc.datasets.CIFAR100
returns the index of the image along
with the image and its label.
class TwoCropsTransform:
"""Take two random crops of one image."""
def __init__(self, base_transform1, base_transform2):
self.base_transform1 = base_transform1
self.base_transform2 = base_transform2
def __call__(self, x):
im1 = self.base_transform1(x)
im2 = self.base_transform2(x)
return [im1, im2]
image_size = 32
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
# SimCLR augmentations
augmentation = [
transforms.RandomResizedCrop(image_size, scale=(0.08, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # Not strengthened
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]
DATA_ROOT = './'
train_dataset = libauc.datasets.CIFAR100(
root=DATA_ROOT, train=True, download=True, return_index=True,
transform=TwoCropsTransform(
transforms.Compose(augmentation),
transforms.Compose(augmentation)
)
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
drop_last=True
)
Helper functions
def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
mlp = []
for l in range(num_layers):
dim1 = input_dim if l == 0 else mlp_dim
dim2 = output_dim if l == num_layers - 1 else mlp_dim
mlp.append(nn.Linear(dim1, dim2, bias=False))
if l < num_layers - 1:
mlp.append(nn.BatchNorm1d(dim2))
mlp.append(nn.ReLU(inplace=True))
elif last_bn:
# Follow SimCLR's design:
# https://github.com/google-research/simclr/blob/master/model_util.py#L157
# For simplicity, we further removed gamma in BN
mlp.append(nn.BatchNorm1d(dim2, affine=False))
return nn.Sequential(*mlp)
def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
"""Decays the learning rate with half-cycle cosine after warmup."""
if epoch < warmup_epochs:
cur_lr = init_lr * epoch / warmup_epochs
else:
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr
return cur_lr
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def train(train_loader, model, loss_fn, optimizer, epoch, lr):
model.train()
iters_per_epoch = len(train_loader)
for i, (images, _, index) in enumerate(train_loader):
cur_lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, lr)
images[0] = images[0].cuda()
images[1] = images[1].cuda()
with torch.cuda.amp.autocast(True):
hidden1 = model(images[0])
hidden2 = model(images[1])
loss = loss_fn(hidden1, hidden2, index)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print(f'Epoch: {epoch}, Dynamtic Loss: {loss:.3f}')
Creating Model & Optimizer
set_all_seeds(123)
# resNet-18 + 2-layer non-linear layers
base_encoder = resnet18(
pretrained=False, last_activation=None, num_classes=128
)
hidden_dim = base_encoder.fc.weight.shape[1]
del base_encoder.fc # Remove original fc layer
base_encoder.fc = build_mlp(num_proj_layers, hidden_dim, mlp_dim, dim)
base_encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
base_encoder.maxpool = nn.Identity()
model = base_encoder.cuda()
# square root lr scaling
lr = init_lr * math.sqrt(batch_size)
# LARS optimizer
optimizer = libauc.optimizers.SogCLR(
base_encoder.parameters(),
mode = 'lars',
lr=lr,
weight_decay=weight_decay,
momentum=0.9
)
# Global Contrastive Loss
loss_fn = GCLoss('unimodal', N=50000, tau=temperature, gamma=gamma, distributed=False)
Pretraining
# mixed precision training
scaler = torch.cuda.amp.GradScaler()
print ('Pretraining')
for epoch in range(epochs):
#if epoch in [int(epochs)*0.5, int(epochs)*0.75]:
# optimizer.update_regularizer()
# train for one epoch
train(train_loader, model, loss_fn, optimizer, epoch, lr)
# save checkpoint
if epoch % 10 == 0 or epochs - epoch < 3:
save_checkpoint(
{'epoch': epoch + 1,
'arch': 'resnet18',
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scaler': scaler.state_dict(),
},
is_best=False,
filename=os.path.join(logdir, logname, f'checkpoint_{epoch:04d}.pth.tar'))
Pretraining
Epoch: 0, Dynamtic Loss: -0.466
Epoch: 1, Dynamtic Loss: -0.716
Epoch: 2, Dynamtic Loss: -0.851
Epoch: 3, Dynamtic Loss: -0.842
Epoch: 4, Dynamtic Loss: -0.928
Epoch: 5, Dynamtic Loss: -1.007
Epoch: 6, Dynamtic Loss: -1.031
Epoch: 7, Dynamtic Loss: -1.064
Epoch: 8, Dynamtic Loss: -1.056
Epoch: 9, Dynamtic Loss: -1.127
Epoch: 10, Dynamtic Loss: -1.136
Epoch: 11, Dynamtic Loss: -1.132
Epoch: 12, Dynamtic Loss: -1.118
Epoch: 13, Dynamtic Loss: -1.177
Epoch: 14, Dynamtic Loss: -1.140
Epoch: 15, Dynamtic Loss: -1.198
Epoch: 16, Dynamtic Loss: -1.217
Epoch: 17, Dynamtic Loss: -1.251
Epoch: 18, Dynamtic Loss: -1.216
Epoch: 19, Dynamtic Loss: -1.192
Epoch: 20, Dynamtic Loss: -1.248
Epoch: 21, Dynamtic Loss: -1.197
Epoch: 22, Dynamtic Loss: -1.294
Epoch: 23, Dynamtic Loss: -1.313
Epoch: 24, Dynamtic Loss: -1.239
Epoch: 25, Dynamtic Loss: -1.305
Epoch: 26, Dynamtic Loss: -1.259
Epoch: 27, Dynamtic Loss: -1.275
Epoch: 28, Dynamtic Loss: -1.287
Epoch: 29, Dynamtic Loss: -1.229
Epoch: 30, Dynamtic Loss: -1.272
Epoch: 31, Dynamtic Loss: -1.311
Epoch: 32, Dynamtic Loss: -1.272
Epoch: 33, Dynamtic Loss: -1.293
Epoch: 34, Dynamtic Loss: -1.274
Epoch: 35, Dynamtic Loss: -1.331
Epoch: 36, Dynamtic Loss: -1.352
Epoch: 37, Dynamtic Loss: -1.299
Epoch: 38, Dynamtic Loss: -1.319
Epoch: 39, Dynamtic Loss: -1.317
Epoch: 40, Dynamtic Loss: -1.327
Epoch: 41, Dynamtic Loss: -1.340
Epoch: 42, Dynamtic Loss: -1.334
Epoch: 43, Dynamtic Loss: -1.322
Epoch: 44, Dynamtic Loss: -1.350
Epoch: 45, Dynamtic Loss: -1.338
Epoch: 46, Dynamtic Loss: -1.311
Epoch: 47, Dynamtic Loss: -1.291
Epoch: 48, Dynamtic Loss: -1.329
Epoch: 49, Dynamtic Loss: -1.345
Epoch: 50, Dynamtic Loss: -1.376
Epoch: 51, Dynamtic Loss: -1.358
Epoch: 52, Dynamtic Loss: -1.315
Epoch: 53, Dynamtic Loss: -1.355
Epoch: 54, Dynamtic Loss: -1.330
Epoch: 55, Dynamtic Loss: -1.346
Epoch: 56, Dynamtic Loss: -1.379
Epoch: 57, Dynamtic Loss: -1.375
Epoch: 58, Dynamtic Loss: -1.376
Epoch: 59, Dynamtic Loss: -1.353
Epoch: 60, Dynamtic Loss: -1.340
Epoch: 61, Dynamtic Loss: -1.387
Epoch: 62, Dynamtic Loss: -1.356
Epoch: 63, Dynamtic Loss: -1.370
Epoch: 64, Dynamtic Loss: -1.364
Epoch: 65, Dynamtic Loss: -1.344
Epoch: 66, Dynamtic Loss: -1.380
Epoch: 67, Dynamtic Loss: -1.410
Epoch: 68, Dynamtic Loss: -1.428
Epoch: 69, Dynamtic Loss: -1.377
Epoch: 70, Dynamtic Loss: -1.408
Epoch: 71, Dynamtic Loss: -1.420
Epoch: 72, Dynamtic Loss: -1.398
Epoch: 73, Dynamtic Loss: -1.395
Epoch: 74, Dynamtic Loss: -1.409
Epoch: 75, Dynamtic Loss: -1.361
Epoch: 76, Dynamtic Loss: -1.396
Epoch: 77, Dynamtic Loss: -1.371
Epoch: 78, Dynamtic Loss: -1.373
Epoch: 79, Dynamtic Loss: -1.406
Epoch: 80, Dynamtic Loss: -1.407
Epoch: 81, Dynamtic Loss: -1.437
Epoch: 82, Dynamtic Loss: -1.425
Epoch: 83, Dynamtic Loss: -1.375
Epoch: 84, Dynamtic Loss: -1.405
Epoch: 85, Dynamtic Loss: -1.372
Epoch: 86, Dynamtic Loss: -1.417
Epoch: 87, Dynamtic Loss: -1.458
Epoch: 88, Dynamtic Loss: -1.399
Epoch: 89, Dynamtic Loss: -1.433
Epoch: 90, Dynamtic Loss: -1.442
Epoch: 91, Dynamtic Loss: -1.423
Epoch: 92, Dynamtic Loss: -1.426
Epoch: 93, Dynamtic Loss: -1.443
Epoch: 94, Dynamtic Loss: -1.454
Epoch: 95, Dynamtic Loss: -1.426
Epoch: 96, Dynamtic Loss: -1.437
Epoch: 97, Dynamtic Loss: -1.478
Epoch: 98, Dynamtic Loss: -1.425
Epoch: 99, Dynamtic Loss: -1.452
Epoch: 100, Dynamtic Loss: -1.453
Epoch: 101, Dynamtic Loss: -1.485
Epoch: 102, Dynamtic Loss: -1.441
Epoch: 103, Dynamtic Loss: -1.446
Epoch: 104, Dynamtic Loss: -1.400
Epoch: 105, Dynamtic Loss: -1.448
Epoch: 106, Dynamtic Loss: -1.433
Epoch: 107, Dynamtic Loss: -1.406
Epoch: 108, Dynamtic Loss: -1.436
Epoch: 109, Dynamtic Loss: -1.460
Epoch: 110, Dynamtic Loss: -1.488
Epoch: 111, Dynamtic Loss: -1.463
Epoch: 112, Dynamtic Loss: -1.502
Epoch: 113, Dynamtic Loss: -1.504
Epoch: 114, Dynamtic Loss: -1.443
Epoch: 115, Dynamtic Loss: -1.395
Epoch: 116, Dynamtic Loss: -1.487
Epoch: 117, Dynamtic Loss: -1.472
Epoch: 118, Dynamtic Loss: -1.446
Epoch: 119, Dynamtic Loss: -1.484
Epoch: 120, Dynamtic Loss: -1.438
Epoch: 121, Dynamtic Loss: -1.494
Epoch: 122, Dynamtic Loss: -1.490
Epoch: 123, Dynamtic Loss: -1.464
Epoch: 124, Dynamtic Loss: -1.490
Epoch: 125, Dynamtic Loss: -1.462
Epoch: 126, Dynamtic Loss: -1.479
Epoch: 127, Dynamtic Loss: -1.478
Epoch: 128, Dynamtic Loss: -1.448
Epoch: 129, Dynamtic Loss: -1.432
Epoch: 130, Dynamtic Loss: -1.498
Epoch: 131, Dynamtic Loss: -1.469
Epoch: 132, Dynamtic Loss: -1.463
Epoch: 133, Dynamtic Loss: -1.514
Epoch: 134, Dynamtic Loss: -1.495
Epoch: 135, Dynamtic Loss: -1.494
Epoch: 136, Dynamtic Loss: -1.503
Epoch: 137, Dynamtic Loss: -1.515
Epoch: 138, Dynamtic Loss: -1.478
Epoch: 139, Dynamtic Loss: -1.467
Epoch: 140, Dynamtic Loss: -1.508
Epoch: 141, Dynamtic Loss: -1.476
Epoch: 142, Dynamtic Loss: -1.523
Epoch: 143, Dynamtic Loss: -1.499
Epoch: 144, Dynamtic Loss: -1.471
Epoch: 145, Dynamtic Loss: -1.471
Epoch: 146, Dynamtic Loss: -1.492
Epoch: 147, Dynamtic Loss: -1.532
Epoch: 148, Dynamtic Loss: -1.503
Epoch: 149, Dynamtic Loss: -1.507
Epoch: 150, Dynamtic Loss: -1.518
Epoch: 151, Dynamtic Loss: -1.494
Epoch: 152, Dynamtic Loss: -1.510
Epoch: 153, Dynamtic Loss: -1.480
Epoch: 154, Dynamtic Loss: -1.525
Epoch: 155, Dynamtic Loss: -1.499
Epoch: 156, Dynamtic Loss: -1.502
Epoch: 157, Dynamtic Loss: -1.527
Epoch: 158, Dynamtic Loss: -1.531
Epoch: 159, Dynamtic Loss: -1.477
Epoch: 160, Dynamtic Loss: -1.491
Epoch: 161, Dynamtic Loss: -1.519
Epoch: 162, Dynamtic Loss: -1.534
Epoch: 163, Dynamtic Loss: -1.475
Epoch: 164, Dynamtic Loss: -1.493
Epoch: 165, Dynamtic Loss: -1.516
Epoch: 166, Dynamtic Loss: -1.494
Epoch: 167, Dynamtic Loss: -1.515
Epoch: 168, Dynamtic Loss: -1.534
Epoch: 169, Dynamtic Loss: -1.503
Epoch: 170, Dynamtic Loss: -1.516
Epoch: 171, Dynamtic Loss: -1.493
Epoch: 172, Dynamtic Loss: -1.498
Epoch: 173, Dynamtic Loss: -1.509
Epoch: 174, Dynamtic Loss: -1.510
Epoch: 175, Dynamtic Loss: -1.509
Epoch: 176, Dynamtic Loss: -1.519
Epoch: 177, Dynamtic Loss: -1.524
Epoch: 178, Dynamtic Loss: -1.520
Epoch: 179, Dynamtic Loss: -1.490
Epoch: 180, Dynamtic Loss: -1.504
Epoch: 181, Dynamtic Loss: -1.489
Epoch: 182, Dynamtic Loss: -1.540
Epoch: 183, Dynamtic Loss: -1.552
Epoch: 184, Dynamtic Loss: -1.504
Epoch: 185, Dynamtic Loss: -1.527
Epoch: 186, Dynamtic Loss: -1.525
Epoch: 187, Dynamtic Loss: -1.541
Epoch: 188, Dynamtic Loss: -1.535
Epoch: 189, Dynamtic Loss: -1.501
Epoch: 190, Dynamtic Loss: -1.530
Epoch: 191, Dynamtic Loss: -1.529
Epoch: 192, Dynamtic Loss: -1.527
Epoch: 193, Dynamtic Loss: -1.521
Epoch: 194, Dynamtic Loss: -1.523
Epoch: 195, Dynamtic Loss: -1.514
Epoch: 196, Dynamtic Loss: -1.501
Epoch: 197, Dynamtic Loss: -1.535
Epoch: 198, Dynamtic Loss: -1.532
Epoch: 199, Dynamtic Loss: -1.507
Linear Evaluation
By default, we use momentum-SGD without weight decay and a batch size of 1024 for linear classification on frozen features/weights. In this stage, it runs 90 epochs.
Configurations
# dataset
image_size = 32
batch_size = 1024
num_classes = 100 # cifar100
# optimizer
epochs = 90
init_lr = 0.075
weight_decay = 0
# checkpoint
checkpoint_dir = '/content/logs/resnet18_cifar100/checkpoint_0199.pth.tar'
Dataset pipeline
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize(mean=mean, std=std)
train_dataset = libauc.datasets.CIFAR100(root=DATA_ROOT, train=True, download=True,
transform=transforms.Compose([transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,]))
val_dataset = libauc.datasets.CIFAR100(root=DATA_ROOT, train=False,download=True,
transform=transforms.Compose([transforms.ToTensor(),
normalize,]))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
Helper functions
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def adjust_learning_rate(optimizer, epoch, init_lr=0.075):
"""Decay the learning rate based on schedule"""
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr
def train(train_loader, model, criterion, optimizer, epoch):
model.eval()
for i, (images, target) in enumerate(train_loader):
images = images.float().cuda()
target = target.long().cuda()
output = model(images)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def validate(val_loader, model, criterion):
model.eval()
acc1_list = []
acc5_list = []
with torch.no_grad():
for i, (images, target) in enumerate(val_loader):
images = images.float().cuda()
target = target.long().cuda()
output = model(images)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1_list.append(acc1)
acc5_list.append(acc5)
acc1_array = torch.stack(acc1_list)
acc5_array = torch.stack(acc5_list)
return torch.mean(acc1_array), torch.mean(acc5_array)
Define model
set_all_seeds(123)
# ResNet-18 + classification layer
model = resnet18(pretrained=False, last_activation=None, num_classes=128)
hidden_dim = model.fc.weight.shape[1]
del model.fc
model.fc = nn.Linear(hidden_dim, num_classes, bias=True)
# cifar head for resnet18
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
# load pretrained checkpoint excluding non-linear layers
linear_keyword = 'fc'
checkpoint = torch.load(checkpoint_dir, map_location="cpu")
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
if linear_keyword in k:
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print ('Linear Classifier Variables: %s'%(msg.missing_keys))
# cuda
model = model.cuda()
# freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
param.requires_grad = False
# init the fc layer
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
getattr(model, linear_keyword).bias.data.zero_()
# optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2 # weight, bias
Define loss & optimizer
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# linear lr scaling
lr = init_lr * batch_size / 256
optimizer = torch.optim.SGD(parameters,
lr=lr,
momentum=0.9,
weight_decay=weight_decay)
Training
# linear evaluation
print ('Linear Evaluation')
for epoch in range(epochs):
adjust_learning_rate(optimizer, epoch, lr)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
# evaluate on validation set
acc1, acc5 = validate(val_loader, model, criterion)
# log
print ('Epoch: %s, Top1: %.2f, Top5: %.2f'%(epoch, acc1, acc5))
Linear Evaluation
Epoch: 0, Top1: 49.25, Top5: 81.33
Epoch: 1, Top1: 51.79, Top5: 82.66
Epoch: 2, Top1: 52.86, Top5: 83.64
Epoch: 3, Top1: 53.38, Top5: 84.51
Epoch: 4, Top1: 53.22, Top5: 84.41
Epoch: 5, Top1: 54.46, Top5: 84.61
Epoch: 6, Top1: 54.76, Top5: 85.17
Epoch: 7, Top1: 54.53, Top5: 84.90
Epoch: 8, Top1: 54.93, Top5: 85.01
Epoch: 9, Top1: 55.07, Top5: 85.55
Epoch: 10, Top1: 55.65, Top5: 85.41
Epoch: 11, Top1: 55.48, Top5: 85.69
Epoch: 12, Top1: 55.64, Top5: 85.57
Epoch: 13, Top1: 56.29, Top5: 85.84
Epoch: 14, Top1: 56.13, Top5: 85.53
Epoch: 15, Top1: 55.94, Top5: 85.93
Epoch: 16, Top1: 55.92, Top5: 85.99
Epoch: 17, Top1: 56.58, Top5: 86.02
Epoch: 18, Top1: 56.41, Top5: 85.83
Epoch: 19, Top1: 56.62, Top5: 86.17
Epoch: 20, Top1: 56.78, Top5: 86.25
Epoch: 21, Top1: 56.09, Top5: 85.70
Epoch: 22, Top1: 56.65, Top5: 85.69
Epoch: 23, Top1: 56.50, Top5: 85.96
Epoch: 24, Top1: 56.62, Top5: 86.07
Epoch: 25, Top1: 56.75, Top5: 86.05
Epoch: 26, Top1: 56.73, Top5: 85.95
Epoch: 27, Top1: 56.92, Top5: 86.06
Epoch: 28, Top1: 56.88, Top5: 86.02
Epoch: 29, Top1: 57.26, Top5: 86.10
Epoch: 30, Top1: 56.69, Top5: 86.10
Epoch: 31, Top1: 57.29, Top5: 86.24
Epoch: 32, Top1: 56.95, Top5: 86.18
Epoch: 33, Top1: 56.99, Top5: 86.20
Epoch: 34, Top1: 56.95, Top5: 86.14
Epoch: 35, Top1: 57.36, Top5: 86.40
Epoch: 36, Top1: 57.50, Top5: 86.50
Epoch: 37, Top1: 57.35, Top5: 86.63
Epoch: 38, Top1: 57.44, Top5: 86.12
Epoch: 39, Top1: 57.51, Top5: 86.09
Epoch: 40, Top1: 57.28, Top5: 86.27
Epoch: 41, Top1: 57.43, Top5: 86.04
Epoch: 42, Top1: 57.56, Top5: 86.30
Epoch: 43, Top1: 57.67, Top5: 86.04
Epoch: 44, Top1: 57.82, Top5: 86.33
Epoch: 45, Top1: 57.71, Top5: 86.34
Epoch: 46, Top1: 57.56, Top5: 86.46
Epoch: 47, Top1: 57.69, Top5: 86.25
Epoch: 48, Top1: 57.77, Top5: 86.36
Epoch: 49, Top1: 57.87, Top5: 86.32
Epoch: 50, Top1: 57.78, Top5: 86.14
Epoch: 51, Top1: 57.69, Top5: 86.31
Epoch: 52, Top1: 57.33, Top5: 86.31
Epoch: 53, Top1: 57.89, Top5: 86.37
Epoch: 54, Top1: 57.68, Top5: 86.15
Epoch: 55, Top1: 57.88, Top5: 86.24
Epoch: 56, Top1: 58.18, Top5: 86.25
Epoch: 57, Top1: 57.56, Top5: 86.45
Epoch: 58, Top1: 57.97, Top5: 86.46
Epoch: 59, Top1: 57.78, Top5: 86.18
Epoch: 60, Top1: 57.97, Top5: 86.37
Epoch: 61, Top1: 57.97, Top5: 86.28
Epoch: 62, Top1: 57.85, Top5: 86.30
Epoch: 63, Top1: 58.04, Top5: 86.39
Epoch: 64, Top1: 57.80, Top5: 86.40
Epoch: 65, Top1: 57.94, Top5: 86.61
Epoch: 66, Top1: 57.95, Top5: 86.46
Epoch: 67, Top1: 58.13, Top5: 86.52
Epoch: 68, Top1: 58.09, Top5: 86.45
Epoch: 69, Top1: 57.94, Top5: 86.51
Epoch: 70, Top1: 57.97, Top5: 86.44
Epoch: 71, Top1: 58.09, Top5: 86.46
Epoch: 72, Top1: 58.26, Top5: 86.63
Epoch: 73, Top1: 58.24, Top5: 86.42
Epoch: 74, Top1: 58.15, Top5: 86.42
Epoch: 75, Top1: 58.10, Top5: 86.55
Epoch: 76, Top1: 58.05, Top5: 86.58
Epoch: 77, Top1: 58.13, Top5: 86.54
Epoch: 78, Top1: 57.95, Top5: 86.51
Epoch: 79, Top1: 58.00, Top5: 86.52
Epoch: 80, Top1: 58.07, Top5: 86.56
Epoch: 81, Top1: 58.00, Top5: 86.50
Epoch: 82, Top1: 57.99, Top5: 86.50
Epoch: 83, Top1: 58.15, Top5: 86.57
Epoch: 84, Top1: 58.10, Top5: 86.53
Epoch: 85, Top1: 58.11, Top5: 86.56
Epoch: 86, Top1: 58.12, Top5: 86.51
Epoch: 87, Top1: 58.10, Top5: 86.51
Epoch: 88, Top1: 58.10, Top5: 86.52
Epoch: 89, Top1: 58.10, Top5: 86.51