from libauc.metrics import auc_prc_score, auc_roc_score, pauc_roc_score
import logging
import sys
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Metric builder
# ---------------------------------------------------------------------------
[docs]
def build_metric(metric_names, metric_kwargs):
"""
Build an evaluation function from a list of metric names.
The returned callable computes each requested metric after every
evaluation epoch and returns the results as a flat dict. It does
**not** affect the training objective — losses and optimizers are
configured separately via ``TrainingArguments``.
Supported metric names (case-insensitive):
* ``"AUROC"`` — full-ROC AUROC (``libauc.metrics.auc_roc_score``).
When ``metric_kwargs`` for this entry contains ``max_fpr`` or
``min_tpr``, partial AUROC (``libauc.metrics.pauc_roc_score``) is
computed instead and the result key becomes e.g. ``"PAUROC(max_fpr=0.3)"``.
* ``"AUPRC"`` — area under the precision-recall curve
(``libauc.metrics.auc_prc_score``).
* ``"ACC"`` — accuracy at a fixed decision threshold of 0.5
(``sklearn.metrics.accuracy_score``).
Unknown names are skipped with a warning. The same name may appear
more than once with different ``metric_kwargs`` entries (e.g. full AUROC
and partial AUROC simultaneously).
Args:
metric_names (list[str]): Ordered list of metric names to compute,
e.g. ``["AUROC", "AUPRC", "ACC"]``.
metric_kwargs (list[dict]): Per-metric keyword arguments.
``metric_kwargs[i]`` is forwarded to the computation function for
``metric_names[i]``. Missing entries default to ``{}``.
Returns:
Callable[[numpy.ndarray, numpy.ndarray], dict[str, float]]:
A function ``metric_fn(test_true, test_pred) -> results`` where
``test_true`` is the 1-D array of ground-truth labels,
``test_pred`` is the 1-D array of model output scores, and
``results`` maps each metric name to its score.
"""
from sklearn import metrics as skmetrics
def metric_fn(test_true, test_pred):
results = {}
for id, name in enumerate(metric_names):
if id < len(metric_kwargs):
kwargs = metric_kwargs[id]
else:
kwargs = {}
name_upper = name.upper()
if name_upper == "AUROC":
if "max_fpr" in kwargs.keys() or "min_tpr" in kwargs.keys():
args = ', '.join([str(k) + '=' + str(v) for k, v in kwargs.items()])
results[f"PAUROC({args})"] = pauc_roc_score(test_true, test_pred, **kwargs)
else:
results["AUROC"] = auc_roc_score(test_true, test_pred)
elif name_upper == "AUPRC":
results["AUPRC"] = auc_prc_score(test_true, test_pred)
elif name_upper == "ACC":
results["ACC"] = skmetrics.accuracy_score(
test_true, (test_pred >= 0.5).astype(int)
)
else:
logger.warning(f"Unknown metric '{name}', skipping.")
return results
return metric_fn