Source code for libauc.trainer.core.gnn_trainer

import logging
import os
from typing import Callable, List, Mapping, Optional, Union

import numpy as np
import torch
from torch.utils.data import Dataset

from .trainer import Trainer
from ..config.args import TrainingArguments
from .callbacks import CallbackHandler, TrainerCallback

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Architectures available in libauc.models for graph learning
# ---------------------------------------------------------------------------
_GNN_REGISTRY = {
    "gcn":        "GCN",
    "gin":        "GIN",
    "gine":       "GINE",
    "graphsage":  "GraphSAGE",
    "gat":        "GAT",
    "mpnn":       "MPNN",
    "deepergcn":  "DeeperGCN",
    "pna":        "PNA",
}

[docs] class GNNTrainer(Trainer): r""" Training loop for graph neural networks built with libauc's GNN model zoo. ``GNNTrainer`` extends :class:`~trainer.core.trainer.Trainer` with graph-aware overrides: * :meth:`_build_model` — looks up the requested GNN architecture in :data:`_GNN_REGISTRY` and constructs it via ``libauc.models``, then infers whether the model expects edge features (``supports_edge_attr``). * :meth:`_get_train_dataloader` / :meth:`_get_eval_dataloader` — use ``torch_geometric.loader.DataLoader`` instead of the standard PyTorch one, while keeping the same :class:`~libauc.sampler.DualSampler` for positive/negative balancing. * :meth:`_forward` — dispatches to the correct GNN forward signature (with or without ``edge_attr``). * :meth:`train` — adds optional learning-rate decay at specified epochs via ``optimizer.update_lr``. Supported GNN architectures --------------------------- ``gcn``, ``gin``, ``gine``, ``graphsage``, ``gat``, ``mpnn``, ``deepergcn``, ``pna`` Args: train_args (TrainingArguments): Training configuration. model_cfg (dict): GNN model configuration. Required key: ``name`` (one of the architectures listed above). Optional keys: ``num_tasks`` (default ``1``), ``emb_dim`` (default ``256``), ``num_layers`` (default ``5``), ``graph_pooling``, ``dropout``, ``atom_features_dims``, ``bond_features_dims``, ``act``, ``norm``, ``jk``, ``v2`` (GAT-only), ``aggr`` / ``t`` / ``learn_t`` / ``p`` / ``learn_p`` / ``block`` (DeeperGCN-only), ``pretrained`` (bool), ``pretrained_path`` (str). train_dataset: PyG-compatible graph dataset (train split). eval_dataset (list, optional): PyG-compatible graph datasets for evaluation splits (default: ``None``). metric (callable, optional): ``(y_true, y_pred) -> dict[str, float]`` callbacks (list[TrainerCallback], optional): Training callbacks. decay_epochs (list[int], optional): Epoch indices at which ``optimizer.update_lr(decay_factor=decay_factor)`` is called (default: no decay). decay_factor (float): LR divisor at each decay epoch (default: ``10.0``). train_eval_dataset: Optional dataset for an unbiased train-split evaluation; falls back to ``train_dataset`` when ``None``. Example:: >>> trainer = GNNTrainer( ... train_args=train_args, ... model_cfg={"name": "gin", "num_tasks": 1, "emb_dim": 300}, ... train_dataset=train_ds, ... eval_dataset=[val_ds, test_ds], ... metric=metric_fn, ... callbacks=[CLICallback()], ... decay_epochs=[100, 150], ... decay_factor=10.0, ... ) >>> log = trainer.train() """ def __init__( self, train_args: TrainingArguments, model_cfg: dict, train_dataset, eval_dataset: Optional[List] = None, metric: Optional[Callable[..., Mapping[str, float]]] = None, callbacks: Optional[List[TrainerCallback]] = None, decay_epochs: Optional[List[int]] = None, decay_factor: float = 10.0, train_eval_dataset = None, ): """ Parameters ---------- train_args : TrainingArguments model_cfg : dict passed to build_gnn_model() train_dataset : PyG-compatible graph dataset (train split) eval_dataset : list of PyG-compatible graph datasets (eval splits) metric : callable (y_true, y_pred) -> dict of floats callbacks : list of TrainerCallback instances decay_epochs : epoch indices at which to decay the LR, e.g. [100, 200] decay_factor : LR divisor applied at each decay epoch (default 10) train_eval_dataset : optional separate dataset for unbiased train-split evaluation; falls back to train_dataset when None """ # Stash GNN-specific state before super().__init__ runs, because the # parent calls _get_train_dataloader() which we override below. self._model_cfg_gnn = model_cfg self._with_edge_features = False self.decay_epochs = decay_epochs or [] self.decay_factor = decay_factor self.train_eval_dataset = train_eval_dataset # The parent will call build_model() (CNN path) and produce a # placeholder model. We swap it out at the start of train(). super().__init__( train_args = train_args, model_cfg = model_cfg, train_dataset = train_dataset, eval_dataset = eval_dataset, metric = metric, callbacks = callbacks, ) def _build_model(self, model_cfg: dict): """ Build a GNN model from libauc.models. Required keys ------------- name : one of gcn | gin | gine | graphsage | gat | mpnn | deepergcn | pna num_tasks : number of output tasks (default 1) emb_dim : node/edge embedding size (default 256) num_layers : number of message-passing layers (default 5) Optional keys (forwarded verbatim to the model constructor) ----------------------------------------------------------- graph_pooling, dropout, atom_features_dims, bond_features_dims, act, norm, jk, v2 (GAT only), aggr / t / learn_t / p / learn_p / block (DeeperGCN only), pretrained (bool), pretrained_path (str) Returns ------- (model, with_edge_features: bool) with_edge_features is inferred from the model's `supports_edge_attr` class-level flag, which every model in libauc.models declares. """ name = model_cfg.get("name", "").lower() if name not in _GNN_REGISTRY: raise ValueError( f"Unknown GNN model '{name}'. " f"Supported: {list(_GNN_REGISTRY.keys())}" ) # Import the class from libauc.models import libauc.models as libauc_models cls_name = _GNN_REGISTRY[name] model_cls = getattr(libauc_models, cls_name, None) if model_cls is None: raise ImportError( f"'{cls_name}' not found in libauc.models. " "Make sure your libauc version includes GNN support." ) # ---- build constructor kwargs ---------------------------------------- constructor_kwargs = dict( num_tasks = model_cfg.get("num_tasks", 1), emb_dim = model_cfg.get("emb_dim", 256), num_layers = model_cfg.get("num_layers", 5), ) # Optional shared kwargs for key in ("graph_pooling", "dropout", "atom_features_dims", "bond_features_dims", "act", "norm", "jk"): if key in model_cfg: constructor_kwargs[key] = model_cfg[key] # DeeperGCN-specific kwargs if name == "deepergcn": for key in ("aggr", "t", "learn_t", "p", "learn_p", "block"): if key in model_cfg: constructor_kwargs[key] = model_cfg[key] # GAT-specific kwargs if name == "gat" and "v2" in model_cfg: constructor_kwargs["v2"] = model_cfg["v2"] model = model_cls(**constructor_kwargs).cuda() # ---- infer whether this architecture uses edge features ---------------- # Every BasicGNN subclass in libauc.models declares `supports_edge_attr` # as a Final[bool] class attribute. DeeperGCN doesn't inherit BasicGNN # but always takes edge_attr in forward(), so we fall back to a name-based # lookup. if hasattr(model_cls, "supports_edge_attr"): with_edge_features = bool(model_cls.supports_edge_attr) else: with_edge_features = name in {"deepergcn", "gine", "gat", "mpnn", "pna"} # ---- optional warm-start ----------------------------------------------- if model_cfg.get("pretrained", False): pretrained_path = model_cfg.get("pretrained_path") if not pretrained_path: raise ValueError( "pretrained=True but 'pretrained_path' is not set in model_cfg." ) state_dict = torch.load(pretrained_path, weights_only=False) if "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] msg = model.load_state_dict(state_dict, strict=False) logger.info(f"GNN pretrained weights loaded: {msg}") model.graph_pred_linear.reset_parameters() logger.info( f"Built {cls_name} | emb_dim={constructor_kwargs['emb_dim']} " f"| num_layers={constructor_kwargs['num_layers']} " f"| with_edge_features={with_edge_features}" ) self._with_edge_features = with_edge_features self.model = model # ------------------------------------------------------------------ # DataLoader overrides – use PyG's DataLoader # ------------------------------------------------------------------ def _get_train_dataloader(self, train_args: TrainingArguments): from torch_geometric.loader import DataLoader as PyGDataLoader from libauc.sampler import DualSampler sampler = DualSampler( self.train_dataset, train_args.batch_size, sampling_rate=train_args.sampling_rate, ) loader = PyGDataLoader( self.train_dataset, batch_size = train_args.batch_size, sampler = sampler, num_workers = train_args.num_workers, ) return sampler, loader def _get_eval_dataloader(self, dataset, train_args: TrainingArguments): from torch_geometric.loader import DataLoader as PyGDataLoader return PyGDataLoader( dataset, batch_size = train_args.eval_batch_size, shuffle = False, num_workers = train_args.num_workers, ) # ------------------------------------------------------------------ # GNN forward-pass helper # ------------------------------------------------------------------ def _forward(self, model, batch): """ Run a forward pass using the signature appropriate for this architecture. supports_edge_attr=False → model(x, edge_index, batch) supports_edge_attr=True → model(x, edge_index, edge_attr, batch) Returns sigmoid probabilities, shape [N, num_tasks]. """ if self._with_edge_features: logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) else: logits = model(batch.x, batch.edge_index, batch.batch) return torch.sigmoid(logits) # ------------------------------------------------------------------ # Main training loop # ------------------------------------------------------------------
[docs] def train(self): """ GNN training loop. Steps each epoch: 1. Optional LR decay if epoch is in decay_epochs. 2. Forward / backward over the training loader. 3. Evaluation on the training split (unbiased loader). 4. Evaluation on all registered eval loaders. 5. Callbacks and periodic checkpointing. Returns ------- list : training log produced by the state / callback system """ self.callback_handler.on_train_begin(self.args, self.state) model = self.model.cuda() self.loss_fn = self.loss_fn.cuda() # Optional checkpoint resume if self.args.resume_from_checkpoint: latest = self.get_latest_checkpoint( os.path.join(self.args.output_path, self.args.experiment_name) ) if latest: self.load_checkpoint(latest) logger.info(f"Resuming from epoch {self.state.epoch}") else: logger.info("No checkpoint found, starting from scratch.") for epoch in range(self.state.epoch, self.args.epochs): self.callback_handler.on_epoch_begin(self.args, self.state) # ── Optional LR decay ─────────────────────────────────────────── if epoch in self.decay_epochs: self.optimizer.update_lr(decay_factor=self.decay_factor) logger.info( f"Epoch {epoch}: LR decayed by /{self.decay_factor}. " f"New LR = {self.optimizer.lr:.6f}" ) # ── Training ──────────────────────────────────────────────────── model.train() train_loss = [] for batch in self.trainloader: self.callback_handler.on_step_begin(self.args, self.state) data, targets, index = batch data = data.cuda() targets = targets.cuda() index = index.cuda() pred = self._forward(model, data) # Compute loss if self.args.loss == "CrossEntropyLoss": raise ValueError("CrossEntropyLoss not supported for GNNs.") if self.args.loss == "BCELoss": loss = self.loss_fn(pred, targets.reshape(-1,1)) else: loss = self.loss_fn(pred, targets, index=index) self.optimizer.zero_grad() loss.backward() self.optimizer.step() train_loss.append(loss.item()) self.callback_handler.on_step_end(self.args, self.state) # ── Evaluation ────────────────────────────────────────────────── model.eval() avg_train_loss = float(np.mean(train_loss)) eval_metrics, test_true, test_pred = self.evaluate_loop(model) self.callback_handler.on_epoch_end( self.args, self.state, metrics = eval_metrics, train_loss = avg_train_loss, lr = self.optimizer.lr, test_true = test_true, test_pred = test_pred, ) # ── Checkpointing ─────────────────────────────────────────────── if ( (epoch + 1) % self.args.save_checkpoint_every == 0 or (epoch + 1) == self.args.epochs ): ckpt = os.path.join( self.args.output_path, self.args.experiment_name, f"epoch_{epoch + 1}.pt", ) self.save_checkpoint(ckpt) self.callback_handler.on_train_end(self.args, self.state) return self.state.train_log
# ------------------------------------------------------------------ # Evaluation helpers # ------------------------------------------------------------------ def _eval_single_loader(self, model, loader): """ Run inference over *loader* and compute the registered metric. Returns ------- (metrics_dict, y_true np.ndarray, y_pred np.ndarray) """ pred_list, true_list = [], [] with torch.no_grad(): for batch in loader: data, targets, index = batch data = data.cuda() targets = targets.cuda() pred = self._forward(model, data) pred_list.append(pred.cpu().detach().numpy()) true_list.append(targets.cpu().detach().numpy()) y_true = np.concatenate(true_list) y_pred = np.concatenate(pred_list) if y_pred.ndim > 1: y_pred = y_pred.flatten() if y_true.ndim > 1: y_true = y_true.flatten() metrics = self.metric(y_true, y_pred) if self.metric else {} return metrics, y_true, y_pred
[docs] def evaluate(self, loader, model): """ Override base Trainer.evaluate() to use the GNN forward pass. Args ---- loader : PyG DataLoader model : GNN model Returns ------- (metrics_dict, y_true, y_pred) """ return self._eval_single_loader(model, loader)