from typing import List
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from libauc.utils import ImbalancedDataGenerator
import pandas as pd
from ogb.graphproppred import PygGraphPropPredDataset
import torch
import os
# ---------------------------------------------------------------------------
# Dataset loading
# ---------------------------------------------------------------------------
[docs]
class IndexedDataset(Dataset):
def __init__(self, dataset, class_id = None):
self.dataset = dataset
self.targets = self._load_targets()
if len(self.targets.shape) == 2 and class_id is not None:
self.targets = self.targets[:, class_id : class_id + 1]
def _load_targets(self):
targets = [self.dataset[i][1] for i in range(len(self.dataset))]
return np.array(targets).astype(np.float32)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# Unpack task_id if TriSampler is being used
task_id = None
if isinstance(idx, (tuple, list)):
idx, task_id = idx
image, _ = self.dataset[idx]
target = self.targets[idx]
return image, target, (idx, task_id) if task_id is not None else idx
[docs]
class ImageDataset(Dataset):
def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
self.images = images.astype(np.uint8)
self.targets = targets
self.mode = mode
self.transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.RandomCrop((crop_size, crop_size), padding=None),
transforms.RandomHorizontalFlip(),
transforms.Resize((image_size, image_size), antialias=True),
])
self.transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((image_size, image_size), antialias=True),
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
target = self.targets[idx]
image = Image.fromarray(image.astype('uint8'))
if self.mode == 'train':
image = self.transform_train(image)
else:
image = self.transform_test(image)
return image, target, idx
[docs]
class ChemicalDataset(Dataset):
def __init__(self, dataset, class_id):
self.targets = []
# dataset.indices() gives the actual indices into the full dataset
indices = dataset.indices()
assert(len(dataset.data.y.shape) == 2)
y = dataset.data.y[indices, class_id]
not_nan = ~np.isnan(y.numpy()) # shape matches len(dataset)
self.targets = y[not_nan].float()
self.dataset = dataset[not_nan]
try:
tmp=np.array(self.targets)
pos = len(tmp[tmp==1])
print('positive: ' + str(pos))
print('positive rate: '+ str(float(pos)/len(tmp)))
except:
print('positive rate error ')
def __len__(self):
return len(self.targets)
def __getitem__(self, idx):
return self.dataset[idx], self.targets[idx], int(idx)
[docs]
class GraphDataset(PygGraphPropPredDataset):
def __getitem__(self, idx):
if isinstance(idx, (int,np.int64)):
item = self.get(self.indices()[idx])
item.idx = torch.LongTensor([idx])
return item
else:
return self.index_select(idx)
[docs]
class TextDataset(Dataset):
def __init__(self, dataframe, text_col, label_col):
self.len = len(dataframe)
self.data = dataframe
self.text_col = text_col
self.targets = self.data[label_col].to_numpy().astype(np.float32)
self.texts = self.data[text_col]
def __getitem__(self, index):
text_inputs = self.texts[index]
targets = self.targets[index]
return text_inputs, targets, index
def __len__(self):
return self.len
[docs]
class MedicalImageCSVDataset(Dataset):
"""
General-purpose CSV-backed medical image dataset.
Expects a CSV with at least an image path column and a binary label column.
Image paths in the CSV may be relative (resolved against ``image_root``) or
absolute.
Args:
csv_path: Path to the metadata CSV.
image_root: Directory that image paths are resolved against when they
are not absolute. Ignored for absolute paths.
image_col: Column name containing the image filename / path.
label_col: Column name containing the binary label (0 / 1).
transform: torchvision transform applied to each PIL image.
"""
def __init__(
self,
csv_path: str,
image_root: str,
image_col: str,
label_col: str,
transform,
):
df = pd.read_csv(csv_path)
# Drop rows with missing labels
df = df.dropna(subset=[label_col]).reset_index(drop=True)
self.image_root = image_root
self.image_col = image_col
self.transform = transform
self.targets = df[label_col].to_numpy().astype(np.float32)
self.image_paths = df[image_col].tolist()
pos = int((self.targets == 1).sum())
total = len(self.targets)
print(f"[MedicalImageCSVDataset] positive: {pos} | rate: {pos/total:.4f}")
def __len__(self):
return len(self.targets)
def __getitem__(self, idx):
rel = self.image_paths[idx]
path = rel if os.path.isabs(rel) else os.path.join(self.image_root, rel)
image = Image.open(path).convert("RGB")
if self.transform is not None:
image = self.transform(image)
return image, self.targets[idx].reshape(-1), idx
# ---------------------------------------------------------------------------
# Shared transform factories
# ---------------------------------------------------------------------------
def _medical_train_transform(image_size: int = 224) -> transforms.Compose:
return transforms.Compose([
transforms.Resize((image_size, image_size), antialias=True),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def _medical_test_transform(image_size: int = 224) -> transforms.Compose:
return transforms.Compose([
transforms.Resize((image_size, image_size), antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ---------------------------------------------------------------------------
# OGB graph dataset helpers (shared by ogbg-molhiv and existing datasets)
# ---------------------------------------------------------------------------
def _safe_import_pyg_globals():
"""Register PyG safe globals for torch.serialization when available."""
import torch.serialization
try:
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import GlobalStorage
torch.serialization.add_safe_globals(
[DataEdgeAttr, DataTensorAttr, GlobalStorage]
)
except ImportError:
pass
[docs]
def load_dataset(name: str, splits: List[str], **kwargs) -> Dataset:
"""
Load a dataset by name and split.
Args:
name: Dataset identifier (e.g. "catvsdog", "chexpert").
splits: Evaluation splits.
**kwargs: Extra dataset-specific keyword arguments from the config.
Returns:
A torch.utils.data.Dataset whose __getitem__ yields
(data, label, index) tuples, as expected by the Trainer.
TODO: Implement each dataset branch below.
"""
name = name.lower()
root_path = kwargs.get("root_path", "./data")
if name == "catvsdog":
raise NotImplementedError(f"Dataset '{name}' is not yet implemented.")
elif name == "chexpert":
from libauc.datasets import CheXpert
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
root = os.path.join(root_path, "CheXpert-v1.0-small")
val_size = kwargs.get("val_size", 0.05) # fraction of training data held out as validation
# Load the full training split so we can read its labels for stratification
full_train = CheXpert(
csv_path=os.path.join(root, 'train.csv'),
image_root_path=root,
use_upsampling=False,
use_frontal=True,
image_size=224,
mode='train',
class_index=-1,
verbose=False,
)
# Stratified split: keep positive rate stable across train and val
# CheXpert labels may be multi-label; stratify on the last column (class_index=-1)
all_targets = np.array([full_train[i][1] for i in range(len(full_train))])
strat_labels = all_targets[:, -1] if all_targets.ndim == 2 else all_targets
train_indices, val_indices = train_test_split(
np.arange(len(full_train)),
test_size=val_size,
stratify=strat_labels,
random_state=42,
)
train_dataset = IndexedDataset(Subset(full_train, train_indices))
val_dataset = IndexedDataset(Subset(full_train, val_indices))
# The official CheXpert valid.csv is used as the test split
test_dataset = IndexedDataset(CheXpert(
csv_path=os.path.join(root, 'valid.csv'),
image_root_path=root,
use_upsampling=False,
use_frontal=True,
image_size=224,
mode='valid',
class_index=-1,
verbose=False,
))
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(val_dataset)
elif split == 'test':
eval_datasets.append(test_dataset)
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "cifar10":
from libauc.datasets import CIFAR10
# load data as numpy arrays
train_data, train_targets = CIFAR10(root=root_path, train=True).as_array()
test_data, test_targets = CIFAR10(root=root_path, train=False).as_array()
imratio = kwargs.get("imratio", 0.1)
# generate imbalanced data
generator = ImbalancedDataGenerator(verbose=True, random_seed=0)
(train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio)
(test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5)
train_dataset = ImageDataset(train_images, train_labels)
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(ImageDataset(train_images, train_labels, mode='test'))
elif split == 'test':
eval_datasets.append(ImageDataset(test_images, test_labels, mode='test'))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "pneumoniamnist":
from medmnist import PneumoniaMNIST
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
train_dataset = IndexedDataset(PneumoniaMNIST(split='train', transform=train_transform, download=True, root=root_path))
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(IndexedDataset(PneumoniaMNIST(split='val', transform=test_transform, download=True, root=root_path)))
elif split == 'test':
eval_datasets.append(IndexedDataset(PneumoniaMNIST(split='test', transform=test_transform, download=True, root=root_path)))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "breastmnist":
from medmnist import BreastMNIST
train_transform = transforms.Compose([
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = IndexedDataset(BreastMNIST(split='train', transform=train_transform, download=True, root=root_path))
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(IndexedDataset(BreastMNIST(split='val', transform=test_transform, download=True, root=root_path)))
elif split == 'test':
eval_datasets.append(IndexedDataset(BreastMNIST(split='test', transform=test_transform, download=True, root=root_path)))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "chestmnist":
from medmnist import ChestMNIST
train_transform = transforms.Compose([
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
task = kwargs.get("task", None)
train_dataset = IndexedDataset(ChestMNIST(split='train', transform=train_transform, download=True, root=root_path), task)
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(IndexedDataset(ChestMNIST(split='val', transform=test_transform, download=True, root=root_path), task))
elif split == 'test':
eval_datasets.append(IndexedDataset(ChestMNIST(split='test', transform=test_transform, download=True, root=root_path), task))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "ogbg-moltox21":
_safe_import_pyg_globals()
dataset = GraphDataset(name='ogbg-moltox21', root=root_path)
split_idx = dataset.get_idx_split()
train_dataset = ChemicalDataset(dataset[split_idx["train"]], class_id=0)
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(ChemicalDataset(dataset[split_idx["valid"]], class_id=0))
elif split == 'test':
eval_datasets.append(ChemicalDataset(dataset[split_idx["test"]], class_id=0))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "ogbg-molmuv":
_safe_import_pyg_globals()
dataset = GraphDataset(name='ogbg-molmuv', root=root_path)
split_idx = dataset.get_idx_split()
train_dataset = ChemicalDataset(dataset[split_idx["train"]], class_id=1)
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(ChemicalDataset(dataset[split_idx["valid"]], class_id=1))
elif split == 'test':
eval_datasets.append(ChemicalDataset(dataset[split_idx["test"]], class_id=1))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "ogbg-molpcba":
_safe_import_pyg_globals()
dataset = GraphDataset(name='ogbg-molpcba', root=root_path)
split_idx = dataset.get_idx_split()
train_dataset = ChemicalDataset(dataset[split_idx["train"]], class_id = 0)
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(ChemicalDataset(dataset[split_idx["valid"]], class_id = 0))
elif split == 'test':
eval_datasets.append(ChemicalDataset(dataset[split_idx["test"]], class_id = 0))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "ogbg-molhiv":
_safe_import_pyg_globals()
dataset = GraphDataset(name='ogbg-molhiv', root=root_path)
split_idx = dataset.get_idx_split()
train_dataset = ChemicalDataset(dataset[split_idx["train"]], class_id=0)
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(ChemicalDataset(dataset[split_idx["valid"]], class_id=0))
elif split == 'test':
eval_datasets.append(ChemicalDataset(dataset[split_idx["test"]], class_id=0))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "melanoma":
from libauc.datasets import Melanoma
root = os.path.join(root_path, "melanoma")
train_dataset = IndexedDataset(Melanoma(root=root, is_test=False, test_size=0.2))
eval_datasets = []
for split in splits:
if split == 'val':
eval_datasets.append(IndexedDataset(Melanoma(root=root, is_test=False, test_size=0.2)))
elif split == 'test':
eval_datasets.append(IndexedDataset(Melanoma(root=root, is_test=True, test_size=0.2)))
else:
raise NotImplementedError(f"Split '{split}' is not yet implemented for dataset '{name}'.")
return train_dataset, eval_datasets
elif name == "ddsm":
from sklearn.model_selection import train_test_split
root = os.path.join(root_path, "ddsm")
csv_dir = os.path.join(root, "csv")
jpeg_dir = os.path.join(root, "jpeg")
# ------------------------------------------------------------------
# 1. Build a UID → absolute jpeg path lookup from dicom_info.csv
# image_path looks like:
# "CBIS-DDSM/jpeg/<UID>/<filename>.jpg"
# ------------------------------------------------------------------
dicom_df = pd.read_csv(os.path.join(csv_dir, "dicom_info.csv"))
# Fix known labeling bug: 8-bit "Unknown" entries are ROI masks
dicom_df.loc[
(dicom_df["SeriesDescription"] == "Unknown") & (dicom_df["BitsAllocated"] == 8),
"SeriesDescription",
] = "ROI mask images"
full_mammo = dicom_df[dicom_df["SeriesDescription"] == "full mammogram images"].copy()
# Extract UID (the folder immediately after "jpeg/")
full_mammo["uid"] = full_mammo["image_path"].str.extract(r"jpeg/([^/]+)/")
full_mammo["abs_path"] = full_mammo["image_path"].apply(
lambda p: p.replace("CBIS-DDSM/jpeg", jpeg_dir)
)
uid_to_path = full_mammo.set_index("uid")["abs_path"].to_dict()
# ------------------------------------------------------------------
# 2. Load case-description CSVs
# They have an "image file path" column like:
# "Mass-Training_P_00001_LEFT_CC/.../<UID>/000000.dcm"
# We extract the UID from there to join with uid_to_path.
# ------------------------------------------------------------------
def load_case_csv(csv_name):
path = os.path.join(csv_dir, csv_name)
if not os.path.exists(path):
return None
df = pd.read_csv(path)
# Normalise the image-path column name (may have spaces)
col = next(c for c in df.columns if "image file path" in c.lower())
df = df.rename(columns={col: "image_file_path"})
# The UID is the second-to-last path component before the filename
# e.g. ".../1.3.6.../000000.dcm" → UID = "1.3.6..."
df["uid"] = df["image_file_path"].str.extract(r"/([^/]+)/[^/]+$")
df["pathology"] = df["pathology"].replace("BENIGN_WITHOUT_CALLBACK", "BENIGN")
df["label"] = (df["pathology"] == "MALIGNANT").astype(np.float32)
return df[["patient_id", "uid", "label", "pathology"]]
train_case_dfs, test_case_dfs = [], []
for csv_name in (
"calc_case_description_train_set.csv",
"mass_case_description_train_set.csv",
):
df = load_case_csv(csv_name)
if df is not None:
train_case_dfs.append(df)
for csv_name in (
"calc_case_description_test_set.csv",
"mass_case_description_test_set.csv",
):
df = load_case_csv(csv_name)
if df is not None:
test_case_dfs.append(df)
train_cases = pd.concat(train_case_dfs, ignore_index=True).drop_duplicates(subset=["uid"])
test_cases = pd.concat(test_case_dfs, ignore_index=True).drop_duplicates(subset=["uid"])
# ------------------------------------------------------------------
# 3. Map UID → absolute image path
# ------------------------------------------------------------------
def attach_paths(df):
df = df.copy()
df["image_path"] = df["uid"].map(uid_to_path)
df = df.dropna(subset=["image_path"]).reset_index(drop=True)
return df
train_pool = attach_paths(train_cases)
test_df = attach_paths(test_cases)
# ------------------------------------------------------------------
# 4. Stratified val split from training pool
# ------------------------------------------------------------------
val_size = kwargs.get("val_size", 0.1)
train_idx, val_idx = train_test_split(
np.arange(len(train_pool)),
test_size=val_size,
stratify=train_pool["label"].values,
random_state=42,
)
train_df = train_pool.iloc[train_idx].reset_index(drop=True)
val_df = train_pool.iloc[val_idx].reset_index(drop=True)
# ------------------------------------------------------------------
# 5. Build datasets
# ------------------------------------------------------------------
image_size = kwargs.get("image_size", 224)
train_transform = _medical_train_transform(image_size)
test_transform = _medical_test_transform(image_size)
import tempfile
def _df_to_dataset(df, transform):
tmp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w")
df[["image_path", "label"]].to_csv(tmp.name, index=False)
tmp.close()
return MedicalImageCSVDataset(
csv_path=tmp.name,
image_root="",
image_col="image_path",
label_col="label",
transform=transform,
)
train_dataset = _df_to_dataset(train_df, train_transform)
eval_datasets = []
for split in splits:
if split == "val":
eval_datasets.append(_df_to_dataset(val_df, test_transform))
elif split == "test":
eval_datasets.append(_df_to_dataset(test_df, test_transform))
else:
raise NotImplementedError(
f"Split '{split}' is not yet implemented for dataset '{name}'."
)
return train_dataset, eval_datasets
else:
raise ValueError(
f"Unknown dataset: '{name}'. "
"Please add a branch for it inside load_dataset()."
)