Source code for libauc.trainer.core.trainer

import logging
import torch
import libauc
import numpy as np
from libauc.sampler import DualSampler, TriSampler
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, Mapping
from torch.utils.data import Dataset
import importlib

from ..config.args import TrainingArguments
from .callbacks import CallbackHandler, TrainerCallback, TrainerState

logger = logging.getLogger(__name__)


[docs] class Trainer: r""" Full training loop for image-classification models supported by libauc. ``Trainer`` wires together a model, an AUC-aware loss function, a libauc optimizer, dual/tri-sampled data loaders, and an optional evaluation pipeline behind a unified :meth:`train` entry point. Progress is surfaced through a :class:`~trainer.core.callbacks.CallbackHandler` so any number of :class:`~trainer.core.callbacks.TrainerCallback` subclasses can observe or alter the training loop without touching ``Trainer`` internals. The class is intentionally thin: heavy lifting (data loading, model construction, loss/optimizer instantiation) is delegated to private helpers so subclasses like :class:`~trainer.core.gnn_trainer.GNNTrainer` can override only the parts they need. Args: train_args (TrainingArguments): Fully populated training configuration produced by :class:`~trainer.config.args.TrainingArguments`. model_cfg (dict): Architecture config forwarded to :meth:`_build_model`. Must contain at least a ``"name"`` key matching one of the registered architectures (``resnet20``, ``resnet18``, ``densenet121``). train_dataset (Dataset): PyTorch ``Dataset`` for the training split. Must expose a ``.targets`` attribute (list or array of labels). eval_dataset (list[Dataset], optional): One or more evaluation datasets. ``None`` disables evaluation (default: ``None``). metric (callable, optional): ``(y_true, y_pred) -> dict[str, float]`` function returned by :func:`~trainer.helpers.build_metric`. ``None`` disables metric computation (default: ``None``). callbacks (list[TrainerCallback], optional): Callbacks invoked at every lifecycle hook. When ``None`` the handler is created with an empty list (default: ``None``). Example:: >>> from trainer.config.args import TrainingArguments >>> from trainer.core.trainer import Trainer >>> from trainer.core.callbacks import CLICallback >>> train_args = TrainingArguments( ... optimizer="PESG", optimizer_kwargs={"lr": 0.1}, ... loss="AUCMLoss", loss_kwargs={"margin": 1.0}, ... SEED=42, batch_size=128, eval_batch_size=128, ... sampling_rate=0.5, epochs=50, decay_epochs=[], ... num_workers=2, output_path="./output", num_tasks=1, ... resume_from_checkpoint=False, save_checkpoint_every=5, ... project_name="libauc", experiment_name="demo", verbose=1, ... ) >>> trainer = Trainer( ... train_args=train_args, ... model_cfg={"name": "resnet18", "num_classes": 1}, ... train_dataset=train_ds, ... eval_dataset=[val_ds], ... metric=metric_fn, ... callbacks=[CLICallback()], ... ) >>> log = trainer.train() """ def __init__(self, train_args: TrainingArguments, model_cfg: dict, train_dataset: Dataset, eval_dataset: Optional[List[Dataset]] = None, metric: Optional[Callable[[torch.Tensor, torch.Tensor], Mapping[str, float]]] = None, callbacks: Optional[List[TrainerCallback]] = None): """ Initialize the trainer. Args: model: The neural network model to train train_args: Training configuration arguments train_dataset: Training dataset eval_dataset: Optional evaluation datasets metric: Evaluation metric function callbacks: Optional list of training callbacks """ self.args = train_args self._build_model(model_cfg) self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.state = TrainerState() self.state.total_epoch = self.args.epochs # Setup data loaders self.sampler, self.trainloader = self._get_train_dataloader(self.args) self.evalloaders = [] if self.eval_dataset: for dataset in self.eval_dataset: self.evalloaders.append(self._get_eval_dataloader(dataset, self.args)) # Calculate dataset statistics if isinstance(self.sampler.pos_len, list): self.data_len = self.sampler.pos_len[0] + self.sampler.neg_len[0] else: self.data_len = self.sampler.pos_len + self.sampler.neg_len self.pos_len = self.sampler.pos_len self.neg_len = self.sampler.neg_len # Setup optimizer and loss self.loss_fn, self.optimizer = self._construct_optimizer_and_loss(self.model, train_args) # Setup metric and callbacks self.metric = metric if callbacks is None: self.callback_handler = CallbackHandler([], self.model, self.optimizer, self.loss_fn) else: self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.loss_fn) self.callback_handler.on_init_end(self.args, self.state)
[docs] def add_callback(self, callback): """Add a callback to the trainer.""" self.callback_handler.add_callback(callback)
def _build_model(self, model_cfg: dict): """ Build a model from the 'model' config block. Expected keys: name - architecture name (string) pretrained - bool (default False) num_classes - int (default 1 for binary classification) TODO: Register additional architectures as needed. """ name = model_cfg.get("name", "").lower() pretrained = model_cfg.get("pretrained", False) pretrained_remote = model_cfg.get("pretrained_remote", False) num_classes = model_cfg.get("num_classes", 1) in_channels = model_cfg.get("in_channels", 3) if name == "resnet20": from libauc.models import resnet20 model = resnet20(last_activation=None, num_classes=num_classes) elif name == "resnet18": from libauc.models import resnet18 model = resnet18(pretrained=pretrained_remote, last_activation=None, in_channels=in_channels, num_classes=num_classes) elif name == "densenet121": from libauc.models import densenet121 model = densenet121(pretrained=pretrained_remote, last_activation=None, activations='relu', num_classes=num_classes) else: raise ValueError(f"Unknown model '{name}'. Please add it to build_model().") model = model.cuda() if pretrained: state_dict = torch.load(model_cfg.get("pretrained_path"), weights_only = False) if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] filtered = {k:v for k,v in state_dict.items() if 'fc' not in k and 'linear' not in k} msg = model.load_state_dict(filtered, False) logger.info(msg) if hasattr(model, 'fc'): model.fc.reset_parameters() if hasattr(model, 'linear'): model.linear.reset_parameters() self.model = model def _get_optimizer(self, name): """ Get an optimizer class by name. Args: name: Name of the optimizer Returns: Optimizer class """ opt = importlib.import_module("libauc.optimizers") opt_cls = getattr(opt, name, None) return opt_cls def _get_loss(self, name): """ Get a loss function class by name. Args: name: Name of the loss function Returns: Loss function class """ # Try libauc first libauc_losses = importlib.import_module("libauc.losses") loss_cls = getattr(libauc_losses, name, None) # Fall back to torch.nn if not found in libauc if loss_cls is None: import torch.nn as nn loss_cls = getattr(nn, name, None) # Raise error if not found in either if loss_cls is None: raise ValueError(f"Loss function '{name}' not found in libauc.losses or torch.nn") return loss_cls def _construct_optimizer_and_loss(self, model, train_args: TrainingArguments): """Construct optimizer and loss function based on configuration.""" # Setup loss function loss_cls = self._get_loss(train_args.loss) if train_args.loss in ["BCELoss", "CrossEntropyLoss"]: train_args.loss_kwargs.pop("num_labels", None) if train_args.loss in ["pAUCLoss", "MultiLabelpAUCLoss"]: if train_args.loss_kwargs["mode"] in ['SOPA']: loss_fn = loss_cls(data_len=self.data_len, pos_len=self.pos_len, **train_args.loss_kwargs) else: loss_fn = loss_cls(data_len=self.data_len, **train_args.loss_kwargs) elif train_args.loss in ["mAPLoss", "APLoss", "pAUC_DRO_Loss", "tpAUC_KL_Loss"]: loss_fn = loss_cls(data_len=self.data_len, **train_args.loss_kwargs) elif train_args.loss in ["pAUC_CVaR_Loss"]: loss_fn = loss_cls(data_len=self.data_len, pos_len=self.pos_len, **train_args.loss_kwargs) elif train_args.loss in ["tpAUC_CVaR_loss"]: loss_fn = loss_cls(data_length=self.data_len, **train_args.loss_kwargs) else: loss_fn = loss_cls(**train_args.loss_kwargs) # Setup optimizer opt_cls = self._get_optimizer(train_args.optimizer) optimizer = opt_cls(model.parameters(), loss_fn=loss_fn, **train_args.optimizer_kwargs) return loss_fn, optimizer def _get_train_dataloader(self, train_args: TrainingArguments): """Create training data loader with dual sampling.""" if train_args.num_tasks >= 3: sampler = TriSampler(self.train_dataset, train_args.batch_size, sampling_rate=train_args.sampling_rate) else: sampler = DualSampler(self.train_dataset, train_args.batch_size, sampling_rate=train_args.sampling_rate) trainloader = torch.utils.data.DataLoader( self.train_dataset, batch_size=train_args.batch_size, sampler=sampler, num_workers=train_args.num_workers ) return sampler, trainloader def _get_eval_dataloader(self, dataset, train_args: TrainingArguments): """Create evaluation data loader.""" evalloader = torch.utils.data.DataLoader( dataset, batch_size=train_args.eval_batch_size, shuffle=False, num_workers=train_args.num_workers ) return evalloader
[docs] def train(self): """ Main training loop. Returns: List of training logs with metrics for each epoch """ self.callback_handler.on_train_begin(self.args, self.state) model = self.model.cuda() self.loss_fn = self.loss_fn.cuda() # Load checkpoint if resuming if self.args.resume_from_checkpoint: latest_checkpoint = self.get_latest_checkpoint(os.path.join(self.args.output_path, self.args.experiment_name)) if latest_checkpoint: checkpoint = self.load_checkpoint(latest_checkpoint) logger.info(f"Resuming training from epoch {self.state.epoch}") else: logger.info("No checkpoint found in output folder, starting from scratch") for epoch in range(self.state.epoch, self.args.epochs): self.callback_handler.on_epoch_begin(self.args, self.state) train_loss = [] model.train() # Training loop for data, targets, index in self.trainloader: self.callback_handler.on_step_begin(self.args, self.state) data, targets = data.cuda(), targets.cuda() y_pred = model(data) # Compute loss if self.args.loss == "CrossEntropyLoss": loss = self.loss_fn(y_pred, targets) if self.args.loss == "BCELoss": y_pred = torch.sigmoid(y_pred) loss = self.loss_fn(y_pred, targets) else: y_pred = torch.sigmoid(y_pred) if isinstance(index, list): # Multilable index, task_id = index loss = self.loss_fn(y_pred, targets, index=index.cuda(), task_id = task_id) else: loss = self.loss_fn(y_pred, targets, index=index.cuda()) # Optimizer step self.optimizer.zero_grad() loss.backward() self.optimizer.step() train_loss.append(loss.item()) self.callback_handler.on_step_end(self.args, self.state) # Evaluation model.eval() train_loss = np.mean(train_loss) metrics, test_true, test_pred = self.evaluate_loop(model) self.callback_handler.on_epoch_end( self.args, self.state, metrics=metrics, train_loss=train_loss, lr=self.optimizer.lr, test_true=test_true, test_pred=test_pred ) # Save checkpoint periodically if (epoch + 1) % self.args.save_checkpoint_every == 0 or (epoch + 1) == self.args.epochs: checkpoint_path = os.path.join(self.args.output_path, self.args.experiment_name, f"epoch_{epoch + 1}.pt") self.save_checkpoint(checkpoint_path) self.callback_handler.on_train_end(self.args, self.state) return self.state.train_log
[docs] def evaluate(self, loader, model): """ Evaluate model on a given data loader. Args: loader: Data loader for evaluation model: Model to evaluate Returns: Tuple of (dictionary of evaluation metrics, test_true, test_pred) """ test_pred_list = [] test_true_list = [] for test_data, test_targets, _ in loader: test_data = test_data.cuda() test_pred = model(test_data) # Apply sigmoid to convert logits to probabilities test_pred = torch.sigmoid(test_pred) test_pred_list.append(test_pred.cpu().detach().numpy()) test_true_list.append(test_targets.numpy()) test_true = np.concatenate(test_true_list) test_pred = np.concatenate(test_pred_list) # Flatten if needed (for binary classification) if test_pred.ndim > 1: test_pred = test_pred.flatten() if test_true.ndim > 1: test_true = test_true.flatten() result = self.metric(test_true, test_pred) return result, test_true, test_pred
[docs] def evaluate_loop(self, model): """ Evaluate model on all evaluation datasets. Args: model: Model to evaluate Returns: Tuple of (dictionary of metrics from all evaluation datasets, test_true, test_pred) test_true and test_pred are from the first evaluation dataset, or None if no eval datasets """ metrics = [] test_true = None test_pred = None if not self.evalloaders: self.callback_handler.on_evaluate(self.args, self.state) return metrics, test_true, test_pred for loader in self.evalloaders: result, eval_true, eval_pred = self.evaluate(loader, model) metrics.append(result) # Store test_true and test_pred from the first evaluation dataset if test_true is None: test_true = eval_true test_pred = eval_pred self.callback_handler.on_evaluate(self.args, self.state) return metrics, test_true, test_pred
[docs] def save_checkpoint(self, checkpoint_path: str): # Ensure checkpoint directory exists os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) checkpoint = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss_fn_state_dict': self.loss_fn.state_dict(), 'loss_fn': self.loss_fn, 'state': self.state, 'args': self.args } # Save checkpoint torch.save(checkpoint, checkpoint_path) logger.info(f"Checkpoint saved to {checkpoint_path}")
[docs] def load_checkpoint(self, checkpoint_path: str): if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") logger.info(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, weights_only=False) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if hasattr(self.loss_fn, 'a') and hasattr(self.loss_fn, 'b') and hasattr(self.loss_fn, 'alpha'): self.loss_fn.load_state_dict(checkpoint['loss_fn_state_dict']) else: self.loss_fn = checkpoint['loss_fn'] self.state = checkpoint['state'] # have to check if the args are the same as the current args self.args = checkpoint['args'] logger.info(f"Checkpoint loaded successfully. Resuming from epoch {self.state.epoch}") return checkpoint
[docs] def get_latest_checkpoint(self, output_path: str): if not os.path.exists(output_path): return None checkpoint_files = [] for file in os.listdir(output_path): if file.startswith("epoch_") and file.endswith(".pt"): try: epoch_num = int(file.split("_")[1].split(".")[0]) checkpoint_files.append((epoch_num, os.path.join(output_path, file))) except (ValueError, IndexError): continue if not checkpoint_files: return None # Sort by epoch number and return the latest checkpoint_files.sort(key=lambda x: x[0]) return checkpoint_files[-1][1]