import logging
import torch
import numpy as np
from libauc.sampler import DualSampler, TriSampler
import os
from typing import Callable, List, Optional, Mapping
from torch.utils.data import Dataset
import importlib
from ..config.args import TrainingArguments
from .callbacks import CallbackHandler, TrainerCallback, TrainerState
logger = logging.getLogger(__name__)
# Keys that identify classification-head parameters in both built-in libauc
# models (fc / linear) and HuggingFace models (classifier / head).
_HEAD_PARTS: frozenset = frozenset({"fc", "linear", "classifier", "head"})
class _ChannelExpand:
"""DataLoader collate that expands single-channel images to 3 channels.
Used when the dataset contains grayscale images (1 channel) but the
model expects RGB input (3 channels). The single channel is repeated
three times: ``(N, 1, H, W)`` → ``(N, 3, H, W)``.
Defined at module level for DataLoader worker pickling.
"""
def __call__(self, batch):
from torch.utils.data import default_collate
images, targets, indices = default_collate(batch)
images = images.expand(-1, 3, -1, -1).contiguous()
return images, targets, indices
class _HFCollate:
"""DataLoader collate function that applies an ``AutoImageProcessor`` to a batch.
Defined at module level so it is picklable by DataLoader worker processes.
Image preprocessing (resize + model-specific normalisation) happens here,
on CPU inside a DataLoader worker, **before** the batch is transferred to
the GPU. This avoids the costly GPU→CPU→GPU round-trip that occurs when
preprocessing runs inside :meth:`~_HFModelWrapper.forward`.
Single-channel (grayscale) inputs are automatically expanded to 3 channels
before the processor is called, so the dataset requires no special handling.
.. note::
Assumes images are ``float32`` tensors in the ``[0, 1]`` range (i.e.
the output of ``torchvision.transforms.ToTensor()``).
``do_rescale=False`` is set to avoid multiplying by 1/255 again.
Args:
processor: An ``AutoImageProcessor`` instance.
"""
def __init__(self, processor) -> None:
self.processor = processor
def __call__(self, batch):
from torch.utils.data import default_collate
images, targets, indices = default_collate(batch)
# Expand grayscale (1 ch) to RGB (3 ch) before the processor.
if images.shape[1] == 1:
images = images.expand(-1, 3, -1, -1).contiguous()
out = self.processor(images=images, return_tensors="pt", do_rescale=False)
return out["pixel_values"], targets, indices
class _HFModelWrapper(torch.nn.Module):
"""Thin wrapper so HuggingFace image-classification models return raw logits.
HuggingFace models return a ``SequenceClassifierOutput`` (or similar
``ModelOutput``) whose ``.logits`` field holds the raw class scores.
This wrapper unpacks that so the rest of the Trainer sees a plain
:class:`torch.Tensor`, identical in shape to a built-in libauc backbone.
Image preprocessing (resize + normalise) is handled upstream by
:class:`_HFCollate` in the DataLoader worker. By the time
:meth:`forward` is called the tensor is already in the model's expected
format, so no CPU/GPU transfers are needed here.
Args:
hf_model (torch.nn.Module): Any ``transformers`` model that accepts
a ``pixel_values`` keyword argument and returns an object with a
``.logits`` attribute (e.g. loaded via
``AutoModelForImageClassification``).
"""
def __init__(self, hf_model: torch.nn.Module) -> None:
super().__init__()
self.hf_model = hf_model
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return raw logits as a plain tensor."""
return self.hf_model(pixel_values=x).logits
[docs]
class Trainer:
r"""
Full training loop for image-classification models supported by libauc.
``Trainer`` wires together a model, an AUC-specific loss function, a libauc
optimizer, dual/tri-sampled data loaders, and an optional evaluation
pipeline behind a unified :meth:`train` entry point. Progress is
surfaced through a :class:`~trainer.core.callbacks.CallbackHandler` so
any number of :class:`~trainer.core.callbacks.TrainerCallback` subclasses
can observe or alter the training loop without touching ``Trainer``
internals.
The class is intentionally thin: heavy lifting (data loading, model
construction, loss/optimizer instantiation) is delegated to private
helpers so subclasses like :class:`~trainer.core.graph_trainer.GraphTrainer`
can override only the parts they need.
Args:
train_args (TrainingArguments): Fully populated training configuration
produced by :class:`~trainer.config.args.TrainingArguments`.
model_cfg (dict): Architecture config forwarded to
:meth:`_build_model`. Must contain at least a ``"name"`` key.
Built-in architectures: ``resnet20``, ``resnet18``,
``densenet121``. Any HuggingFace model repo ID containing
a ``/`` is also accepted (e.g.
``"google/vit-base-patch16-224"``,
``"openai/clip-vit-base-patch32"``). For HF models the
``AutoImageProcessor`` is applied inside the trainer — the
dataset needs no HF-specific transforms.
train_dataset (Dataset): PyTorch ``Dataset`` for the training split.
Must expose a ``.targets`` attribute (list or array of labels).
eval_dataset (list[Dataset], optional): One or more evaluation
datasets. ``None`` disables evaluation (default: ``None``).
metric (callable, optional): ``(y_true, y_pred) -> dict[str, float]``
function returned by :func:`~trainer.helpers.build_metric`.
``None`` disables metric computation (default: ``None``).
callbacks (list[TrainerCallback], optional): Callbacks invoked at
every lifecycle hook. When ``None`` the handler is created with
an empty list (default: ``None``).
Example::
>>> from trainer.config.args import TrainingArguments
>>> from trainer.core.image_trainer import Trainer
>>> from trainer.core.callbacks import CLICallback
>>> train_args = TrainingArguments(
... optimizer="PESG", optimizer_kwargs={"lr": 0.1},
... loss="AUCMLoss", loss_kwargs={"margin": 1.0},
... SEED=42, batch_size=128, eval_batch_size=128,
... sampling_rate=0.5, epochs=50, decay_epochs=[],
... num_workers=2, output_path="./output", num_tasks=1,
... resume_from_checkpoint=False, save_checkpoint_every=5,
... project_name="libauc", experiment_name="demo", verbose=1,
... )
>>> trainer = Trainer(
... train_args=train_args,
... model_cfg={"name": "resnet18"},
... train_dataset=train_ds,
... eval_dataset=[val_ds],
... metric=metric_fn,
... callbacks=[CLICallback()],
... )
>>> log = trainer.train()
HuggingFace model — no dataset changes needed, the trainer preprocesses
internally::
>>> trainer = Trainer(
... train_args=train_args,
... model_cfg={
... "name": "google/vit-base-patch16-224",
... },
... train_dataset=train_ds,
... eval_dataset=[val_ds],
... metric=metric_fn,
... )
>>> log = trainer.train()
"""
def __init__(self,
train_args: TrainingArguments,
model_cfg: dict,
train_dataset: Dataset,
eval_dataset: Optional[List[Dataset]] = None,
metric: Optional[Callable[[torch.Tensor, torch.Tensor], Mapping[str, float]]] = None,
callbacks: Optional[List[TrainerCallback]] = None):
"""
Initialize the trainer.
Args:
train_args: Training configuration arguments.
model_cfg: Model architecture config (dict with at least a ``"name"`` key).
train_dataset: Training dataset.
eval_dataset: Optional list of evaluation datasets.
metric: Evaluation metric function ``(y_true, y_pred) -> dict``.
callbacks: Optional list of training callbacks.
"""
self.args = train_args
self.train_dataset = train_dataset # set before _build_model so it
self._build_model(model_cfg) # can infer in_channels from data
self.eval_dataset = eval_dataset
self.state = TrainerState()
self.state.total_epoch = self.args.epochs
# Setup data loaders
self.sampler, self.trainloader = self._get_train_dataloader(self.args)
self.evalloaders = []
if self.eval_dataset:
for dataset in self.eval_dataset:
self.evalloaders.append(self._get_eval_dataloader(dataset, self.args))
# Calculate dataset statistics
if isinstance(self.sampler.pos_len, list):
self.data_len = self.sampler.pos_len[0] + self.sampler.neg_len[0]
else:
self.data_len = self.sampler.pos_len + self.sampler.neg_len
self.pos_len = self.sampler.pos_len
self.neg_len = self.sampler.neg_len
# Setup optimizer and loss
self.loss_fn, self.optimizer = self._construct_optimizer_and_loss(self.model, train_args)
# Setup metric and callbacks
self.metric = metric
if callbacks is None:
self.callback_handler = CallbackHandler([], self.model, self.optimizer, self.loss_fn)
else:
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.loss_fn)
self.callback_handler.on_init_end(self.args, self.state)
[docs]
def add_callback(self, callback):
"""Add a callback to the trainer."""
self.callback_handler.add_callback(callback)
def _build_model(self, model_cfg: dict):
"""
Build a model from the ``model`` config block.
Expected keys
-------------
name Architecture name. Use one of the built-in shortcuts
(``resnet20``, ``resnet18``, ``densenet121``) **or**
any HuggingFace model repo ID that contains a ``/``
(e.g. ``"google/vit-base-patch16-224"``,
``"openai/clip-vit-base-patch32"``).
HF models are loaded via
``AutoModelForImageClassification.from_pretrained``
and wrapped in :class:`_HFModelWrapper`. The
``AutoImageProcessor`` is applied in the DataLoader
collate function so the external dataset requires
**no** HF-specific transforms — a plain
``ToTensor()`` is sufficient.
pretrained bool — load a local checkpoint from
``pretrained_path`` (default ``False``).
pretrained_remote bool — use the architecture's own pretrained
ImageNet weights where supported (default ``False``).
Ignored for HF models.
pretrained_path str — path to a local ``.pt`` checkpoint. Required
when ``pretrained=True``. Head parameters identified
by :data:`_HEAD_PARTS` are excluded from loading and
re-initialised so the model can be fine-tuned on a
different number of classes.
Attributes set on ``self``
--------------------------
image_processor ``AutoImageProcessor`` instance for the HF model
(``None`` for built-in models). The trainer uses
it via :class:`_HFCollate` in the DataLoader —
do **not** apply it again in your dataset transforms.
"""
name = model_cfg.get("name", "")
name_lower = name.lower()
pretrained = model_cfg.get("pretrained", False)
pretrained_remote = model_cfg.get("pretrained_remote", False)
is_hf_model = "/" in name
num_classes = 1 if self.args.num_tasks <= 2 else self.args.num_tasks
in_channels = 3 # sensible default (RGB)
try:
first_sample = self.train_dataset[0]
img = first_sample[0] if isinstance(first_sample, (tuple, list)) else first_sample
if hasattr(img, "shape") and img.ndim >= 3:
in_channels = img.shape[0]
except Exception:
pass
if in_channels not in (1, 3):
raise ValueError(
f"Dataset images have {in_channels} channels. "
f"Only 1 (grayscale) and 3 (RGB) are supported."
)
# If grayscale, flag expansion; all model constructors receive in_channels=3.
self._expand_to_rgb = (in_channels == 1)
if self._expand_to_rgb:
logger.info(
"Grayscale input (1 channel) detected — images will be "
"expanded to 3 channels (L→RGB repeat) in the DataLoader "
"collate function."
)
in_channels = 3
logger.info(f"Auto-detected: num_classes={num_classes}, in_channels={in_channels}")
self.image_processor = None # public reference (for inspection)
self._hf_processor = None # used by _make_collate_fn()
if name_lower == "resnet20":
from libauc.models import resnet20
model = resnet20(last_activation=None, num_classes=num_classes)
elif name_lower == "resnet18":
from libauc.models import resnet18
model = resnet18(pretrained=pretrained_remote, last_activation=None, num_classes=num_classes)
elif name_lower == "densenet121":
from libauc.models import densenet121
model = densenet121(pretrained=pretrained_remote, last_activation=None, activations='relu', num_classes=num_classes)
elif is_hf_model:
from transformers import AutoImageProcessor, AutoModelForImageClassification
try:
_proc = AutoImageProcessor.from_pretrained(name)
self.image_processor = _proc # public reference
self._hf_processor = _proc # picked up by DataLoaders
_sz = getattr(_proc, "size", "unknown")
logger.info(
f"HF image processor for '{name}': native size = {_sz}. "
"Preprocessing applied in DataLoader collate_fn."
)
except Exception as exc:
logger.warning(
f"Could not load AutoImageProcessor for '{name}': {exc}. "
"Raw tensors will be passed directly to the model — ensure "
"your dataset already matches the model's expected resolution "
"and normalisation."
)
# ── Model ────────────────────────────────────────────────────────
hf_model = AutoModelForImageClassification.from_pretrained(
name,
num_labels=num_classes,
ignore_mismatched_sizes=True,
)
model = _HFModelWrapper(hf_model)
logger.info(f"Loaded HuggingFace model '{name}' | num_labels={num_classes}")
else:
raise ValueError(
f"Unknown model '{name}'. Use a built-in name "
f"(resnet20, resnet18, densenet121) or a HuggingFace repo ID "
f"containing '/' (e.g. 'google/vit-base-patch16-224')."
)
model = model.cuda()
if pretrained:
pretrained_path = model_cfg.get("pretrained_path")
if not pretrained_path:
raise ValueError(
"pretrained=True but 'pretrained_path' is not set in model_cfg."
)
state_dict = torch.load(pretrained_path, weights_only=False)
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
filtered = {
k: v for k, v in state_dict.items()
if not any(part in k.split(".") for part in _HEAD_PARTS)
}
msg = model.load_state_dict(filtered, strict=False)
logger.info(msg)
# Re-initialise the head so it starts from scratch.
if is_hf_model:
inner = model.hf_model
for attr in ("classifier", "head"):
head = getattr(inner, attr, None)
if head is not None and hasattr(head, "reset_parameters"):
head.reset_parameters()
else:
for attr in ("fc", "linear"):
head = getattr(model, attr, None)
if head is not None and hasattr(head, "reset_parameters"):
head.reset_parameters()
self.model = model
@staticmethod
def _get_optimizer(name: str):
"""Return the optimizer class *name* from ``libauc.optimizers``."""
opt = importlib.import_module("libauc.optimizers")
return getattr(opt, name, None)
@staticmethod
def _get_loss(name: str):
"""Return the loss class *name*, checking ``libauc.losses`` then ``torch.nn``."""
libauc_losses = importlib.import_module("libauc.losses")
loss_cls = getattr(libauc_losses, name, None)
if loss_cls is None:
import torch.nn as nn
loss_cls = getattr(nn, name, None)
if loss_cls is None:
raise ValueError(f"Loss function '{name}' not found in libauc.losses or torch.nn")
return loss_cls
def _construct_optimizer_and_loss(self, model, train_args: TrainingArguments):
"""Construct optimizer and loss function based on configuration."""
# Setup loss function
loss_cls = self._get_loss(train_args.loss)
if train_args.loss in ["BCELoss", "CrossEntropyLoss"]:
train_args.loss_kwargs.pop("num_labels", None)
if train_args.loss in ["pAUCLoss", "MultiLabelpAUCLoss"]:
if train_args.loss_kwargs["mode"] in ['SOPA']:
loss_fn = loss_cls(data_len=self.data_len, pos_len=self.pos_len, **train_args.loss_kwargs)
else:
loss_fn = loss_cls(data_len=self.data_len, **train_args.loss_kwargs)
elif train_args.loss in ["mAPLoss", "APLoss", "pAUC_DRO_Loss", "tpAUC_KL_Loss"]:
loss_fn = loss_cls(data_len=self.data_len, **train_args.loss_kwargs)
elif train_args.loss in ["pAUC_CVaR_Loss"]:
loss_fn = loss_cls(data_len=self.data_len, pos_len=self.pos_len, **train_args.loss_kwargs)
elif train_args.loss in ["tpAUC_CVaR_loss"]:
loss_fn = loss_cls(data_length=self.data_len, **train_args.loss_kwargs)
else:
loss_fn = loss_cls(**train_args.loss_kwargs)
# Setup optimizer
opt_cls = self._get_optimizer(train_args.optimizer)
optimizer = opt_cls(model.parameters(), loss_fn=loss_fn, **train_args.optimizer_kwargs)
return loss_fn, optimizer
def _make_collate_fn(self):
"""Return the appropriate collate function for this trainer instance.
Priority:
- HF model → :class:`_HFCollate` (handles both channel expansion and
processor-based resize/normalise).
- Built-in model + grayscale input → :class:`_ChannelExpand` (repeats
the single channel three times to produce RGB tensors).
- Built-in model + RGB input → ``None`` (PyTorch default collation).
"""
proc = getattr(self, "_hf_processor", None)
expand = getattr(self, "_expand_to_rgb", False)
if proc is not None:
return _HFCollate(proc) # _HFCollate handles expansion internally
if expand:
return _ChannelExpand()
return None
def _get_train_dataloader(self, train_args: TrainingArguments):
"""Create training data loader with dual sampling."""
if train_args.num_tasks >= 3:
sampler = TriSampler(self.train_dataset, train_args.batch_size, sampling_rate=train_args.sampling_rate)
else:
sampler = DualSampler(self.train_dataset, train_args.batch_size, sampling_rate=train_args.sampling_rate)
trainloader = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=train_args.batch_size,
sampler=sampler,
num_workers=train_args.num_workers,
collate_fn=self._make_collate_fn(),
)
return sampler, trainloader
def _get_eval_dataloader(self, dataset, train_args: TrainingArguments):
"""Create evaluation data loader."""
evalloader = torch.utils.data.DataLoader(
dataset,
batch_size=train_args.eval_batch_size,
shuffle=False,
num_workers=train_args.num_workers,
collate_fn=self._make_collate_fn(),
)
return evalloader
[docs]
def train(self):
"""
Main training loop.
Returns:
List of training logs with metrics for each epoch
"""
self.callback_handler.on_train_begin(self.args, self.state)
model = self.model.cuda()
self.loss_fn = self.loss_fn.cuda()
# Load checkpoint if resuming
if self.args.resume_from_checkpoint:
latest_checkpoint = self.get_latest_checkpoint(os.path.join(self.args.output_path, self.args.experiment_name))
if latest_checkpoint:
self.load_checkpoint(latest_checkpoint)
logger.info(f"Resuming training from epoch {self.state.epoch}")
else:
logger.info("No checkpoint found in output folder, starting from scratch")
for epoch in range(self.state.epoch, self.args.epochs):
self.callback_handler.on_epoch_begin(self.args, self.state)
step_losses = []
model.train()
# Training loop
for data, targets, index in self.trainloader:
self.callback_handler.on_step_begin(self.args, self.state)
data, targets = data.cuda(), targets.cuda()
y_pred = model(data)
# Compute loss
if self.args.loss == "CrossEntropyLoss":
loss = self.loss_fn(y_pred, targets)
elif self.args.loss == "BCELoss":
y_pred = torch.sigmoid(y_pred)
loss = self.loss_fn(y_pred, targets)
else:
y_pred = torch.sigmoid(y_pred)
if isinstance(index, list): # Multi-label
index, task_id = index
loss = self.loss_fn(y_pred, targets, index=index.cuda(), task_id=task_id)
else:
loss = self.loss_fn(y_pred, targets, index=index.cuda())
# Optimizer step
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
step_losses.append(loss.item())
self.callback_handler.on_step_end(self.args, self.state)
# Evaluation
model.eval()
avg_train_loss = float(np.mean(step_losses))
metrics, test_true, test_pred = self.evaluate_loop(model)
self.callback_handler.on_epoch_end(
self.args, self.state,
metrics=metrics,
train_loss=avg_train_loss,
lr=self.optimizer.lr,
test_true=test_true,
test_pred=test_pred,
)
# Save checkpoint periodically
if (epoch + 1) % self.args.save_checkpoint_every == 0 or (epoch + 1) == self.args.epochs:
checkpoint_path = os.path.join(self.args.output_path, self.args.experiment_name, f"epoch_{epoch + 1}.pt")
self.save_checkpoint(checkpoint_path)
self.callback_handler.on_train_end(self.args, self.state)
return self.state.train_log
[docs]
def evaluate(self, loader, model):
"""
Evaluate model on a given data loader.
Args:
loader: Data loader for evaluation
model: Model to evaluate
Returns:
Tuple of (dictionary of evaluation metrics, test_true, test_pred)
"""
test_pred_list = []
test_true_list = []
for test_data, test_targets, _ in loader:
test_data = test_data.cuda()
test_pred = model(test_data)
# Apply sigmoid to convert logits to probabilities
test_pred = torch.sigmoid(test_pred)
test_pred_list.append(test_pred.cpu().detach().numpy())
test_true_list.append(test_targets.numpy())
test_true = np.concatenate(test_true_list)
test_pred = np.concatenate(test_pred_list)
# Flatten if needed (for binary classification)
if test_pred.ndim > 1:
test_pred = test_pred.flatten()
if test_true.ndim > 1:
test_true = test_true.flatten()
result = self.metric(test_true, test_pred)
return result, test_true, test_pred
[docs]
def evaluate_loop(self, model):
"""
Evaluate model on all evaluation datasets.
Args:
model: Model to evaluate
Returns:
Tuple of (dictionary of metrics from all evaluation datasets, test_true, test_pred)
test_true and test_pred are from the first evaluation dataset, or None if no eval datasets
"""
metrics = []
test_true = None
test_pred = None
if not self.evalloaders:
self.callback_handler.on_evaluate(self.args, self.state)
return metrics, test_true, test_pred
for loader in self.evalloaders:
result, eval_true, eval_pred = self.evaluate(loader, model)
metrics.append(result)
# Store test_true and test_pred from the first evaluation dataset
if test_true is None:
test_true = eval_true
test_pred = eval_pred
self.callback_handler.on_evaluate(self.args, self.state)
return metrics, test_true, test_pred
[docs]
def save_checkpoint(self, checkpoint_path: str):
# Ensure checkpoint directory exists
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss_fn_state_dict': self.loss_fn.state_dict(),
'loss_fn': self.loss_fn,
'state': self.state,
'args': self.args
}
# Save checkpoint
torch.save(checkpoint, checkpoint_path)
logger.info(f"Checkpoint saved to {checkpoint_path}")
[docs]
def load_checkpoint(self, checkpoint_path: str):
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if hasattr(self.loss_fn, 'a') and hasattr(self.loss_fn, 'b') and hasattr(self.loss_fn, 'alpha'):
self.loss_fn.load_state_dict(checkpoint['loss_fn_state_dict'])
else:
self.loss_fn = checkpoint['loss_fn']
self.state = checkpoint['state']
# have to check if the args are the same as the current args
self.args = checkpoint['args']
logger.info(f"Checkpoint loaded successfully. Resuming from epoch {self.state.epoch}")
return checkpoint
[docs]
def get_latest_checkpoint(self, output_path: str):
if not os.path.exists(output_path):
return None
checkpoint_files = []
for file in os.listdir(output_path):
if file.startswith("epoch_") and file.endswith(".pt"):
try:
epoch_num = int(file.split("_")[1].split(".")[0])
checkpoint_files.append((epoch_num, os.path.join(output_path, file)))
except (ValueError, IndexError):
continue
if not checkpoint_files:
return None
# Sort by epoch number and return the latest
checkpoint_files.sort(key=lambda x: x[0])
return checkpoint_files[-1][1]