Source code for libauc.trainer.core.callbacks

import logging
import sys
from typing import List

from ..config.args import TrainingArguments

logger = logging.getLogger(__name__)


[docs] class TrainerState: """State object to track training progress.""" def __init__(self): self.epoch = 0 self.total_epoch = 0 self.step = 0 self.train_log = [] self.train_summary = {}
[docs] class TrainerCallback: r""" Base class for training lifecycle callbacks. Every method is a no-op by default, so subclasses only need to override the hooks they care about. Instances are registered with :class:`~trainer.core.callbacks.CallbackHandler`, which calls each hook in registration order and forwards a consistent set of keyword arguments (``model``, ``optimizer``, ``loss_fn``, plus any extra kwargs the :class:`~trainer.core.image_trainer.Trainer` supplies for that event). Lifecycle order during a typical training run:: on_init_end on_train_begin for each epoch: on_epoch_begin for each step: on_step_begin on_step_end on_evaluate on_epoch_end [on_save — called periodically inside the epoch loop] on_train_end All callback methods are optional and can be overridden in subclasses. """ def __init__(self) -> None: pass
[docs] def on_init_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called at the end of trainer initialization. No-op in base class.""" pass
[docs] def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called once before the first epoch. No-op in base class.""" pass
[docs] def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called once after the last epoch. No-op in base class.""" pass
[docs] def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called at the start of each epoch. No-op in base class.""" pass
[docs] def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called at the end of each epoch. No-op in base class.""" pass
[docs] def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called before each optimizer step. No-op in base class.""" pass
[docs] def on_substep_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called after each gradient-accumulation sub-step. No-op in base class.""" pass
[docs] def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called after each optimizer step. No-op in base class.""" pass
[docs] def on_evaluate(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called after each evaluation pass. No-op in base class.""" pass
[docs] def on_predict(self, args: TrainingArguments, state: TrainerState, metrics, **kwargs): """Called after a prediction pass. No-op in base class.""" pass
[docs] def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called after a checkpoint is saved. No-op in base class.""" pass
[docs] def on_log(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called after metrics are logged. No-op in base class.""" pass
[docs] def on_prediction_step(self, args: TrainingArguments, state: TrainerState, **kwargs): """Called after each prediction step. No-op in base class.""" pass
[docs] class CallbackHandler(TrainerCallback): r""" Multiplexer that owns a list of :class:`~TrainerCallback` instances and fans out every lifecycle event to each of them in registration order. ``CallbackHandler`` itself inherits from :class:`~TrainerCallback` so it can be used polymorphically, but its primary role is orchestration rather than providing hook implementations of its own. Args: callbacks (list[TrainerCallback]): Initial callback list. model: The model being trained. optimizer: The active optimizer. loss_fn: The active loss function. """ def __init__(self, callbacks: List[TrainerCallback], model, optimizer, loss_fn): self.callbacks = [] for cb in callbacks: self.add_callback(cb) self.model = model self.optimizer = optimizer self.loss_fn = loss_fn
[docs] def add_callback(self, callback): """Add a callback to the handler.""" cb = callback() if isinstance(callback, type) else callback cb_class = callback if isinstance(callback, type) else callback.__class__ if cb_class in [c.__class__ for c in self.callbacks]: logger.warning( f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. " f"The current list of callbacks is:\n{self.callback_list}" ) self.callbacks.append(cb)
[docs] def pop_callback(self, callback): """Remove and return a callback.""" if isinstance(callback, type): for cb in self.callbacks: if isinstance(cb, callback): self.callbacks.remove(cb) return cb else: for cb in self.callbacks: if cb == callback: self.callbacks.remove(cb) return cb
[docs] def remove_callback(self, callback): """Remove a callback without returning it.""" if isinstance(callback, type): for cb in self.callbacks: if isinstance(cb, callback): self.callbacks.remove(cb) return else: self.callbacks.remove(callback)
@property def callback_list(self): """Get a string representation of all callbacks.""" return "\n".join(cb.__class__.__name__ for cb in self.callbacks) def on_init_end(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_init_end", args, state) def on_train_begin(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_train_begin", args, state) def on_train_end(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_train_end", args, state) def on_epoch_begin(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_epoch_begin", args, state) def on_epoch_end(self, args: TrainingArguments, state: TrainerState, metrics, **kwargs): return self._call_event("on_epoch_end", args, state, metrics=metrics, **kwargs) def on_step_begin(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_step_begin", args, state) def on_substep_end(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_substep_end", args, state) def on_step_end(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_step_end", args, state) def on_evaluate(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_evaluate", args, state) def on_predict(self, args: TrainingArguments, state: TrainerState, metrics): return self._call_event("on_predict", args, state, metrics=metrics) def on_save(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_save", args, state) def on_log(self, args: TrainingArguments, state: TrainerState, logs): return self._call_event("on_log", args, state, logs=logs) def on_prediction_step(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_prediction_step", args, state) def _call_event(self, event, args, state, **kwargs): """Invoke *event* on every registered callback in order. ``model``, ``optimizer``, and ``loss_fn`` are always injected into ``**kwargs`` so every callback receives them as ``kwargs["model"]``, ``kwargs["optimizer"]``, and ``kwargs["loss_fn"]`` regardless of what the caller passes. Any additional keyword arguments from the caller are merged in after. """ for callback in self.callbacks: result = getattr(callback, event)( args, state, model=self.model, optimizer=self.optimizer, loss_fn=self.loss_fn, **kwargs, )
def _format_metrics(log: dict, skip_keys: tuple = ("epoch", "train_loss", "lr")) -> str: """Format metric key-value pairs into a display string, excluding skip_keys.""" return " | ".join(f"{k}: {v:.4f}" for k, v in log.items() if k not in skip_keys) def _build_log_dict( metrics: list, train_loss: float, lr: float, epoch: int, ) -> dict: """ Build a flat log dict from epoch metrics, suitable for console output and wandb. Returns a dict with keys: epoch, train_loss, lr, and one entry per metric per dataset. Single-dataset runs use bare metric names; multi-dataset runs prefix with 'ds{N}/'. """ log: dict[str, float] = { "epoch": epoch, "train_loss": train_loss, "lr": lr, } single = len(metrics) == 1 for ds_idx, ds_metrics in enumerate(metrics): if not isinstance(ds_metrics, dict): continue prefix = "" if single else f"eval_splits{ds_idx + 1}/" for k, v in ds_metrics.items(): if k in ("epoch", "lr", "loss"): continue try: log[f"{prefix}{k}"] = float(v) except (ValueError, TypeError): pass return log
[docs] class CLICallback(TrainerCallback): r""" Console and Weights & Biases logging callback. On ``on_train_begin`` it initialises a W&B run (silently falls back to console-only when W&B is not installed) and pretty-prints the full :class:`~trainer.config.args.TrainingArguments` config. On ``on_epoch_end`` it: * appends a structured entry to ``state.train_log``; * renders a progress bar (``verbose=1``) or a per-epoch line (``verbose=2``) to stdout; * ships the flat log dict to W&B via ``wandb.log``. On ``on_train_end`` it prints a training summary (best validation and test scores) and calls ``wandb.finish()``. .. note:: This callback takes no constructor arguments. All configuration is read from :class:`~libauc.trainer.config.args.TrainingArguments` at runtime via the ``args`` parameter passed to each lifecycle hook. Note: W&B logging is silently disabled when ``wandb`` is not installed or when ``wandb.log`` raises an exception. """ # Width of the progress bar fill (verbose=1) _BAR_WIDTH = 30 def __init__(self) -> None: super().__init__() self._use_wandb = True # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _wandb_log(self, log: dict, step: int) -> None: """Log a dict to wandb, silently skipping if unavailable.""" if not self._use_wandb: return try: import wandb wandb.log(log, step=step) except Exception as e: logger.warning(f"wandb logging failed: {e}") def _render_bar(self, epoch: int, total: int, log: dict) -> str: filled = int(self._BAR_WIDTH * epoch / total) if total else 0 empty = self._BAR_WIDTH - filled bar = "█" * filled + "·" * empty metrics = _format_metrics(log) metrics_str = f" | {metrics}" if metrics else "" return ( f"\rEpoch [{bar}] {epoch}/{total} | " f"Loss: {log.get('train_loss', 0):.4f}" f"{metrics_str} | " f"LR: {log.get('lr', 0):.6f}" ) # ------------------------------------------------------------------ # Callback events — what each hook does in CLICallback # ------------------------------------------------------------------
[docs] def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Initialise W&B and print the training config. Attempts to start a W&B run using ``args.project_name`` and ``args.experiment_name`` with the full ``TrainingArguments`` dict as the run config. If ``wandb`` is not installed the run is silently skipped and ``self._use_wandb`` is set to ``False`` so all subsequent W&B calls are no-ops. When ``args.verbose != 0``, pretty-prints the resolved config between two ``====`` separator lines so the user can confirm all hyperparameters before training starts. """ try: import wandb config = { k: v for k, v in vars(args).items() if not k.startswith("_") } wandb.init(project=args.project_name, name=args.experiment_name, reinit=True, config=config) except ImportError: logger.warning("wandb not installed; skipping wandb logging") self._use_wandb = False if args.verbose == 0: return import pprint config = {k: v for k, v in vars(args).items() if not k.startswith("_")} print("=" * 60) pprint.pprint(config, indent=2, sort_dicts=False) print("=" * 60)
[docs] def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Apply learning-rate / regulariser decay if the epoch is scheduled. Checks whether ``state.epoch`` is listed in ``args.decay_epochs``. If so, decays by a fixed factor of 10: """ optimizer = kwargs.get("optimizer") if optimizer and state.epoch in args.decay_epochs: if getattr(optimizer, "model_ref", None) is not None: optimizer.update_regularizer(decay_factor=10) else: optimizer.update_lr(decay_factor=10)
[docs] def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Record metrics, update the console display, and log to W&B. 1. **Appends** a structured record ``{metrics, epoch, lr, train_loss}`` to ``state.train_log`` for later retrieval (e.g. by ``on_train_end``). 2. **Console output** (controlled by ``args.verbose``): * ``verbose=1`` — overwrites a single progress-bar line in place using ``\\r``, showing a block-character bar, loss, all eval metrics, and the current LR. A newline is printed only on the final epoch so the bar persists after training. * ``verbose=2`` — prints one pipe-separated line per epoch (``Epoch N/T | Loss: X | AUROC: Y | LR: Z``). * ``verbose=0`` — no console output. 3. **W&B** — ships the flat log dict to ``wandb.log`` (no-op if W&B is unavailable). 4. **Increments** ``state.epoch``. """ metrics: list = kwargs.get("metrics", []) train_loss: float = kwargs.get("train_loss", 0) lr: float = kwargs.get("lr", 0) state.train_log.append({ "metrics": metrics, "epoch": state.epoch + 1, "lr": lr, "train_loss": train_loss, }) log = _build_log_dict(metrics, train_loss, lr, epoch=state.epoch + 1) # ---- Console output (mode-dependent) ---------------------------- if args.verbose == 1: # Overwrite the same line with an updated progress bar bar_str = self._render_bar(state.epoch + 1, state.total_epoch, log) sys.stdout.write(bar_str) sys.stdout.flush() # Print a newline only on the very last epoch so the bar stays # on screen after training ends if state.epoch + 1 >= state.total_epoch: sys.stdout.write("\n") sys.stdout.flush() elif args.verbose == 2: # One line per epoch (original behaviour) display_parts = [ f"Epoch {state.epoch + 1}/{state.total_epoch}", f"Loss: {train_loss:.4f}", ] display_parts += [ f"{k}: {v:.4f}" for k, v in log.items() if k not in ("epoch", "train_loss", "lr") ] display_parts.append(f"LR: {lr:.6f}") print(" | ".join(display_parts)) # ---- wandb logging (always active unless unavailable) ----------- self._wandb_log(log, step=state.epoch + 1) state.epoch += 1
[docs] def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Print a training summary and close the W&B run. 1. Prints a ``"Training complete."`` separator when ``verbose != 0``. 2. Calls ``wandb.finish()`` to flush and close the W&B run. 3. Scans ``state.train_log`` to find the best epoch (highest value of the first metric on the first eval split) and reports: * **1 eval split** — best validation score. * **2+ eval splits** — best validation score + average of all remaining test split scores at that epoch. 4. Stores the summary dict under ``state.train_summary``. """ if args.verbose != 0: print("-" * 50) print("Training complete.") if self._use_wandb: try: import wandb wandb.finish() except ImportError: pass train_log = state.train_log if not train_log: raise ValueError("Training should have at least one evaluation record.") # No eval datasets registered — nothing to summarise. if not train_log[0]['metrics']: logger.info("No evaluation datasets registered; skipping training summary.") return train_summary = {} target = list(train_log[0]['metrics'][0].keys())[0] best_idx = max(range(len(train_log)), key=lambda i: train_log[i]['metrics'][0][target]) num_evals = len(train_log[0]['metrics']) if num_evals == 1: val = train_log[best_idx]['metrics'][0][target] logger.info(f"best validation {target}: {val}") train_summary["val"] = val else: val = train_log[best_idx]['metrics'][0][target] score = sum( train_log[best_idx]['metrics'][x][target] for x in range(1, num_evals) ) / (num_evals - 1) logger.info(f"best validation {target}: {val}, best test avg. {target}: {score}") train_summary["val"] = val train_summary["test"] = score state.train_summary = train_summary
[docs] def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Increment the global step counter. Increments ``state.step`` by 1 after every optimizer update, keeping a running total of training steps across all epochs. """ state.step += 1