Source code for libauc.trainer.helpers


from libauc.metrics import auc_prc_score, auc_roc_score, pauc_roc_score
import logging
import sys

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Metric builder
# ---------------------------------------------------------------------------

[docs] def build_metric(metric_names, metric_kwargs): """ Build an evaluation function from a list of metric names. The returned callable computes each requested metric after every evaluation epoch and returns the results as a flat dict. It does **not** affect the training objective — losses and optimizers are configured separately via ``TrainingArguments``. Supported metric names (case-insensitive): * ``"AUROC"`` — full-ROC AUROC (``libauc.metrics.auc_roc_score``). When ``metric_kwargs`` for this entry contains ``max_fpr`` or ``min_tpr``, partial AUROC (``libauc.metrics.pauc_roc_score``) is computed instead and the result key becomes e.g. ``"PAUROC(max_fpr=0.3)"``. * ``"AUPRC"`` — area under the precision-recall curve (``libauc.metrics.auc_prc_score``). * ``"ACC"`` — accuracy at a fixed decision threshold of 0.5 (``sklearn.metrics.accuracy_score``). Unknown names are skipped with a warning. The same name may appear more than once with different ``metric_kwargs`` entries (e.g. full AUROC and partial AUROC simultaneously). Args: metric_names (list[str]): Ordered list of metric names to compute, e.g. ``["AUROC", "AUPRC", "ACC"]``. metric_kwargs (list[dict]): Per-metric keyword arguments. ``metric_kwargs[i]`` is forwarded to the computation function for ``metric_names[i]``. Missing entries default to ``{}``. Returns: Callable[[numpy.ndarray, numpy.ndarray], dict[str, float]]: A function ``metric_fn(test_true, test_pred) -> results`` where ``test_true`` is the 1-D array of ground-truth labels, ``test_pred`` is the 1-D array of model output scores, and ``results`` maps each metric name to its score. """ from sklearn import metrics as skmetrics def metric_fn(test_true, test_pred): results = {} for id, name in enumerate(metric_names): if id < len(metric_kwargs): kwargs = metric_kwargs[id] else: kwargs = {} name_upper = name.upper() if name_upper == "AUROC": if "max_fpr" in kwargs.keys() or "min_tpr" in kwargs.keys(): args = ', '.join([str(k) + '=' + str(v) for k, v in kwargs.items()]) results[f"PAUROC({args})"] = pauc_roc_score(test_true, test_pred, **kwargs) else: results["AUROC"] = auc_roc_score(test_true, test_pred) elif name_upper == "AUPRC": results["AUPRC"] = auc_prc_score(test_true, test_pred) elif name_upper == "ACC": results["ACC"] = skmetrics.accuracy_score( test_true, (test_pred >= 0.5).astype(int) ) else: logger.warning(f"Unknown metric '{name}', skipping.") return results return metric_fn