Source code for libauc.trainer.data.datasets

import logging
import os

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset

from libauc.utils import ImbalancedDataGenerator
from ogb.graphproppred import PygGraphPropPredDataset

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Dataset classes
# ---------------------------------------------------------------------------

[docs] class IndexedDataset(Dataset): """Wraps an existing dataset to return ``(image, target, index)`` tuples. Optionally selects a single column from multi-label targets when ``class_id`` is given. """ def __init__(self, dataset, class_id=None): self.dataset = dataset self.targets = self._load_targets() if len(self.targets.shape) == 2 and class_id is not None: self.targets = self.targets[:, class_id: class_id + 1] def _load_targets(self): targets = [self.dataset[i][1] for i in range(len(self.dataset))] return np.array(targets).astype(np.float32) def __len__(self): return len(self.dataset) def __getitem__(self, idx): # Unpack task_id if TriSampler is being used task_id = None if isinstance(idx, (tuple, list)): idx, task_id = idx image, _ = self.dataset[idx] target = self.targets[idx] return image, target, (idx, task_id) if task_id is not None else idx
[docs] class ImageDataset(Dataset): """In-memory image dataset with train/test augmentation presets.""" def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'): self.images = images.astype(np.uint8) self.targets = targets self.mode = mode self.transform_train = transforms.Compose([ transforms.ToTensor(), transforms.RandomCrop((crop_size, crop_size), padding=None), transforms.RandomHorizontalFlip(), transforms.Resize((image_size, image_size), antialias=True), ]) self.transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Resize((image_size, image_size), antialias=True), ]) def __len__(self): return len(self.images) def __getitem__(self, idx): image = Image.fromarray(self.images[idx].astype('uint8')) target = self.targets[idx] image = self.transform_train(image) if self.mode == 'train' else self.transform_test(image) return image, target, idx
[docs] class ChemicalDataset(Dataset): """OGB molecular graph dataset filtered to a single task column.""" def __init__(self, dataset, class_id): indices = dataset.indices() assert len(dataset.data.y.shape) == 2 y = dataset.data.y[indices, class_id] not_nan = ~np.isnan(y.numpy()) self.targets = y[not_nan].float() self.dataset = dataset[not_nan] pos = int((self.targets == 1).sum()) total = len(self.targets) logger.info(f"[ChemicalDataset] positive: {pos} | rate: {pos / total:.4f}") def __len__(self): return len(self.targets) def __getitem__(self, idx): return self.dataset[idx], self.targets[idx], int(idx)
[docs] class GraphDataset(PygGraphPropPredDataset): """PygGraphPropPredDataset with integer-index support.""" def __getitem__(self, idx): if isinstance(idx, (int, np.int64)): item = self.get(self.indices()[idx]) item.idx = torch.LongTensor([idx]) return item else: return self.index_select(idx)
[docs] class MedicalImageCSVDataset(Dataset): """ General-purpose CSV-backed medical image dataset. Expects a CSV (or DataFrame) with at least an image path column and a binary label column. Image paths may be relative (resolved against ``image_root``) or absolute. Args: source: Path to a metadata CSV **or** a ``pandas.DataFrame`` with the required columns already loaded. Passing a DataFrame avoids writing temporary files. image_root: Directory that relative image paths are resolved against. Ignored for absolute paths. image_col: Column name containing the image filename / path. label_col: Column name containing the binary label (0 / 1). transform: torchvision transform applied to each PIL image. """ def __init__( self, source, image_root: str, image_col: str, label_col: str, transform, ): if isinstance(source, pd.DataFrame): df = source.dropna(subset=[label_col]).reset_index(drop=True) else: df = pd.read_csv(source).dropna(subset=[label_col]).reset_index(drop=True) self.image_root = image_root self.image_col = image_col self.transform = transform self.targets = df[label_col].to_numpy().astype(np.float32) self.image_paths = df[image_col].tolist() pos = int((self.targets == 1).sum()) total = len(self.targets) logger.info(f"[MedicalImageCSVDataset] positive: {pos} | rate: {pos / total:.4f}") def __len__(self): return len(self.targets) def __getitem__(self, idx): rel = self.image_paths[idx] path = rel if os.path.isabs(rel) else os.path.join(self.image_root, rel) image = Image.open(path).convert("RGB") if self.transform is not None: image = self.transform(image) return image, self.targets[idx].reshape(-1), idx
# --------------------------------------------------------------------------- # Shared transform factories # --------------------------------------------------------------------------- def _medical_train_transform(image_size: int = 224) -> transforms.Compose: return transforms.Compose([ transforms.Resize((image_size, image_size), antialias=True), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def _medical_test_transform(image_size: int = 224) -> transforms.Compose: return transforms.Compose([ transforms.Resize((image_size, image_size), antialias=True), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # --------------------------------------------------------------------------- # OGB graph dataset helpers # --------------------------------------------------------------------------- def _safe_import_pyg_globals(): """Register PyG safe globals for torch.serialization when available.""" import torch.serialization try: from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr from torch_geometric.data.storage import GlobalStorage torch.serialization.add_safe_globals( [DataEdgeAttr, DataTensorAttr, GlobalStorage] ) except ImportError: pass # OGB dataset name → (class_id to use for binary classification) _OGB_CLASS_IDS: dict[str, int] = { "ogbg-molhiv": 0, "ogbg-moltox21": 0, "ogbg-molmuv": 1, "ogbg-molpcba": 0, } def _load_ogbg(name: str, class_id: int, root_path: str, splits: list): """Load any OGB graph-property-prediction dataset as binary classification. Args: name: OGB dataset name (e.g. ``"ogbg-molhiv"``). class_id: Column index in ``y`` to use as the binary label. root_path: Root directory for dataset caching. splits: Evaluation splits to return (e.g. ``["val", "test"]``). Returns: (train_dataset, eval_datasets) """ _safe_import_pyg_globals() dataset = GraphDataset(name=name, root=root_path) split_idx = dataset.get_idx_split() train_dataset = ChemicalDataset(dataset[split_idx["train"]], class_id=class_id) eval_datasets = [] for split in splits: if split == "val": eval_datasets.append(ChemicalDataset(dataset[split_idx["valid"]], class_id=class_id)) elif split == "test": eval_datasets.append(ChemicalDataset(dataset[split_idx["test"]], class_id=class_id)) else: raise NotImplementedError( f"Split '{split}' is not implemented for dataset '{name}'." ) return train_dataset, eval_datasets # --------------------------------------------------------------------------- # Dataset loading # ---------------------------------------------------------------------------
[docs] def load_dataset(name: str, splits: list, **kwargs): """ Load a dataset by name and return train + eval splits. Args: name: Dataset identifier (case-insensitive). splits: Evaluation splits to return, e.g. ``["val", "test"]``. **kwargs: Extra dataset-specific keyword arguments from the config. Returns: ``(train_dataset, eval_datasets)`` — both are :class:`torch.utils.data.Dataset` instances whose ``__getitem__`` yields ``(data, label, index)`` tuples, as expected by the Trainer. """ name = name.lower() root_path = kwargs.get("root_path", "./data") # ── OGB graph datasets ─────────────────────────────────────────────────── if name in _OGB_CLASS_IDS: return _load_ogbg(name, _OGB_CLASS_IDS[name], root_path, splits) # ── Image datasets ─────────────────────────────────────────────────────── if name == "catvsdog": raise NotImplementedError(f"Dataset '{name}' is not yet implemented.") elif name == "chexpert": from libauc.datasets import CheXpert from sklearn.model_selection import train_test_split from torch.utils.data import Subset root = os.path.join(root_path, "CheXpert-v1.0-small") val_size = kwargs.get("val_size", 0.05) full_train = CheXpert( csv_path=os.path.join(root, 'train.csv'), image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='train', class_index=-1, verbose=False, ) all_targets = np.array([full_train[i][1] for i in range(len(full_train))]) strat_labels = all_targets[:, -1] if all_targets.ndim == 2 else all_targets train_indices, val_indices = train_test_split( np.arange(len(full_train)), test_size=val_size, stratify=strat_labels, random_state=42, ) train_dataset = IndexedDataset(Subset(full_train, train_indices)) val_dataset = IndexedDataset(Subset(full_train, val_indices)) test_dataset = IndexedDataset(CheXpert( csv_path=os.path.join(root, 'valid.csv'), image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='valid', class_index=-1, verbose=False, )) eval_datasets = [] for split in splits: if split == 'val': eval_datasets.append(val_dataset) elif split == 'test': eval_datasets.append(test_dataset) else: raise NotImplementedError( f"Split '{split}' is not implemented for dataset '{name}'." ) return train_dataset, eval_datasets elif name == "cifar10": from libauc.datasets import CIFAR10 train_data, train_targets = CIFAR10(root=root_path, train=True).as_array() test_data, test_targets = CIFAR10(root=root_path, train=False).as_array() imratio = kwargs.get("imratio", 0.1) generator = ImbalancedDataGenerator(verbose=True, random_seed=0) (train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio) (test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5) train_dataset = ImageDataset(train_images, train_labels) eval_datasets = [] for split in splits: if split == 'val': eval_datasets.append(ImageDataset(train_images, train_labels, mode='test')) elif split == 'test': eval_datasets.append(ImageDataset(test_images, test_labels, mode='test')) else: raise NotImplementedError( f"Split '{split}' is not implemented for dataset '{name}'." ) return train_dataset, eval_datasets elif name == "pneumoniamnist": from medmnist import PneumoniaMNIST train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5]), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5]), ]) train_dataset = IndexedDataset( PneumoniaMNIST(split='train', transform=train_transform, download=True, root=root_path) ) eval_datasets = [] for split in splits: if split == 'val': eval_datasets.append(IndexedDataset( PneumoniaMNIST(split='val', transform=test_transform, download=True, root=root_path) )) elif split == 'test': eval_datasets.append(IndexedDataset( PneumoniaMNIST(split='test', transform=test_transform, download=True, root=root_path) )) else: raise NotImplementedError( f"Split '{split}' is not implemented for dataset '{name}'." ) return train_dataset, eval_datasets elif name == "breastmnist": from medmnist import BreastMNIST transform = transforms.Compose([transforms.ToTensor()]) train_dataset = IndexedDataset( BreastMNIST(split='train', transform=transform, download=True, root=root_path) ) eval_datasets = [] for split in splits: if split == 'val': eval_datasets.append(IndexedDataset( BreastMNIST(split='val', transform=transform, download=True, root=root_path) )) elif split == 'test': eval_datasets.append(IndexedDataset( BreastMNIST(split='test', transform=transform, download=True, root=root_path) )) else: raise NotImplementedError( f"Split '{split}' is not implemented for dataset '{name}'." ) return train_dataset, eval_datasets elif name == "chestmnist": from medmnist import ChestMNIST transform = transforms.Compose([transforms.ToTensor()]) task = kwargs.get("task", None) train_dataset = IndexedDataset( ChestMNIST(split='train', transform=transform, download=True, root=root_path), task ) eval_datasets = [] for split in splits: if split == 'val': eval_datasets.append(IndexedDataset( ChestMNIST(split='val', transform=transform, download=True, root=root_path), task )) elif split == 'test': eval_datasets.append(IndexedDataset( ChestMNIST(split='test', transform=transform, download=True, root=root_path), task )) else: raise NotImplementedError( f"Split '{split}' is not implemented for dataset '{name}'." ) return train_dataset, eval_datasets elif name == "melanoma": from libauc.datasets import Melanoma root = os.path.join(root_path, "melanoma") train_dataset = IndexedDataset(Melanoma(root=root, is_test=False, test_size=0.2)) eval_datasets = [] for split in splits: if split == 'val': eval_datasets.append(IndexedDataset(Melanoma(root=root, is_test=False, test_size=0.2))) elif split == 'test': eval_datasets.append(IndexedDataset(Melanoma(root=root, is_test=True, test_size=0.2))) else: raise NotImplementedError( f"Split '{split}' is not implemented for dataset '{name}'." ) return train_dataset, eval_datasets elif name == "ddsm": from sklearn.model_selection import train_test_split root = os.path.join(root_path, "ddsm") csv_dir = os.path.join(root, "csv") jpeg_dir = os.path.join(root, "jpeg") # ── 1. Build UID → absolute jpeg path lookup from dicom_info.csv ──── dicom_df = pd.read_csv(os.path.join(csv_dir, "dicom_info.csv")) # Fix known labelling bug: 8-bit "Unknown" entries are ROI masks. dicom_df.loc[ (dicom_df["SeriesDescription"] == "Unknown") & (dicom_df["BitsAllocated"] == 8), "SeriesDescription", ] = "ROI mask images" full_mammo = dicom_df[dicom_df["SeriesDescription"] == "full mammogram images"].copy() full_mammo["uid"] = full_mammo["image_path"].str.extract(r"jpeg/([^/]+)/") full_mammo["abs_path"] = full_mammo["image_path"].apply( lambda p: p.replace("CBIS-DDSM/jpeg", jpeg_dir) ) uid_to_path = full_mammo.set_index("uid")["abs_path"].to_dict() # ── 2. Load case-description CSVs ─────────────────────────────────── def load_case_csv(csv_name): path = os.path.join(csv_dir, csv_name) if not os.path.exists(path): return None df = pd.read_csv(path) col = next(c for c in df.columns if "image file path" in c.lower()) df = df.rename(columns={col: "image_file_path"}) df["uid"] = df["image_file_path"].str.extract(r"/([^/]+)/[^/]+$") df["pathology"] = df["pathology"].replace("BENIGN_WITHOUT_CALLBACK", "BENIGN") df["label"] = (df["pathology"] == "MALIGNANT").astype(np.float32) return df[["patient_id", "uid", "label", "pathology"]] train_case_dfs, test_case_dfs = [], [] for csv_name in ("calc_case_description_train_set.csv", "mass_case_description_train_set.csv"): df = load_case_csv(csv_name) if df is not None: train_case_dfs.append(df) for csv_name in ("calc_case_description_test_set.csv", "mass_case_description_test_set.csv"): df = load_case_csv(csv_name) if df is not None: test_case_dfs.append(df) train_cases = pd.concat(train_case_dfs, ignore_index=True).drop_duplicates(subset=["uid"]) test_cases = pd.concat(test_case_dfs, ignore_index=True).drop_duplicates(subset=["uid"]) # ── 3. Map UID → absolute image path ──────────────────────────────── def attach_paths(df): df = df.copy() df["image_path"] = df["uid"].map(uid_to_path) return df.dropna(subset=["image_path"]).reset_index(drop=True) train_pool = attach_paths(train_cases) test_df = attach_paths(test_cases) # ── 4. Stratified val split from training pool ─────────────────────── val_size = kwargs.get("val_size", 0.1) train_idx, val_idx = train_test_split( np.arange(len(train_pool)), test_size=val_size, stratify=train_pool["label"].values, random_state=42, ) train_df = train_pool.iloc[train_idx].reset_index(drop=True) val_df = train_pool.iloc[val_idx].reset_index(drop=True) # ── 5. Build datasets — pass DataFrame directly to avoid temp files ── image_size = kwargs.get("image_size", 224) train_transform = _medical_train_transform(image_size) test_transform = _medical_test_transform(image_size) def _df_to_dataset(df, transform): return MedicalImageCSVDataset( source=df[["image_path", "label"]], image_root="", image_col="image_path", label_col="label", transform=transform, ) train_dataset = _df_to_dataset(train_df, train_transform) eval_datasets = [] for split in splits: if split == "val": eval_datasets.append(_df_to_dataset(val_df, test_transform)) elif split == "test": eval_datasets.append(_df_to_dataset(test_df, test_transform)) else: raise NotImplementedError( f"Split '{split}' is not implemented for dataset '{name}'." ) return train_dataset, eval_datasets else: raise ValueError( f"Unknown dataset: '{name}'. " "Please add a branch for it inside load_dataset()." )