Source code for libauc.sampler.sampler

import numpy as np
import random
import torch
import torchvision
from torch.utils.data.sampler import Sampler

__all__ = [
        'ControlledDataSampler', 
        'DualSampler',
        'TriSampler']

[docs] class ControlledDataSampler(Sampler): r""" Base class for Controlled Data Sampler.""" def __init__(self, dataset, batch_size, labels=None, shuffle=True, num_pos=None, num_sampled_tasks=None, sampling_rate=0.5, random_seed=2023): assert batch_size is not None, 'batch_size can not be None!' assert (num_pos is None) or (sampling_rate is None), 'only one of {pos_num} and {sampling_rate} is needed!' if sampling_rate: assert sampling_rate>0.0 and sampling_rate<1.0, 'sampling rate is not a valid number!' if labels is None: labels = self._get_labels(dataset) self.labels = self._check_labels(labels) # return: (N, ) or (N, T) self.random_seed = random_seed self.shuffle = shuffle self.num_samples = int(len(labels)) self.sampling_rate = sampling_rate self.batch_size = batch_size np.random.seed(self.random_seed) total_tasks = 0 if num_sampled_tasks is None: total_tasks = self._get_num_tasks(self.labels) self.total_tasks = total_tasks self.num_sampled_tasks = num_sampled_tasks self.pos_indices, self.neg_indices = self._get_sample_class_indices(self.labels) # task_id: 0, 1, 2, 3, ... self.class_counts = self._get_sample_class_counts(self.labels) # pos_len & neg_len if self.sampling_rate: self.num_pos = int(self.sampling_rate*batch_size) if self.num_pos == 0: self.num_pos = 1 self.num_neg = batch_size - self.num_pos elif num_pos: self.num_pos = num_pos self.num_neg = batch_size - num_pos else: NotImplementedError self.num_batches = len(labels)//batch_size self.sampled = [] def _check_array(self, data, squeeze=True): if not isinstance(data, (np.ndarray, np.generic)): data = np.array(data) if squeeze: data = np.squeeze(data) return data def _get_labels(self, dataset): r"""Extract labels from given any dataset object.""" if isinstance(dataset, torch.utils.data.Dataset): return np.array(dataset.targets) elif isinstance(dataset, torchvision.datasets.ImageFolder): return np.array(dataset.targets) else: raise NotImplementedError # TODO: support more Dataset types def _check_labels(self, labels): r"""Validate labels on three cases: nan, negative, one-hot.""" if np.isnan(labels).sum()>0: raise ValueError('labels contain NaN value!') labels = self._check_array(labels, squeeze=True) if (labels<0).sum() > 0 : raise ValueError('labels contain negative value!') if len(labels.shape) == 1: num_classes = np.unique(labels).size assert num_classes > 1, 'labels must have >= 2 classes!' if num_classes > 2: # format multi-class to multi-label num_samples = len(labels) new_labels = np.eye(num_classes)[labels] return new_labels return labels def _get_num_tasks(self, labels): r"""Compute number of unique labels for binary and multi-label datasets.""" if len(labels.shape) == 1: return len(np.unique(labels)) else: return labels.shape[-1] def _get_unique_labels(self, labels): r"""Extract unique labels for binary and multi-label (task) datasets.""" unique_labels = np.unique(labels) if len(labels.shape)==1 else np.arange(labels.squeeze().shape[-1]) assert len(unique_labels) > 1, 'labels must have >=2 classes!' return unique_labels def _get_sample_class_counts(self, labels): r"""Compute number of postives and negatives per label (task). """ num_sampled_task = self._get_num_tasks(labels) dict = {} if num_sampled_task == 2: task_id = 0 # binary data, i.e. num_sampled_task == 1 dict[task_id] = (np.count_nonzero(labels == 1), np.count_nonzero(labels == 0) ) else: task_ids = np.arange(num_sampled_task) for task_id in task_ids: dict[task_id] = (np.count_nonzero(labels[:, task_id] > 0), np.count_nonzero(labels[:, task_id] == 0) ) return dict def _get_sample_class_indices(self, labels, num_sampled_task=None): r"""Extract sample indices for postives and negatives per label (task).""" if not num_sampled_task: num_sampled_task = self._get_num_tasks(labels) num_sampled_task = num_sampled_task - 1 if num_sampled_task == 2 else num_sampled_task pos_indices, neg_indices = {}, {} for task_id in range(num_sampled_task): label_t = labels[:, task_id] if num_sampled_task > 2 else labels pos_idx = np.flatnonzero(label_t>0) neg_idx = np.flatnonzero(label_t==0) if self.shuffle: np.random.shuffle(pos_idx), np.random.shuffle(neg_idx) pos_indices[task_id] = pos_idx neg_indices[task_id] = neg_idx return pos_indices, neg_indices def __iter__(self): r"""Naive implementation for Controlled Data Sampler.""" pos_id = 0 neg_id = 0 if self.shuffle: np.random.shuffle(self.pos_pool) np.random.shuffle(self.neg_pool) for i in range(self.num_batches): for j in range(self.num_pos): self.sampled.append(self.pos_indices[pos_id % self.pos_len]) pos_id += 1 for j in range(self.num_neg): self.sampled.append(self.neg_indices[neg_id % self.neg_len]) neg_id += 1 return iter(self.sampled) def __len__ (self): return len(self.sampled)
[docs] class DualSampler(ControlledDataSampler): r""" Dual Sampler aims to customize the number of positives and negatives in mini-batch data for binary classification tasks. For more details, please refer to LibAUC paper[1]_. Args: dataset (torch.utils.data.Dataset): pytorch dataset object for training or evaluation. batch_size (int): number of samples per mini-batch. sampling_rate (float): the ratio of number of positive samples to total number of samples per task in a mini-batch (default: ``0.5``). num_pos (int, optional): number of positive samples in a batch (default: ``None``). labels (list or array, optional): A list or array of labels for the dataset (default: ``None``). shuffle (bool): Whether to shuffle the data before sampling mini-batch data (default: ``True``). num_sampled_tasks (int): number of sampled tasks from original dataset. If None is given, then all labels (tasks) are used for training (default: ``None``). random_seed (int): random seed for reproducibility (default: ``2023``). Example: >>> sampler = libauc.sampler.DualSampler(trainSet, batch_size=32, sampling_rate=0.5) >>> trainloader = torch.utils.data.DataLoader(trainSet, batch_size=32, sampler=sampler, shuffle=False) >>> data, targets, index = next(iter(trainloader)) .. note:: Practical Tips: - In `DualSampler`, ``num_pos`` is equivalent to ``int(sampling_rate * batch_size)``. You can choose to use ``num_pos`` if you want to define the exact number of positive samples per mini-batch. Otherwise, ``sampling_rate`` will be the required parameter by default. - For ``sampling_rate``, we recommended to set a value slightly higher than the proportion of positive samples in your training dataset. For instance, if the ratio of positive sample in your dataset is 0.01, you might consider setting ``sampling_rate`` to 0.05, 0.1, or 0.2. Reference: .. [1] Zhuoning Yuan, Dixian Zhu, Zi-Hao Qiu, Gang Li, Xuanhui Wang, Tianbao Yang. "LibAUC: A Deep Learning Library for X-Risk Optimization." 29th SIGKDD Conference on Knowledge Discovery and Data Mining. https://arxiv.org/abs/2306.03065 """ def __init__(self, dataset, batch_size, labels=None, shuffle=True, num_pos=None, num_sampled_tasks=None, sampling_rate=0.5, random_seed=2023): super().__init__(dataset, batch_size, labels, shuffle, num_pos, num_sampled_tasks, sampling_rate, random_seed) assert self.total_tasks > 1, 'Labels are not binary, e.g., [0, 1]!' self.pos_len = self.class_counts[0][0] self.neg_len = self.class_counts[0][1] self.pos_indices, self.neg_indices = self.pos_indices[0], self.neg_indices[0] np.random.seed(self.random_seed) if shuffle: np.random.shuffle(self.pos_indices) np.random.shuffle(self.neg_indices) self.num_batches = max(self.pos_len//self.num_pos, self.neg_len//self.num_neg) self.pos_ptr, self.neg_ptr = 0, 0 self.sampled = np.zeros(self.num_batches*self.batch_size, dtype=np.int64) def __iter__(self): self.sampled = np.zeros(self.num_batches*self.batch_size, dtype=np.int64) for i in range(self.num_batches): start_index = i*self.batch_size if self.pos_ptr+self.num_pos > self.pos_len: # TODO: edge case - dataset has very limited positive samples e.g., < half of batch size temp = self.pos_indices[self.pos_ptr:] np.random.shuffle(self.pos_indices) self.pos_ptr = (self.pos_ptr+self.num_pos)%self.pos_len self.sampled[start_index:start_index+self.num_pos] = np.concatenate((temp, self.pos_indices[:self.pos_ptr])) else: self.sampled[start_index:start_index+self.num_pos]= self.pos_indices[self.pos_ptr:self.pos_ptr+self.num_pos] self.pos_ptr += self.num_pos start_index += self.num_pos if self.neg_ptr+self.num_neg > self.neg_len: temp = self.neg_indices[self.neg_ptr:] np.random.shuffle(self.neg_indices) self.neg_ptr = (self.neg_ptr+self.num_neg)%self.neg_len self.sampled[start_index:start_index+self.num_neg] = np.concatenate((temp, self.neg_indices[:self.neg_ptr])) else: self.sampled[start_index:start_index+self.num_neg] = self.neg_indices[self.neg_ptr:self.neg_ptr+self.num_neg] self.neg_ptr += self.num_neg return iter(self.sampled) def __len__ (self): return len(self.sampled)
[docs] class TriSampler(ControlledDataSampler): r""" TriSampler aims to customize the number of positives and negatives in mini-batch data for multi-label classification or ranking tasks. For more details, please refer to LibAUC paper[1]_. Args: dataset (torch.utils.data.Dataset): pytorch dataset object for training or evaluation. batch_size_per_task (int): number of samples per mini-batch for each task. num_sampled_tasks (int): number of sampled tasks from original dataset. If None is given, then all labels (tasks) are used for training (default: ``None``). sampling_rate (float): the ratio of number of positive samples to total number of samples per task in a mini-batch (default: ``0.5``). num_pos (int, optional): number of positive samples in a batch (default: ``None``). mode (str, optional): sampling mode for classification or ranking tasks (default: ``'classification'``). labels (list or array, optional): A list or array of labels for the dataset (default: ``None``). shuffle (bool): Whether to shuffle the data before sampling mini-batch data (default: ``True``). random_seed (int): random seed for reproducibility (default: ``2023``). Example: >>> sampler = libauc.sampler.TriSampler(trainSet, batch_size_per_task=32, num_sampled_tasks=10, sampling_rate=0.5) >>> trainloader = torch.utils.data.DataLoader(trainSet, batch_size=320, sampler=sampler, shuffle=False) >>> data, targets, index = next(iter(trainloader)) >>> data_id, task_id = index .. note:: `TriSampler` will return an index tuple of ``(sample_id, task_id)`` and it requires a slight change in your dataloader for the training. See the example below: .. code-block:: python class SampleDataset(torch.utils.data.Dataset): def __init__(self, inputs, targets): self.inputs = inputs self.targets = targets def __len__(self): return len(self.inputs) def __getitem__(self, index): index, task_id = index data = self.inputs[index] target = self.targets[index] return data, target, (index, task_id) .. note:: Practical Tips: - In `classification` mode, ``batch_size_per_task * num_sampled_tasks`` is the total ``batch_size``. If ``num_sampled_tasks`` is not specified, all labels will be used. - In `ranking` mode, ``batch_size_per_task`` is the number of queries, ``num_pos`` is the number of positive items per user, and ``num_sampled_tasks`` is the number of users sampled from the dataset for mini-batch. For example, ``batch_size_per_task=310``, ``num_pos=10``, ``num_sampled_tasks=256`` implies that we sample 256 users per mini-batch data where each user has 10 positive items and 300 negative items. """ def __init__(self, dataset, batch_size_per_task, num_sampled_tasks=None, sampling_rate=0.5, mode='classification', labels=None, shuffle=True, num_pos=None, random_seed=2023): super().__init__(dataset, batch_size_per_task, labels, shuffle, num_pos, None, sampling_rate, random_seed) self.mode = mode assert self.mode in ['classification', 'ranking'], 'TriSampler mode should be classification or ranking' assert self.total_tasks >=3, "TriSampler requires number of tasks >= 3 for given dataset!" self.batch_size_per_task = batch_size_per_task self.num_sampled_tasks = num_sampled_tasks if num_sampled_tasks != None else self.total_tasks # if num_sampled_tasks is not specified, it uses all tasks by default. self.batch_size = self.batch_size_per_task*self.num_sampled_tasks if self.mode == 'classification': self.num_batches = self.labels.shape[0]//(self.batch_size_per_task*self.num_sampled_tasks) else: self.num_batches = self.labels.shape[1]// self.num_sampled_tasks self.num_pos = int(self.batch_size_per_task*self.sampling_rate) if not num_pos else num_pos if self.num_pos < 1: print('batch_size_per_task x sampling_rate < 1 !') self.num_pos = 1 self.num_neg = self.batch_size_per_task - self.num_pos self.pos_len = [self.class_counts[task_id][0] for task_id in range(self.total_tasks)] self.neg_len = [self.class_counts[task_id][1] for task_id in range(self.total_tasks)] self.tasks_ids = np.arange(self.total_tasks) np.random.seed(self.random_seed) if shuffle: np.random.shuffle(self.tasks_ids) for task_id in range(len(self.pos_indices)): np.random.shuffle(self.pos_indices[task_id]) np.random.shuffle(self.neg_indices[task_id]) self.pos_ptr, self.neg_ptr, self.task_ptr = np.zeros(self.total_tasks, dtype=np.int32), np.zeros(self.total_tasks, dtype=np.int32), 0 if self.mode == 'classification': self.sampled = np.zeros(self.num_batches*self.batch_size, dtype=np.int64) self.sampled_tasks = np.zeros(self.num_batches*self.batch_size, dtype=np.int32) else: self.sampled = np.zeros((self.num_batches*self.num_sampled_tasks, self.num_pos+self.num_neg), dtype=np.int32) self.sampled_tasks = np.zeros(self.num_batches*self.num_sampled_tasks, dtype=np.int32) def __iter__(self): sid = 0 for batch_id in range(self.num_batches): start_index = batch_id*self.batch_size if self.num_sampled_tasks < self.total_tasks: task_ids = [] if self.task_ptr + self.num_sampled_tasks >= self.total_tasks: temp = self.tasks_ids[self.task_ptr:] self.task_ptr = (self.task_ptr + self.num_sampled_tasks) % len(self.tasks_ids) np.random.shuffle(self.tasks_ids) task_ids = np.concatenate((temp, self.tasks_ids[:self.task_ptr])) else: task_ids = self.tasks_ids[self.task_ptr:self.task_ptr+self.num_sampled_tasks] self.task_ptr += self.num_sampled_tasks else: self.num_sampled_tasks = self.total_tasks task_ids = self.tasks_ids np.random.shuffle(self.tasks_ids) for idx, task_id in enumerate(task_ids): if self.pos_ptr[task_id]+self.num_pos >= self.pos_len[task_id]: temp = self.pos_indices[task_id][self.pos_ptr[task_id]:] np.random.shuffle(self.pos_indices[task_id]) self.pos_ptr[task_id] = (self.pos_ptr[task_id]+self.num_pos)%self.pos_len[task_id] pos_list = np.concatenate((temp, self.pos_indices[task_id][:self.pos_ptr[task_id]])) else: pos_list = self.pos_indices[task_id][self.pos_ptr[task_id]:self.pos_ptr[task_id]+self.num_pos] self.pos_ptr[task_id] += self.num_pos if self.mode == 'classification': self.sampled[start_index:start_index+self.num_pos] = pos_list self.sampled_tasks[start_index:start_index+self.num_pos] = task_id start_index += self.num_pos else: self.sampled[sid, :self.num_pos] = pos_list if self.neg_ptr[task_id]+self.num_neg >= self.neg_len[task_id]: temp = self.neg_indices[task_id][self.neg_ptr[task_id]:] np.random.shuffle(self.neg_indices[task_id]) self.neg_ptr[task_id] = (self.neg_ptr[task_id]+self.num_neg)%self.neg_len[task_id] neg_list = np.concatenate((temp, self.neg_indices[task_id][:self.neg_ptr[task_id]])) else: neg_list = self.neg_indices[task_id][self.neg_ptr[task_id]:self.neg_ptr[task_id]+self.num_neg] self.neg_ptr[task_id] += self.num_neg if self.mode == 'classification': self.sampled[start_index:start_index+self.num_neg] = neg_list self.sampled_tasks[start_index:start_index+self.num_neg] = task_id start_index += self.num_neg else: self.sampled[sid, self.num_pos:] = neg_list if self.mode == 'ranking': self.sampled_tasks[sid] = task_id sid += 1 return iter(zip(self.sampled, self.sampled_tasks)) # potential issue: task_id can be zero! def __len__ (self): return len(self.sampled)