# Adapted from https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
import logging
import random
from multiprocessing import Value
from typing import Dict, Callable, Optional
from torch.utils.data import get_worker_info
try:
import webdataset as wds
from webdataset.filters import _shuffle
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
except ImportError:
raise ImportError("webdataset is not installed. Please install it by running `pip install webdataset`.")
[docs]
class SharedEpoch:
"""Epoch number for distributed training"""
def __init__(self, epoch: int = 0):
self.shared_epoch = Value('i', epoch)
[docs]
def set_value(self, epoch):
self.shared_epoch.value = epoch
[docs]
def get_value(self):
return self.shared_epoch.value
[docs]
def filter_no_caption_or_no_image(sample):
"""Check if sample has caption and image"""
has_caption = ('txt' in sample)
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample)
return has_caption and has_image
[docs]
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
[docs]
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
[docs]
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
"""A re-implementation of the webdataset impl with group_by_keys that doesn't throw"""
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys_nothrow(files, handler=handler)
return samples
[docs]
def pytorch_worker_seed(increment=0):
"""Get dataloader worker seed from pytorch"""
worker_info = get_worker_info()
if worker_info is not None:
# favour using the seed already created for pytorch dataloader workers if it exists
seed = worker_info.seed
if increment:
# space out seed increments so they can't overlap across workers in different iterations
seed += increment * max(1, worker_info.num_workers)
return seed
# fallback to wds rank based seed
return wds.utils.pytorch_worker_seed()
_SHARD_SHUFFLE_SIZE = 2000
_SHARD_SHUFFLE_INITIAL = 500
_SAMPLE_SHUFFLE_SIZE = 5000
_SAMPLE_SHUFFLE_INITIAL = 1000
[docs]
class detshuffle2(wds.PipelineStage):
"""Shuffle according to seed and epoch"""
def __init__(
self,
bufsize=1000,
initial=100,
seed=0,
epoch=-1,
):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
[docs]
def run(self, src):
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
rng = random.Random()
if self.seed < 0:
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
seed = pytorch_worker_seed(epoch)
else:
# This seed to be deterministic AND the same across all nodes/workers in each epoch
seed = self.seed + epoch
rng.seed(seed)
return _shuffle(src, self.bufsize, self.initial, rng)
[docs]
class WebDataset(wds.DataPipeline):
r"""
An image-text dataset that is stored in webdataset format. For more information on webdataset format,
refer to https://github.com/webdataset/webdataset.
Args:
input_shards (str): Path to the dataset shards.
is_train (bool): Whether the dataset is for training or evaluation.
batch_size (int): Batch size per worker.
preprocess_img (Callable): Function to preprocess the image.
seed (int): Seed for shuffling the dataset.
epoch (int): Start epoch number.
tokenize (Optional[Callable]): Tokenizer function for the text data.
return_index (bool): Whether to return the index of the data.
"""
def __init__(self,
input_shards: str,
is_train: bool,
batch_size: int,
preprocess_img: Callable,
seed: int = 0,
epoch: int = 0,
tokenize: Optional[Callable] = None,
return_index: bool = False,
):
self.shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
pipeline = [wds.SimpleShardList(input_shards)]
# at this point we have an iterator over all the shards
if is_train:
pipeline.extend([
detshuffle2(
bufsize=_SHARD_SHUFFLE_SIZE,
initial=_SHARD_SHUFFLE_INITIAL,
seed=seed,
epoch=self.shared_epoch,
),
wds.split_by_node,
wds.split_by_worker,
])
pipeline.extend([
# at this point, we have an iterator over the shards assigned to each worker at each node
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
wds.shuffle(
bufsize=_SAMPLE_SHUFFLE_SIZE,
initial=_SAMPLE_SHUFFLE_INITIAL,
),
])
else:
pipeline.extend([
wds.split_by_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(handler=log_and_continue),
])
# here we also load the key of data
def json_parse_key(json_dict: Dict) -> int:
return int(json_dict["key"])
if return_index:
rename = wds.rename(image="jpg;png;jpeg;webp", text="txt", key="json")
if tokenize is not None:
map_dict = wds.map_dict(image=preprocess_img, text=tokenize, key=json_parse_key)
else:
map_dict = wds.map_dict(image=preprocess_img, key=json_parse_key)
to_tuple = wds.to_tuple("image", "text", "key", "key")
else:
rename = wds.rename(image="jpg;png;jpeg;webp", text="txt")
if tokenize is not None:
map_dict = wds.map_dict(image=preprocess_img, text=tokenize)
else:
map_dict = wds.map_dict(image=preprocess_img)
to_tuple = wds.to_tuple("image", "text")
pipeline.extend([
wds.select(filter_no_caption_or_no_image),
wds.decode("pilrgb", handler=log_and_continue),
rename, map_dict, to_tuple,
wds.batched(batch_size, partial=not is_train)
])
super().__init__(*pipeline)