import logging
import os
from typing import Callable, List, Mapping, Optional, Union
import numpy as np
import torch
from torch.utils.data import Dataset
from .image_trainer import Trainer
from ..config.args import TrainingArguments
from .callbacks import CallbackHandler, TrainerCallback
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Architectures available in libauc.models for graph learning
# ---------------------------------------------------------------------------
_GNN_REGISTRY = {
"gcn": "GCN",
"gin": "GIN",
"gine": "GINE",
"graphsage": "GraphSAGE",
"gat": "GAT",
"mpnn": "MPNN",
"deepergcn": "DeeperGCN",
"pna": "PNA",
}
[docs]
class GraphTrainer(Trainer):
r"""
Training loop for graph neural networks built with libauc's GNN model zoo.
``GraphTrainer`` extends :class:`~trainer.core.image_trainer.Trainer` with
graph-aware overrides:
* :meth:`_build_model` — looks up the requested GNN architecture in
:data:`_GNN_REGISTRY` and constructs it via ``libauc.models``, then
infers whether the model expects edge features (``supports_edge_attr``).
* :meth:`_get_train_dataloader` / :meth:`_get_eval_dataloader` — use
``torch_geometric.loader.DataLoader`` instead of the standard PyTorch
one, while keeping the same :class:`~libauc.sampler.DualSampler` for
positive/negative balancing.
* :meth:`_forward` — dispatches to the correct GNN forward signature
(with or without ``edge_attr``).
* :meth:`train` — adds optional learning-rate decay at specified epochs
via ``optimizer.update_lr``.
Supported GNN architectures
---------------------------
``gcn``, ``gin``, ``gine``, ``graphsage``, ``gat``, ``mpnn``,
``deepergcn``, ``pna``
Args:
train_args (TrainingArguments): Training configuration.
model_cfg (dict): GNN model configuration. Required key: ``name``
(one of the architectures listed above). Optional keys:
``emb_dim`` (default ``256``), ``num_layers`` (default ``5``),
``graph_pooling``, ``dropout``,
``atom_features_dims``, ``bond_features_dims``, ``act``, ``norm``,
``jk``, ``v2`` (GAT-only), ``aggr`` / ``t`` / ``learn_t`` / ``p``
/ ``learn_p`` / ``block`` (DeeperGCN-only), ``pretrained`` (bool),
``pretrained_path`` (str).
train_dataset: PyG-compatible graph dataset (train split).
eval_dataset (list, optional): PyG-compatible graph datasets for
evaluation splits (default: ``None``).
metric (callable, optional): ``(y_true, y_pred) -> dict[str, float]``
callbacks (list[TrainerCallback], optional): Training callbacks.
decay_epochs (list[int], optional): Epoch indices at which
``optimizer.update_lr(decay_factor=decay_factor)`` is called
(default: no decay).
decay_factor (float): LR divisor at each decay epoch (default: ``10.0``).
train_eval_dataset: Optional dataset for an unbiased train-split
evaluation; falls back to ``train_dataset`` when ``None``.
Example::
>>> trainer = GraphTrainer(
... train_args=train_args,
... model_cfg={"name": "gin", "emb_dim": 300},
... train_dataset=train_ds,
... eval_dataset=[val_ds, test_ds],
... metric=metric_fn,
... callbacks=[CLICallback()],
... decay_epochs=[100, 150],
... decay_factor=10.0,
... )
>>> log = trainer.train()
"""
def __init__(
self,
train_args: TrainingArguments,
model_cfg: dict,
train_dataset,
eval_dataset: Optional[List] = None,
metric: Optional[Callable[..., Mapping[str, float]]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
decay_epochs: Optional[List[int]] = None,
decay_factor: float = 10.0,
train_eval_dataset = None,
):
"""
Parameters
----------
train_args : TrainingArguments
model_cfg : dict passed to build_gnn_model()
train_dataset : PyG-compatible graph dataset (train split)
eval_dataset : list of PyG-compatible graph datasets (eval splits)
metric : callable (y_true, y_pred) -> dict of floats
callbacks : list of TrainerCallback instances
decay_epochs : epoch indices at which to decay the LR, e.g. [100, 200]
decay_factor : LR divisor applied at each decay epoch (default 10)
train_eval_dataset : optional separate dataset for unbiased train-split
evaluation; falls back to train_dataset when None
"""
# Stash GNN-specific state before super().__init__ runs, because the
# parent calls _get_train_dataloader() which we override below.
self._model_cfg_gnn = model_cfg
self._with_edge_features = False
self.decay_epochs = decay_epochs or []
self.decay_factor = decay_factor
self.train_eval_dataset = train_eval_dataset
# The parent will call build_model() (CNN path) and produce a
# placeholder model. We swap it out at the start of train().
super().__init__(
train_args = train_args,
model_cfg = model_cfg,
train_dataset = train_dataset,
eval_dataset = eval_dataset,
metric = metric,
callbacks = callbacks,
)
def _build_model(self, model_cfg: dict):
"""
Build a GNN model from libauc.models.
Required keys
-------------
name : one of gcn | gin | gine | graphsage | gat | mpnn | deepergcn | pna
Optional keys (inferred automatically)
---------------------------------------
num_tasks : inferred from ``self.args.num_tasks``; ignored if present in model_cfg
emb_dim : node/edge embedding size (default 256)
num_layers : number of message-passing layers (default 5)
Optional keys (forwarded verbatim to the model constructor)
-----------------------------------------------------------
graph_pooling, dropout, atom_features_dims, bond_features_dims,
act, norm, jk, v2 (GAT only),
aggr / t / learn_t / p / learn_p / block (DeeperGCN only),
pretrained (bool), pretrained_path (str)
Returns
-------
(model, with_edge_features: bool)
with_edge_features is inferred from the model's `supports_edge_attr`
class-level flag, which every model in libauc.models declares.
"""
name = model_cfg.get("name", "").lower()
# ── HuggingFace Hub path ────────────────────────────────────────────────
if "/" in name:
from torch_geometric.nn import to_captum # noqa: ensure PyG available
from transformers import AutoModel
model = AutoModel.from_pretrained(name, trust_remote_code=True).cuda()
# Infer edge-feature support from the loaded model
if hasattr(model, "supports_edge_attr"):
self._with_edge_features = bool(model.supports_edge_attr)
else:
self._with_edge_features = False
logger.info(f"Loaded HF model '{name}' | with_edge_features={self._with_edge_features}")
self.model = model
return
if name not in _GNN_REGISTRY:
raise ValueError(
f"Unknown GNN model '{name}'. "
f"Supported: {list(_GNN_REGISTRY.keys())}"
f"or a HuggingFace model ID (e.g. 'user/my-gcn')"
)
# Import the class from libauc.models
import libauc.models as libauc_models
cls_name = _GNN_REGISTRY[name]
model_cls = getattr(libauc_models, cls_name, None)
if model_cls is None:
raise ImportError(
f"'{cls_name}' not found in libauc.models. "
"Make sure your libauc version includes GNN support."
)
_META_KEYS = {"name", "num_tasks", "pretrained", "pretrained_remote", "pretrained_path"}
constructor_kwargs = {
k: v for k, v in model_cfg.items() if k not in _META_KEYS
}
constructor_kwargs["num_tasks"] = self.args.num_tasks
constructor_kwargs.setdefault("emb_dim", 256)
constructor_kwargs.setdefault("num_layers", 5)
model = model_cls(**constructor_kwargs).cuda()
if hasattr(model_cls, "supports_edge_attr"):
with_edge_features = bool(model_cls.supports_edge_attr)
else:
with_edge_features = name in {"deepergcn", "gine", "gat", "mpnn", "pna"}
# ---- optional warm-start -----------------------------------------------
if model_cfg.get("pretrained", False):
pretrained_path = model_cfg.get("pretrained_path")
if not pretrained_path:
raise ValueError(
"pretrained=True but 'pretrained_path' is not set in model_cfg."
)
state_dict = torch.load(pretrained_path, weights_only=False)
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
msg = model.load_state_dict(state_dict, strict=False)
logger.info(f"GNN pretrained weights loaded: {msg}")
model.graph_pred_linear.reset_parameters()
logger.info(
f"Built {cls_name} | emb_dim={constructor_kwargs['emb_dim']} "
f"| num_layers={constructor_kwargs['num_layers']} "
f"| with_edge_features={with_edge_features}"
)
self._with_edge_features = with_edge_features
self.model = model
# ------------------------------------------------------------------
# DataLoader overrides – use PyG's DataLoader
# ------------------------------------------------------------------
def _get_train_dataloader(self, train_args: TrainingArguments):
from torch_geometric.loader import DataLoader as PyGDataLoader
from libauc.sampler import DualSampler
sampler = DualSampler(
self.train_dataset,
train_args.batch_size,
sampling_rate=train_args.sampling_rate,
)
loader = PyGDataLoader(
self.train_dataset,
batch_size = train_args.batch_size,
sampler = sampler,
num_workers = train_args.num_workers,
)
return sampler, loader
def _get_eval_dataloader(self, dataset, train_args: TrainingArguments):
from torch_geometric.loader import DataLoader as PyGDataLoader
return PyGDataLoader(
dataset,
batch_size = train_args.eval_batch_size,
shuffle = False,
num_workers = train_args.num_workers,
)
# ------------------------------------------------------------------
# GNN forward-pass helper
# ------------------------------------------------------------------
def _forward(self, model, batch):
"""
Run a forward pass using the signature appropriate for this architecture.
supports_edge_attr=False → model(x, edge_index, batch)
supports_edge_attr=True → model(x, edge_index, edge_attr, batch)
Returns sigmoid probabilities, shape [N, num_tasks].
"""
if self._with_edge_features:
logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
else:
logits = model(batch.x, batch.edge_index, batch.batch)
return torch.sigmoid(logits)
# ------------------------------------------------------------------
# Main training loop
# ------------------------------------------------------------------
[docs]
def train(self):
"""
GNN training loop.
Steps each epoch:
1. Optional LR decay if epoch is in decay_epochs.
2. Forward / backward over the training loader.
3. Evaluation on the training split (unbiased loader).
4. Evaluation on all registered eval loaders.
5. Callbacks and periodic checkpointing.
Returns
-------
list : training log produced by the state / callback system
"""
self.callback_handler.on_train_begin(self.args, self.state)
model = self.model.cuda()
self.loss_fn = self.loss_fn.cuda()
# Optional checkpoint resume
if self.args.resume_from_checkpoint:
latest = self.get_latest_checkpoint(
os.path.join(self.args.output_path, self.args.experiment_name)
)
if latest:
self.load_checkpoint(latest)
logger.info(f"Resuming from epoch {self.state.epoch}")
else:
logger.info("No checkpoint found, starting from scratch.")
for epoch in range(self.state.epoch, self.args.epochs):
self.callback_handler.on_epoch_begin(self.args, self.state)
# ── Optional LR decay ───────────────────────────────────────────
if epoch in self.decay_epochs:
self.optimizer.update_lr(decay_factor=self.decay_factor)
logger.info(
f"Epoch {epoch}: LR decayed by /{self.decay_factor}. "
f"New LR = {self.optimizer.lr:.6f}"
)
# ── Training ────────────────────────────────────────────────────
model.train()
train_loss = []
for batch in self.trainloader:
self.callback_handler.on_step_begin(self.args, self.state)
data, targets, index = batch
data = data.cuda()
targets = targets.cuda()
index = index.cuda()
pred = self._forward(model, data)
# Compute loss
if self.args.loss == "CrossEntropyLoss":
raise ValueError("CrossEntropyLoss not supported for GNNs.")
if self.args.loss == "BCELoss":
loss = self.loss_fn(pred, targets.reshape(-1,1))
else:
loss = self.loss_fn(pred, targets, index=index)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
train_loss.append(loss.item())
self.callback_handler.on_step_end(self.args, self.state)
# ── Evaluation ──────────────────────────────────────────────────
model.eval()
avg_train_loss = float(np.mean(train_loss))
eval_metrics, test_true, test_pred = self.evaluate_loop(model)
self.callback_handler.on_epoch_end(
self.args,
self.state,
metrics = eval_metrics,
train_loss = avg_train_loss,
lr = self.optimizer.lr,
test_true = test_true,
test_pred = test_pred,
)
# ── Checkpointing ───────────────────────────────────────────────
if (
(epoch + 1) % self.args.save_checkpoint_every == 0
or (epoch + 1) == self.args.epochs
):
ckpt = os.path.join(
self.args.output_path,
self.args.experiment_name,
f"epoch_{epoch + 1}.pt",
)
self.save_checkpoint(ckpt)
self.callback_handler.on_train_end(self.args, self.state)
return self.state.train_log
# ------------------------------------------------------------------
# Evaluation helpers
# ------------------------------------------------------------------
def _eval_single_loader(self, model, loader):
"""
Run inference over *loader* and compute the registered metric.
Returns
-------
(metrics_dict, y_true np.ndarray, y_pred np.ndarray)
"""
pred_list, true_list = [], []
with torch.no_grad():
for batch in loader:
data, targets, index = batch
data = data.cuda()
targets = targets.cuda()
pred = self._forward(model, data)
pred_list.append(pred.cpu().detach().numpy())
true_list.append(targets.cpu().detach().numpy())
y_true = np.concatenate(true_list)
y_pred = np.concatenate(pred_list)
if y_pred.ndim > 1:
y_pred = y_pred.flatten()
if y_true.ndim > 1:
y_true = y_true.flatten()
metrics = self.metric(y_true, y_pred) if self.metric else {}
return metrics, y_true, y_pred
[docs]
def evaluate(self, loader, model):
"""
Override base Trainer.evaluate() to use the GNN forward pass.
Args
----
loader : PyG DataLoader
model : GNN model
Returns
-------
(metrics_dict, y_true, y_pred)
"""
return self._eval_single_loader(model, loader)