Source code for libauc.trainer.core.callbacks

import logging
import sys
from typing import Any, Dict, List, Optional

from ..config.args import TrainingArguments

logger = logging.getLogger(__name__)


[docs] class TrainerState: """State object to track training progress.""" def __init__(self): self.epoch = 0 self.total_epoch = 0 self.step = 0 self.train_log = [] self.train_summary = {}
[docs] class TrainerCallback: r""" Base class for training lifecycle callbacks. Every method is a no-op by default, so subclasses only need to override the hooks they care about. Instances are registered with :class:`~trainer.core.callbacks.CallbackHandler`, which calls each hook in registration order and forwards a consistent set of keyword arguments (``model``, ``optimizer``, ``loss_fn``, plus any extra kwargs the :class:`~trainer.core.trainer.Trainer` supplies for that event). Lifecycle order during a typical training run:: on_init_end on_train_begin for each epoch: on_epoch_begin for each step: on_step_begin on_step_end on_evaluate on_epoch_end [on_save — called periodically inside the epoch loop] on_train_end All callback methods are optional and can be overridden in subclasses. Example:: >>> class MyCallback(TrainerCallback): ... def on_epoch_end(self, args, state, **kwargs): ... print(f"Epoch {state.epoch} done, loss={kwargs['train_loss']:.4f}") ... >>> trainer = Trainer(..., callbacks=[MyCallback()]) """ def __init__(self) -> None: pass
[docs] def on_init_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of trainer initialization.""" pass
[docs] def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the beginning of training.""" pass
[docs] def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of training.""" pass
[docs] def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the beginning of an epoch.""" pass
[docs] def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of an epoch.""" pass
[docs] def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the beginning of a training step.""" pass
[docs] def on_substep_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of a substep during gradient accumulation.""" pass
[docs] def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of a training step.""" pass
[docs] def on_evaluate(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called after an evaluation phase.""" pass
[docs] def on_predict(self, args: TrainingArguments, state: TrainerState, metrics, **kwargs): """Event called after a successful prediction.""" pass
[docs] def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called after a checkpoint save.""" pass
[docs] def on_log(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called after logging the last logs.""" pass
[docs] def on_prediction_step(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called after a prediction step.""" pass
[docs] class CallbackHandler(TrainerCallback): r""" Multiplexer that owns a list of :class:`~TrainerCallback` instances and fans out every lifecycle event to each of them in registration order. ``CallbackHandler`` itself inherits from :class:`~TrainerCallback` so it can be used polymorphically, but its primary role is orchestration rather than providing hook implementations of its own. Args: callbacks (list[TrainerCallback]): Initial callback list. model: The model being trained (forwarded to every callback via ``kwargs["model"]``). optimizer: The active optimizer (forwarded via ``kwargs["optimizer"]``). loss_fn: The active loss function (forwarded via ``kwargs["loss_fn"]``). Example:: >>> handler = CallbackHandler( ... [CLICallback()], ... model=model, optimizer=optimizer, loss_fn=loss_fn, ... ) >>> handler.on_train_begin(args, state) """ def __init__(self, callbacks: List[TrainerCallback], model, optimizer, loss_fn): self.callbacks = [] for cb in callbacks: self.add_callback(cb) self.model = model self.optimizer = optimizer self.loss_fn = loss_fn
[docs] def add_callback(self, callback): """Add a callback to the handler.""" cb = callback() if isinstance(callback, type) else callback cb_class = callback if isinstance(callback, type) else callback.__class__ if cb_class in [c.__class__ for c in self.callbacks]: logger.warning( f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. " f"The current list of callbacks is:\n{self.callback_list}" ) self.callbacks.append(cb)
[docs] def pop_callback(self, callback): """Remove and return a callback.""" if isinstance(callback, type): for cb in self.callbacks: if isinstance(cb, callback): self.callbacks.remove(cb) return cb else: for cb in self.callbacks: if cb == callback: self.callbacks.remove(cb) return cb
[docs] def remove_callback(self, callback): """Remove a callback without returning it.""" if isinstance(callback, type): for cb in self.callbacks: if isinstance(cb, callback): self.callbacks.remove(cb) return else: self.callbacks.remove(callback)
@property def callback_list(self): """Get a string representation of all callbacks.""" return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
[docs] def on_init_end(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_init_end", args, state)
[docs] def on_train_begin(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_train_begin", args, state)
[docs] def on_train_end(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_train_end", args, state)
[docs] def on_epoch_begin(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_epoch_begin", args, state)
[docs] def on_epoch_end(self, args: TrainingArguments, state: TrainerState, metrics, **kwargs): return self._call_event("on_epoch_end", args, state, metrics=metrics, **kwargs)
[docs] def on_step_begin(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_step_begin", args, state)
[docs] def on_substep_end(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_substep_end", args, state)
[docs] def on_step_end(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_step_end", args, state)
[docs] def on_evaluate(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_evaluate", args, state)
[docs] def on_predict(self, args: TrainingArguments, state: TrainerState, metrics): return self._call_event("on_predict", args, state, metrics=metrics)
[docs] def on_save(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_save", args, state)
[docs] def on_log(self, args: TrainingArguments, state: TrainerState, logs): return self._call_event("on_log", args, state, logs=logs)
[docs] def on_prediction_step(self, args: TrainingArguments, state: TrainerState): return self._call_event("on_prediction_step", args, state)
def _call_event(self, event, args, state, **kwargs): """Call the specified event on all callbacks.""" for callback in self.callbacks: result = getattr(callback, event)( args, state, model=self.model, optimizer=self.optimizer, loss_fn=self.loss_fn, **kwargs, )
[docs] class DefaultCallback(TrainerCallback): """Default callback with basic functionality.""" def __init__(self) -> None: super().__init__()
[docs] def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the beginning of an epoch.""" optimizer = kwargs.get("optimizer") if optimizer and state.epoch in args.decay_epochs: if getattr(optimizer, "model_ref", None) is not None: optimizer.update_regularizer(decay_factor=10) else: optimizer.update_lr(decay_factor=10)
[docs] def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of an epoch.""" state.epoch += 1
[docs] def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of a training step.""" state.step += 1
def _format_metrics(log: dict, skip_keys: tuple = ("epoch", "train_loss", "lr")) -> str: """Format metric key-value pairs into a display string, excluding skip_keys.""" return " | ".join(f"{k}: {v:.4f}" for k, v in log.items() if k not in skip_keys) def _build_log_dict( metrics: list, train_loss: float, lr: float, epoch: int, ) -> dict: """ Build a flat log dict from epoch metrics, suitable for console output and wandb. Returns a dict with keys: epoch, train_loss, lr, and one entry per metric per dataset. Single-dataset runs use bare metric names; multi-dataset runs prefix with 'ds{N}/'. """ log: dict[str, float] = { "epoch": epoch, "train_loss": train_loss, "lr": lr, } single = len(metrics) == 1 for ds_idx, ds_metrics in enumerate(metrics): if not isinstance(ds_metrics, dict): continue prefix = "" if single else f"eval_splits{ds_idx + 1}/" for k, v in ds_metrics.items(): if k in ("epoch", "lr", "loss"): continue try: log[f"{prefix}{k}"] = float(v) except (ValueError, TypeError): pass return log
[docs] class CLICallback(TrainerCallback): r""" Console and Weights & Biases logging callback. On ``on_train_begin`` it initialises a W&B run (silently falls back to console-only when W&B is not installed) and pretty-prints the full :class:`~trainer.config.args.TrainingArguments` config. On ``on_epoch_end`` it: * appends a structured entry to ``state.train_log``; * renders a progress bar (``verbose=1``) or a per-epoch line (``verbose=2``) to stdout; * ships the flat log dict to W&B via ``wandb.log``. On ``on_train_end`` it prints a training summary (best validation and test scores) and calls ``wandb.finish()``. Args: (none — all configuration is read from :class:`~trainer.config.args.TrainingArguments` at runtime) Note: W&B logging is silently disabled when ``wandb`` is not installed or when ``wandb.log`` raises an exception. Example:: >>> trainer = Trainer(..., callbacks=[CLICallback()]) >>> trainer.train() ============================================================ {'batch_size': 128, 'epochs': 50, ...} ============================================================ Epoch [██████████████████············] 20/50 | Loss: 0.3241 | AUROC: 0.8712 | LR: 0.100000 """ # Width of the progress bar fill (verbose=1) _BAR_WIDTH = 30 def __init__(self) -> None: super().__init__() self._use_wandb = True # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _wandb_log(self, log: dict, step: int) -> None: """Log a dict to wandb, silently skipping if unavailable.""" if not self._use_wandb: return try: import wandb wandb.log(log, step=step) except Exception as e: logger.warning(f"wandb logging failed: {e}") def _render_bar(self, epoch: int, total: int, log: dict) -> str: filled = int(self._BAR_WIDTH * epoch / total) if total else 0 empty = self._BAR_WIDTH - filled bar = "█" * filled + "·" * empty metrics = _format_metrics(log) metrics_str = f" | {metrics}" if metrics else "" return ( f"\rEpoch [{bar}] {epoch}/{total} | " f"Loss: {log.get('train_loss', 0):.4f}" f"{metrics_str} | " f"LR: {log.get('lr', 0):.6f}" ) # ------------------------------------------------------------------ # Callback events # ------------------------------------------------------------------
[docs] def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): try: import wandb config = { k: v for k, v in vars(args).items() if not k.startswith("_") } wandb.init(project=args.project_name, name=args.experiment_name, reinit=True, config=config) except ImportError: logger.warning("wandb not installed; skipping wandb logging") self._use_wandb = False if args.verbose == 0: return import pprint config = {k: v for k, v in vars(args).items() if not k.startswith("_")} print("=" * 60) pprint.pprint(config, indent=2, sort_dicts=False) print("=" * 60)
[docs] def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the beginning of an epoch.""" optimizer = kwargs.get("optimizer") if optimizer and state.epoch in args.decay_epochs: if getattr(optimizer, "model_ref", None) is not None: optimizer.update_regularizer(decay_factor=10) else: optimizer.update_lr(decay_factor=10)
[docs] def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of an epoch.""" metrics: list = kwargs.get("metrics", []) train_loss: float = kwargs.get("train_loss", 0) lr: float = kwargs.get("lr", 0) state.train_log.append({ "metrics": metrics, "epoch": state.epoch + 1, "lr": lr, "train_loss": train_loss, }) log = _build_log_dict(metrics, train_loss, lr, epoch=state.epoch + 1) # ---- Console output (mode-dependent) ---------------------------- if args.verbose == 1: # Overwrite the same line with an updated progress bar bar_str = self._render_bar(state.epoch + 1, state.total_epoch, log) sys.stdout.write(bar_str) sys.stdout.flush() # Print a newline only on the very last epoch so the bar stays # on screen after training ends if state.epoch + 1 >= state.total_epoch: sys.stdout.write("\n") sys.stdout.flush() elif args.verbose == 2: # One line per epoch (original behaviour) display_parts = [ f"Epoch {state.epoch + 1}/{state.total_epoch}", f"Loss: {train_loss:.4f}", ] display_parts += [ f"{k}: {v:.4f}" for k, v in log.items() if k not in ("epoch", "train_loss", "lr") ] display_parts.append(f"LR: {lr:.6f}") print(" | ".join(display_parts)) # ---- wandb logging (always active unless unavailable) ----------- self._wandb_log(log, step=state.epoch + 1) state.epoch += 1
[docs] def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of training.""" if args.verbose != 0: print("-" * 50) print("Training complete.") if self._use_wandb: try: import wandb wandb.finish() except ImportError: pass train_log = state.train_log if not train_log: raise ValueError("Training should have at least one evaluation record.") train_summary = {} target = list(train_log[0]['metrics'][0].keys())[0] id = max(range(len(train_log)), key=lambda i: train_log[i]['metrics'][0][target]) num_evals = len(train_log[0]['metrics']) if num_evals == 0: raise ValueError("Evaluation should contain at least one dataset split.") elif num_evals == 1: val = train_log[id]['metrics'][0][target] logger.info(f"best validation {target}: {val}") train_summary["val"] = val elif num_evals == 2: val = train_log[id]['metrics'][0][target] score = train_log[id]['metrics'][1][target] logger.info(f"best validation {target}: {val}, best test {target}: {score}") train_summary["val"] = val train_summary["test"] = score else: val = train_log[id]['metrics'][0][target] score = sum( train_log[id]['metrics'][x][target] for x in range(1, num_evals) ) / (num_evals - 1) logger.info(f"best validation {target}: {val}, best test avg. {target}: {score}") train_summary["val"] = val train_summary["test"] = score state.train_summary = train_summary
[docs] def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs): """Event called at the end of a training step.""" state.step += 1