.. _fastclip: Optimizing Robust Global Contrastive Loss with FastCLIP (CLIP) ================================================================================================================================ .. raw:: html
Run on Colab
Download Notebook
View on Github
------------------------------------------------------------------------------------ .. container:: cell markdown | **Author**: Xiyuan Wei, Linli Zhou, Tianbao Yang \ Introduction ------------------------------------------------------------------------------------ In this tutorial, we will train a model using FastCLIP-v3. In pretraining stage, we use a subset of the `CC3M dataset `__, which contains about 300,000 image-text pairs. And then we evaluate the pretrained models via zero-shot image/text retrieval on `MS-COCO `__ dataset. We provide the metadata of the training subset and evluation set `here `__. **Reference** If you find this tutorial helpful in your work, please cite our `library paper `__ and the following paper: .. code-block:: RST @article{wei2024fastclip, title={Fastclip: A suite of optimization techniques to accelerate clip training with limited resources}, author={Wei, Xiyuan and Ye, Fanjiang and Yonay, Ori and Chen, Xingyu and Sun, Baixi and Tao, Dingwen and Yang, Tianbao}, journal={arXiv preprint arXiv:2407.01445}, year={2024} } Importing required libs ------------------------------------------------------------------------------------ .. code:: python !pip install -U libauc !pip install timm !pip install transformers .. code:: python import os os.environ["TOKENIZERS_PARALLELISM"] = "true" import re import argparse from pathlib import Path import json import os import random import math from functools import partial import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch import optim import torchvision from torchvision import transforms from torch.utils.data import Dataset, Subset, DataLoader from PIL import Image from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True Image.MAX_IMAGE_PIXELS = None import cv2 import numpy as np import timm from transformers import AutoModel, AutoTokenizer import libauc from libauc.losses.contrastive import GCLoss_v2 from libauc.optimizers import SogCLR from libauc.utils.paper_utils import CosineLRScheduler Arguments for experiments ------------------------------------------------------------------------------------ .. code:: python # path to data folder data_path = './cc3m' train_file = 'cc3m_subset_new.json' # model config image_encoder = 'resnet50' text_encoder = 'distilbert-base-uncased' image_res = 256 vision_width = 768 embed_dim = 256 seed = 42 # optimizer and schedular opt = 'adamW' lr = 3e-4 min_lr = 1e-5 warmup = True warmup_lr = 1e-5 weight_decay = 0.02 decay_rate = 1 epochs = 30 warmup_epochs = 20 cooldown_epochs = 0 # training & test settings batch_size_train = 1024 batch_size_test = 512 k_test = 256 # output path output_dir = './output/' # AMP training use_amp = True # loss config temp = 0.07 # the temperature parameter for clip or sogclr n_gpus = torch.cuda.device_count() val_coco_file = 'coco_val_new.json' test_coco_file = 'coco_test_new.json' coco_image_root = './mscoco' Path(output_dir).mkdir(parents=True, exist_ok=True) Define helper functions ------------------------------------------------------------------------------------ .. code:: python # we employ this function to preprocess the captions def pre_caption(caption, max_words): caption = re.sub( r"([,.'!?\"()*#:;~])", '', caption.lower(), ).replace('-', ' ').replace('/', ' ').replace('', 'person') caption = re.sub( r"\s{2,}", ' ', caption, ) caption = caption.rstrip('\n') caption = caption.strip(' ') #truncate caption caption_words = caption.split(' ') if len(caption_words)>max_words: caption = ' '.join(caption_words[:max_words]) return caption .. code:: python class train_set(Dataset): def __init__(self, ann_file, transform, image_root, max_words=30): self.ann = [] for f in ann_file: self.ann += json.load(open(f,'r')) self.transform = transform self.image_root = image_root self.max_words = max_words self.img_ids = {} n = 0 for ann in self.ann: img_id = ann['image_id'] if img_id not in self.img_ids.keys(): self.img_ids[img_id] = n n += 1 def __len__(self): return len(self.ann) def __getitem__(self, index): ann = self.ann[index] # image_path = os.path.join(self.image_root, ann['image']) image_path = ann['image'] image = Image.open(image_path).convert('RGB') image = self.transform(image) caption = pre_caption(ann['caption'], self.max_words) return image, caption, self.img_ids[ann['image_id']], index class eval_set(Dataset): def __init__(self, ann_file, transform, image_root, max_words=30): self.ann = json.load(open(ann_file,'r')) self.transform = transform self.image_root = image_root self.max_words = max_words self.text = [] self.image = [] self.txt2img = {} self.img2txt = {} txt_id = 0 for img_id, ann in enumerate(self.ann): self.image.append(ann['image']) self.img2txt[img_id] = [] # 'val2014/000000184613.jpg' image_path = os.path.join(self.image_root, ann['image'].split('/')[0], f"COCO_val2014_{ann['image'].split('/')[-1]}") ann['image'] = image_path for i, caption in enumerate(ann['caption']): self.text.append(pre_caption(caption,self.max_words)) self.img2txt[img_id].append(txt_id) self.txt2img[txt_id] = img_id txt_id += 1 def __len__(self): return len(self.image) def __getitem__(self, index): image_path = self.ann[index]['image'] image = Image.open(image_path).convert('RGB') image = self.transform(image) return image, index .. code:: python def add_weight_decay(model, weight_decay=1e-5, skip_list=()): decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: no_decay.append(param) else: decay.append(param) return [ {'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': weight_decay}] def create_optimizer(model, opt, weight_decay=1e-5, filter_bias_and_bn=True): if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay() parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() opt_args = dict(lr=lr, weight_decay=weight_decay) optimizer = SogCLR(parameters, mode=opt, **opt_args) return optimizer .. code:: python def create_scheduler(optimizer): num_epochs = epochs lr_scheduler = CosineLRScheduler( optimizer, t_initial = num_epochs, t_mul = 1.0, lr_min = min_lr, decay_rate = decay_rate, warmup_lr_init = warmup_lr, warmup_t = warmup_epochs, cycle_limit = 1, t_in_epochs = True, noise_range_t = None, noise_pct = 0.67, noise_std = 1.0, noise_seed = 42, ) return lr_scheduler Reproducibility ------------------------------------------------------------------------------------ The following functions limit the number of sources of randomness behaviors, such as model intialization, data shuffling, etcs. .. code:: python # fix the seed for reproducibility torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) cudnn.benchmark = True Building the model ------------------------------------------------------------------------------------ .. code:: python # The following class includes the image encoder, text encoder and several objectives class Model(nn.Module): def __init__(self, image_encoder = None, text_encoder = None, embed_dim = 256, init_model = True, bsz = 128, gamma = 0.9, # the coefficient for moving average estimator temp = 0.01, # temperature for clip or sogclr gamma_schedule = 'cosine', gamma_decay_epochs = 1, ): super().__init__() self.temp = temp self.visual_encoder = timm.create_model(image_encoder, pretrained=init_model) self.visual_encoder.reset_classifier(0) self.text_encoder = AutoModel.from_pretrained(text_encoder, local_files_only=False) if not init_model: self.text_encoder.init_weights() self.vision_proj = nn.Linear(self.visual_encoder.num_features, embed_dim) self.text_proj = nn.Linear(768, embed_dim) self.criterion = GCLoss_v2(tau=temp, gamma=gamma, enable_isogclr=False, gamma_schedule=gamma_schedule, gamma_decay_epochs=gamma_decay_epochs, temperature_scheme="global_learnable") def forward(self, image, text_ids, text_att_masks, idx, text_idx, epoch): image_embeds = self.visual_encoder(image) image_embeds = self.vision_proj(image_embeds) image_feat = F.normalize(image_embeds, dim=-1) text_output = self.text_encoder(text_ids, attention_mask=text_att_masks, output_hidden_states=False) text_embeds = self.text_proj(text_output.last_hidden_state[:,0,:]) text_feat = F.normalize(text_embeds, dim=-1) loss, info = self.criterion(image_feat, text_feat, idx) return loss, info Training functions ------------------------------------------------------------------------------------ .. code:: python print_freq = 50 # user can determine how many iterations to output (e.g., 50 by default) .. code:: python def epoch_train(model, data_loader, optimizer, tokenizer, epoch, max_epoch, warmup_steps, device, scheduler, grad_scaler, tau_optimizer, tau_lr_scheduler, tau_grad_scaler): # train model.train() step_size = 100 warmup_iterations = warmup_steps * step_size if hasattr(model, 'module'): model_orig = model.module else: model_orig = model # adjust gamma based on gamma schedule if hasattr(model_orig.criterion, 'adjust_gamma'): model_orig.criterion.adjust_gamma(epoch) for i,(image, text, idx, text_idx) in enumerate(data_loader): optimizer.zero_grad() tau_optimizer.zero_grad() image = image.to(device, non_blocking=True) # idx = idx.to(device, non_blocking=True) text_idx = text_idx.to(device, non_blocking=True) text_input = tokenizer(text, padding='max_length', truncation=True, max_length=30, return_tensors="pt").to(device) if grad_scaler is None: loss, info = model(image, text_input.input_ids, text_input.attention_mask, idx=idx, text_idx=text_idx, epoch=epoch) loss.mean().backward() optimizer.step() tau_optimizer.step() else: with torch.cuda.amp.autocast(): loss, info = model(image, text_input.input_ids, text_input.attention_mask, idx=idx, text_idx=text_idx, epoch=epoch) grad_scaler.scale(loss.mean()).backward() grad_scaler.step(optimizer) grad_scaler.update() tau_grad_scaler.scale(model_orig.criterion.tau.mean()).backward() tau_grad_scaler.step(tau_optimizer) tau_grad_scaler.update() model_orig.criterion.tau.data.clamp_(0.01, 0.07) if epoch==0 and i%step_size==0 and i<=warmup_iterations: scheduler.step(i//step_size) tau_lr_scheduler.step(i//step_size) if i%print_freq == 0: lr = optimizer.param_groups[0]["lr"] print("Epoch:", epoch, "iteration:", i, "lr:", lr, "loss:", loss.mean().item(), "tau:", model_orig.criterion.tau.mean().item()) if info is not None: print("tau_img: %.4f, tau_txt: %.4f" % (info[0].mean(), info[1].mean())) Evaluation functions ------------------------------------------------------------------------------------ .. code:: python @torch.no_grad() def evaluation(model, data_loader, tokenizer, device): # test model.eval() print('Computing features for evaluation...') texts = data_loader.dataset.text num_text = len(texts) text_bs = 256 text_embeds = [] for i in range(0, num_text, text_bs): text = texts[i: min(num_text, i+text_bs)] text_input = tokenizer(text, padding='max_length', truncation=True, max_length=30, return_tensors="pt").to(device) text_output = model.text_encoder(text_input.input_ids, attention_mask=text_input.attention_mask, output_hidden_states=False) text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]), dim=-1) text_embeds.append(text_embed) text_embeds = torch.cat(text_embeds,dim=0) image_embeds = [] for image, img_id in data_loader: image = image.to(device) image_feat = model.visual_encoder(image) image_embed = model.vision_proj(image_feat) image_embed = F.normalize(image_embed, dim=-1) image_embeds.append(image_embed) image_embeds = torch.cat(image_embeds,dim=0) sims_matrix = image_embeds @ text_embeds.t() score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device) for i,sims in enumerate(sims_matrix): topk_sim, topk_idx = sims.topk(k=k_test, dim=0) score_matrix_i2t[i, topk_idx] = topk_sim sims_matrix = sims_matrix.t() score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device) for i,sims in enumerate(sims_matrix): topk_sim, topk_idx = sims.topk(k=k_test, dim=0) score_matrix_t2i[i, topk_idx] = topk_sim return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() @torch.no_grad() def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): #Images->Text ranks = np.zeros(scores_i2t.shape[0]) for index,score in enumerate(scores_i2t): inds = np.argsort(score)[::-1] # Score rank = 1e20 for i in img2txt[index]: tmp = np.where(inds == i)[0][0] if tmp < rank: rank = tmp ranks[index] = rank # Compute metrics tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) #Text->Images ranks = np.zeros(scores_t2i.shape[0]) for index,score in enumerate(scores_t2i): inds = np.argsort(score)[::-1] ranks[index] = np.where(inds == txt2img[index])[0][0] # Compute metrics ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) tr_mean = (tr1 + tr5 + tr10) / 3 ir_mean = (ir1 + ir5 + ir10) / 3 r_mean = (tr_mean + ir_mean) / 2 eval_result = {'txt_r1': tr1, 'txt_r5': tr5, 'txt_r10': tr10, 'txt_r_mean': tr_mean, 'img_r1': ir1, 'img_r5': ir5, 'img_r10': ir10, 'img_r_mean': ir_mean, 'r_mean': r_mean} return eval_result Dataset Pipeline for Bimodal Contrastive Learning ------------------------------------------------------------------------------------ .. code:: python # set up the transformation, datasets and dataloaders train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_res, scale=(0.5, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), transforms.RandAugment(), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) test_transform = transforms.Compose([ transforms.Resize((image_res, image_res), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) train_dataset = train_set([train_file], train_transform, data_path) val_coco_dataset = eval_set(val_coco_file, test_transform, coco_image_root) test_coco_dataset = eval_set(test_coco_file, test_transform, coco_image_root) print("len of train_dataset:", len(train_dataset)) print("len of coco val/test:", len(val_coco_dataset), len(test_coco_dataset)) train_loader = DataLoader(train_dataset, batch_size=batch_size_train * n_gpus, num_workers=16, pin_memory=True, shuffle=True, drop_last=True, prefetch_factor=4) val_loader = DataLoader(val_coco_dataset, batch_size=batch_size_test, num_workers=16, pin_memory=True, shuffle=False, drop_last=False, prefetch_factor=12) test_loader = DataLoader(test_coco_dataset, batch_size=batch_size_test, num_workers=16, pin_memory=True, shuffle=False, drop_last=False, prefetch_factor=12) Pretraining and evaluation for SogCLR with Cosine Gamma ------------------------------------------------------------------------------------ For consine gamma schedule, we recommend setting **gamma_decay_epochs** to 50% of the number of training epochs. .. code:: python gamma = 0.2 # the parameter for the moving average estimator in sogclr/isogclr gamma_schedule = "cosine" gamma_decay_epochs = epochs // 2 # create the model and wrap it in DDP tokenizer = AutoTokenizer.from_pretrained(text_encoder, local_files_only=False) model = Model(image_encoder=image_encoder, text_encoder=text_encoder, embed_dim=embed_dim, init_model=True, bsz=batch_size_train, gamma=gamma, temp=temp, gamma_schedule=gamma_schedule, gamma_decay_epochs=gamma_decay_epochs) model = model.cuda() .. code:: python if n_gpus > 1: print("Using", n_gpus, "GPUs") model = nn.DataParallel(model) .. code:: python # set up the optimizer and objective function optimizer = create_optimizer(model, opt, weight_decay) lr_scheduler = create_scheduler(optimizer) tau_optimizer = torch.optim.AdamW([model.criterion.tau], lr=lr/4, weight_decay=0.0) tau_lr_scheduler = create_scheduler(tau_optimizer) if use_amp: grad_scaler = torch.cuda.amp.GradScaler() tau_grad_scaler = torch.cuda.amp.GradScaler() else: grad_scaler = None tau_grad_scaler = None # training loop for epoch in range(0, epochs): train_stats = epoch_train(model, train_loader, optimizer, tokenizer, epoch, epochs, warmup_epochs, torch.device('cuda'), lr_scheduler, grad_scaler, tau_optimizer, tau_lr_scheduler, tau_grad_scaler) # evaluate the model on ms-coco data try: score_val_i2t_coco, score_val_t2i_coco = evaluation(model.module, val_loader, tokenizer, torch.device('cuda')) score_test_i2t_coco, score_test_t2i_coco = evaluation(model.module, test_loader, tokenizer, torch.device('cuda')) except: # for non-distributed training score_val_i2t_coco, score_val_t2i_coco = evaluation(model, val_loader, tokenizer, torch.device('cuda')) score_test_i2t_coco, score_test_t2i_coco = evaluation(model, test_loader, tokenizer, torch.device('cuda')) print("Epoch:", epoch) val_result_coco = itm_eval(score_val_i2t_coco, score_val_t2i_coco, val_loader.dataset.txt2img, val_loader.dataset.img2txt) print("coco val:", val_result_coco) test_result_coco = itm_eval(score_test_i2t_coco, score_test_t2i_coco, test_loader.dataset.txt2img, test_loader.dataset.img2txt) print("coco test:", test_result_coco) lr_scheduler.step(epoch+warmup_epochs+1) Visualization ------------------------------------------------------------------------------------ In order to compare the performance of different algorithms, we also train CLIP models using OpenCLIP, where the notebook is `here `__. We plot the training curves of the mean validation recall values of the three algorithms. .. code:: python clip_recall_vals = [9.146024256963882, 25.643747834199658, 29.270318539250965, 30.15803545248567, 30.758616553378648, 30.071351459416235, 30.393121418099422, 30.48455817672931, 30.86781287485006, 30.960474476875916, 30.23930161268826, 30.297093162734903, 30.965095295215246, 31.358384646141545, 31.467666266826605, 30.952405704384915, 31.01773290683726, 31.05107556977209, 31.651619352259097, 31.91153405304545, 31.72560575769692, 31.692890843662532, 31.85492869518859, 32.204224976675995, 32.286870585099294, 32.64414634146342, 32.50149940023991, 32.45683726509397, 32.46283220045315, 32.77081167532987] fastclip_recall_vals = [24.293855791016924, 30.252612288417968, 32.606966546714645, 34.127257097161134, 34.355180594428894, 34.57910569105691, 34.640355857656935, 35.06099160335866, 34.64842729574836, 34.22647474343596, 33.93986405437825, 34.24375716380115, 34.322396374783416, 34.83299613487938, 34.437050513128085, 34.4030281220845, 33.790442489670795, 34.651717979474874, 34.08245635079302, 35.12101692656271, 34.557049180327866, 34.295081967213115, 34.567054511528724, 34.401079568172726, 34.18844328935093, 34.359738771158206, 34.63501666000267, 34.466374783419965, 34.38638011462082, 34.37372917499667] .. code:: python import matplotlib.pyplot as plt import numpy as np epochs = np.arange(1, 31) plt.plot(epochs, clip_recall_vals, label='OpenCLIP', ls=':', marker='+', color='blue') plt.plot(epochs, fastclip_recall_vals, label='FastCLIP', marker='*', color='orange') plt.ylabel('Mean Validation Recall', fontsize=18) plt.xlabel('Epoch', fontsize=18) plt.legend(fontsize=16) plt.show() .. image:: ./imgs/FastCLIP_31_0.png