Extreme Classification by Optimizing Compositional Entropic Risk (SCENT and SOX)
Author: Xiyuan Wei
Introduction
In this tutorial, you will learn how to train a linear model by optimizing Compositional Entropic Risk (CERM) using SCENT and SOX on a subset of TreeOfLife-10M, a biology dataset of 163K species. This version was implementated in PyTorch. It is recommended to run this notebook on a GPU-enabled environment, e.g., Google Colab.
Reference
If you find this tutorial helpful in your work, please cite our library paper and the following paper:
@article{wei2026geometry,
title={A Geometry-Aware Efficient Algorithm for Compositional Entropic Risk Minimization},
author={Wei, Xiyuan and Zhou, Linli and Wang, Bokun and Lin, Chih-Jen and Yang, Tianbao},
journal={arXiv preprint arXiv:2602.02877},
year={2026}
}
Prerequisites
Download pre-computed features with gdown (pip install gdown if needed).
!gdown --folder 'https://drive.google.com/drive/folders/10cY2Azqz9Gnci-r4fXcGIpH_e5YaKXH_?usp=sharing' -O ./features
Install LibAUC
Let’s start with installing our library here.
!pip install libauc
Import required packages
import os
import logging
import pathlib
import json
import sys
import random
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from libauc.losses import EntLossClassification
from libauc.optimizers import SCENT
Helper Functions
Model architecture
class LinearClassifier(nn.Module):
"""Linear classifier: maps feature vectors to logits for m classes.
forward(x): expects shape (B, D) and returns logits (B, m)
"""
def __init__(self, feature_dim: int, num_classes: int):
super().__init__()
self.feature_dim = feature_dim
self.num_classes = num_classes
self.fc = nn.Linear(feature_dim, num_classes, bias=False)
nn.init.normal_(self.fc.weight, mean=0.0, std=0.01)
def forward(self, x: torch.Tensor, labels: torch.Tensor, classes: torch.Tensor | str | None = None) -> torch.Tensor:
"""Compute logits for only the classes in `labels`.
Args:
x: tensor (B, D)
labels: 1D LongTensor of class indices (B,)
classes: sampled classes, or 'all' (meaning all classes), or None (meaning classes in labels)
Returns:
logits (B, K)
"""
w_pos = self.fc.weight[labels] # (B, D)
mask = None
if isinstance(classes, str):
assert classes == "all"
w_sampled = self.fc.weight
elif isinstance(classes, torch.Tensor):
w_sampled = self.fc.weight[classes]
elif classes is None:
unique_labels = torch.unique(labels) # (K,)
w_sampled = self.fc.weight[unique_labels] # (K, D)
mask = (labels.unsqueeze(1) == unique_labels.unsqueeze(0)) # (B, K)
else:
raise ValueError(f"Unknown classes type: {type(classes)}")
logits = x @ w_sampled.T - torch.sum(torch.mul(x, w_pos), dim=1, keepdim=True) # (B, K)
if mask is not None:
logits.masked_fill_(mask.to(logits.device), float("-inf"))
return logits
Dataset of features
class FeaturesDataset(Dataset):
"""Dataset for precomputed features.
Expects features_path (N, D) and labels_path (N,) where labels are ints 0..C-1.
"""
def __init__(self, features_path: str, labels_path: str):
self.features = torch.load(features_path)
self.labels = torch.load(labels_path)
assert self.features.shape[0] == self.labels.shape[0], "features/labels length mismatch"
def __len__(self):
return self.features.shape[0]
def __getitem__(self, idx):
x = self.features[idx]
y = self.labels[idx]
return x, y, idx
Training helper functions
def setup_logging(out_log_file=None):
logging.root.setLevel(level=logging.INFO)
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
logger.setLevel(level=logging.INFO)
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
if out_log_file is not None:
file_handler = logging.FileHandler(out_log_file)
logging.root.addHandler(file_handler)
for handler in logging.root.handlers:
handler.setFormatter(formatter)
def build_dataloaders(root_data_dir, batch_size):
dataloader_list = []
for split, shuffle in zip(["train", "val", "test"], [True, False, False]):
data_dir = os.path.join(root_data_dir, split)
if not os.path.exists(data_dir):
if split == "train" or split == "val":
raise FileNotFoundError(f"Data directory {data_dir} does not exist.")
else:
dataloader_list.append(None)
continue
features = os.path.join(data_dir, "features.pt")
labels = os.path.join(data_dir, "labels.pt")
ds = FeaturesDataset(features, labels)
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=4)
dataloader_list.append(dataloader)
train_loader, val_loader, test_loader = dataloader_list
return train_loader, val_loader, test_loader
def train_one_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss = 0.0
total = 0
for i, batch in enumerate(loader):
feats, labels, indices = batch
feats = feats.to(device)
logits = model(feats, labels)
loss_dict = criterion(logits, indices)
loss = loss_dict["loss"]
with torch.no_grad():
model.eval()
labels = labels.to(device, dtype=torch.long)
cross_entropy_loss = F.cross_entropy(model.fc(feats), labels)
loss_dict["cross_entropy_loss"] = cross_entropy_loss
model.train()
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += cross_entropy_loss.item() * feats.size(0)
total += feats.size(0)
if i % 100 == 0:
log_str = f" Batch {i} / {len(loader)}:"
for key, value in loss_dict.items():
log_str += f" {key}={value.item():.6f}"
logging.info(log_str)
return total_loss / total
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0
for i, batch in enumerate(loader):
feats, labels, _ = batch
feats = feats.to(device)
labels = labels.to(device, dtype=torch.long)
logits = model.fc(feats)
loss = F.cross_entropy(logits, labels)
total_loss += loss.item() * feats.size(0)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += feats.size(0)
if i % 100 == 0:
logging.info(f" Batch {i} / {len(loader)}: loss={loss.item():.6f}")
return total_loss / total, correct / total
def set_seed(seed):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Make cudnn deterministic
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logging.info(f"Set random seed to {seed}")
Universal Hyperparameters for SCENT and SOX
data_dir = "./features/treeoflife10m_subset"
root_out_dir = "./outputs/"
seed = 2026
device = "cuda"
feature_dim = 512
num_classes = 163002
data_size = 762654
save_frequency = 1
epochs = 20
batch_size = 128
setup_logging()
Training (SCENT)
name = "scent_treeoflife10m_subset"
lr = 0.002
alpha = 0.0
momentum = 0.9
alpha_multiplier = 0.03
Creating model and optimizer
out_dir = pathlib.Path(root_out_dir) / name
os.makedirs(out_dir / "checkpoints", exist_ok=True)
out_log_file = out_dir / "out.log"
set_seed(seed)
train_loader, val_loader, test_loader = build_dataloaders(data_dir, batch_size)
model = LinearClassifier(feature_dim, num_classes).to(device)
optimizer = SCENT(model.parameters(), lr=lr, momentum=momentum)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
criterion = EntLossClassification(data_size=data_size, alpha=alpha,
is_scent=True, alpha_multiplier=alpha_multiplier)
Training loop
best_val_acc = 0.0
best_test_acc = 0.0
best_test_loss = float("inf")
start_epoch = 0
for epoch in range(start_epoch + 1, epochs + 1):
logging.info(f"Epoch {epoch}/{epochs}, learning_rate={lr_scheduler.get_last_lr()[0]:.6f}")
if hasattr(criterion, "adjust_gamma"):
criterion.adjust_gamma(epoch, epochs)
cross_entropy_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
lr_scheduler.step()
logging.info("Evaluating on validation set")
val_loss, val_acc = evaluate(model, val_loader, device)
if val_acc > best_val_acc:
logging.info("New best model found")
best_val_acc = val_acc
if test_loader is not None:
logging.info("Evaluating on test set")
best_test_loss, best_test_acc = evaluate(model, test_loader, device)
logging.info(f"Epoch {epoch}: cross_entropy_loss={cross_entropy_loss:.6f} val_loss={val_loss:.6f} val_acc={val_acc:.6f}")
logging.info(f" Best val_acc={best_val_acc:.6f}")
if test_loader is not None:
logging.info(f" Best test acc={best_test_acc:.6f} Best test loss={best_test_loss:.6f}")
eval_results = {
"epoch": epoch,
"cross_entropy_loss": cross_entropy_loss,
"val_loss": val_loss,
"val_acc": val_acc,
"best_val_acc": best_val_acc,
}
if test_loader is not None:
eval_results.update({
"best_test_acc": best_test_acc,
"best_test_loss": best_test_loss,
})
with open(out_dir / f"eval_{name}.jsonl", "a") as f:
f.write(json.dumps(eval_results) + "\n")
if epoch % save_frequency == 0 or epoch == epochs:
save_dict = {
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"epoch": epoch,
}
if hasattr(criterion, "nu"):
save_dict["criterion_nu"] = criterion.nu.cpu()
torch.save(save_dict, out_dir / "checkpoints" / f"epoch_{epoch}.pt")
logging.info(f"Saved checkpoint for epoch {epoch}.")
Training (SOX)
name = "sox_treeoflife10m_subset"
lr = 0.001
gamma = 0.6
momentum = 0.9
Creating model and optimizer
out_dir = pathlib.Path(root_out_dir) / name
os.makedirs(out_dir / "checkpoints", exist_ok=True)
out_log_file = out_dir / "out.log"
set_seed(seed)
train_loader, val_loader, test_loader = build_dataloaders(data_dir, batch_size)
model = LinearClassifier(feature_dim, num_classes).to(device)
optimizer = SCENT(model.parameters(), lr=lr, momentum=momentum)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
criterion = EntLossClassification(data_size=data_size, gamma=gamma, is_scent=False)
Training loop
best_val_acc = 0.0
best_test_acc = 0.0
best_test_loss = float("inf")
start_epoch = 0
for epoch in range(start_epoch + 1, epochs + 1):
logging.info(f"Epoch {epoch}/{epochs}, learning_rate={lr_scheduler.get_last_lr()[0]:.6f}")
if hasattr(criterion, "adjust_gamma"):
criterion.adjust_gamma(epoch, epochs)
cross_entropy_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
lr_scheduler.step()
logging.info("Evaluating on validation set")
val_loss, val_acc = evaluate(model, val_loader, device)
if val_acc > best_val_acc:
logging.info("New best model found")
best_val_acc = val_acc
if test_loader is not None:
logging.info("Evaluating on test set")
best_test_loss, best_test_acc = evaluate(model, test_loader, device)
logging.info(f"Epoch {epoch}: cross_entropy_loss={cross_entropy_loss:.6f} val_loss={val_loss:.6f} val_acc={val_acc:.6f}")
logging.info(f" Best val_acc={best_val_acc:.6f}")
if test_loader is not None:
logging.info(f" Best test acc={best_test_acc:.6f} Best test loss={best_test_loss:.6f}")
eval_results = {
"epoch": epoch,
"cross_entropy_loss": cross_entropy_loss,
"val_loss": val_loss,
"val_acc": val_acc,
"best_val_acc": best_val_acc,
}
if test_loader is not None:
eval_results.update({
"best_test_acc": best_test_acc,
"best_test_loss": best_test_loss,
})
with open(out_dir / f"eval_{name}.jsonl", "a") as f:
f.write(json.dumps(eval_results) + "\n")
if epoch % save_frequency == 0 or epoch == epochs:
save_dict = {
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"epoch": epoch,
}
if hasattr(criterion, "nu"):
save_dict["criterion_nu"] = criterion.nu.cpu()
torch.save(save_dict, out_dir / "checkpoints" / f"epoch_{epoch}.pt")
logging.info(f"Saved checkpoint for epoch {epoch}.")
Results and Visualization
# plot
import json
import math
import matplotlib.pyplot as plt
for metric, name in zip(["cross_entropy_loss", "val_loss", "val_acc"], ["Training Loss", "Validation Loss", "Validation Accuracy"]):
if "loss" in metric:
results_scent = [-math.log(1.0 / num_classes)] + [0.0] * 20
results_sox = [-math.log(1.0 / num_classes)] + [0.0] * 20
else:
results_scent = [1.0 / num_classes] + [0.0] * 20
results_sox = [1.0 / num_classes] + [0.0] * 20
with open("./outputs/scent_treeoflife10m_subset/eval_scent_treeoflife10m_subset.jsonl", "r") as f:
for line in f:
eval_results = json.loads(line)
results_scent[eval_results["epoch"]] = eval_results[metric]
with open("./outputs/sox_treeoflife10m_subset/eval_sox_treeoflife10m_subset.jsonl", "r") as f:
for line in f:
eval_results = json.loads(line)
results_sox[eval_results["epoch"]] = eval_results[metric]
fig, ax = plt.subplots()
ax.plot(results_scent, label="SCENT")
ax.plot(results_sox, label="SOX")
ax.set_xlabel("Epoch")
ax.set_ylabel(name)
ax.legend()
plt.show()
Method |
Training Loss |
Validation Loss |
Validation Accuracy |
|---|---|---|---|
SCENT |
9.24 |
9.60 |
0.26 |
SOX |
9.73 |
9.95 |
0.20 |