.. _scent_extreme_classification: Extreme Classification by Optimizing Compositional Entropic Risk (SCENT and SOX) ================================================================================================================================ .. raw:: html
Run on Colab
Download Notebook
View on Github
------------------------------------------------------------------------------------ .. container:: cell markdown | **Author**: Xiyuan Wei \ Introduction ------------------------------------------------------------------------------------ In this tutorial, you will learn how to train a linear model by optimizing `Compositional Entropic Risk `__ (CERM) using SCENT and SOX on a subset of TreeOfLife-10M, a biology dataset of 163K species. This version was implementated in PyTorch. It is recommended to run this notebook on a GPU-enabled environment, e.g., `Google Colab `__. **Reference** If you find this tutorial helpful in your work, please cite our `library paper `__ and the following paper: .. code-block:: RST @article{wei2026geometry, title={A Geometry-Aware Efficient Algorithm for Compositional Entropic Risk Minimization}, author={Wei, Xiyuan and Zhou, Linli and Wang, Bokun and Lin, Chih-Jen and Yang, Tianbao}, journal={arXiv preprint arXiv:2602.02877}, year={2026} } Prerequisites -------------------------------------------------------------------------------- Download pre-computed features with ``gdown`` (``pip install gdown`` if needed). .. container:: cell code .. code:: python !gdown --folder 'https://drive.google.com/drive/folders/10cY2Azqz9Gnci-r4fXcGIpH_e5YaKXH_?usp=sharing' -O ./features Install LibAUC ------------------------------------------------------------------------------------ Let's start with installing our library here. .. container:: cell code .. code:: python !pip install libauc Import required packages ------------------------------------------------------------------------------------ .. container:: cell code .. code:: python import os import logging import pathlib import json import sys import random import math import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR from libauc.losses import EntLossClassification from libauc.optimizers import SCENT Helper Functions -------------------------------------------------------------------------------- Model architecture ^^^^^^^^^^^^^^^^^^^^ .. code:: python class LinearClassifier(nn.Module): """Linear classifier: maps feature vectors to logits for m classes. forward(x): expects shape (B, D) and returns logits (B, m) """ def __init__(self, feature_dim: int, num_classes: int): super().__init__() self.feature_dim = feature_dim self.num_classes = num_classes self.fc = nn.Linear(feature_dim, num_classes, bias=False) nn.init.normal_(self.fc.weight, mean=0.0, std=0.01) def forward(self, x: torch.Tensor, labels: torch.Tensor, classes: torch.Tensor | str | None = None) -> torch.Tensor: """Compute logits for only the classes in `labels`. Args: x: tensor (B, D) labels: 1D LongTensor of class indices (B,) classes: sampled classes, or 'all' (meaning all classes), or None (meaning classes in labels) Returns: logits (B, K) """ w_pos = self.fc.weight[labels] # (B, D) mask = None if isinstance(classes, str): assert classes == "all" w_sampled = self.fc.weight elif isinstance(classes, torch.Tensor): w_sampled = self.fc.weight[classes] elif classes is None: unique_labels = torch.unique(labels) # (K,) w_sampled = self.fc.weight[unique_labels] # (K, D) mask = (labels.unsqueeze(1) == unique_labels.unsqueeze(0)) # (B, K) else: raise ValueError(f"Unknown classes type: {type(classes)}") logits = x @ w_sampled.T - torch.sum(torch.mul(x, w_pos), dim=1, keepdim=True) # (B, K) if mask is not None: logits.masked_fill_(mask.to(logits.device), float("-inf")) return logits Dataset of features ^^^^^^^^^^^^^^^^^^^ .. code:: python class FeaturesDataset(Dataset): """Dataset for precomputed features. Expects features_path (N, D) and labels_path (N,) where labels are ints 0..C-1. """ def __init__(self, features_path: str, labels_path: str): self.features = torch.load(features_path) self.labels = torch.load(labels_path) assert self.features.shape[0] == self.labels.shape[0], "features/labels length mismatch" def __len__(self): return self.features.shape[0] def __getitem__(self, idx): x = self.features[idx] y = self.labels[idx] return x, y, idx Training helper functions ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python def setup_logging(out_log_file=None): logging.root.setLevel(level=logging.INFO) loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] for logger in loggers: logger.setLevel(level=logging.INFO) formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') if out_log_file is not None: file_handler = logging.FileHandler(out_log_file) logging.root.addHandler(file_handler) for handler in logging.root.handlers: handler.setFormatter(formatter) def build_dataloaders(root_data_dir, batch_size): dataloader_list = [] for split, shuffle in zip(["train", "val", "test"], [True, False, False]): data_dir = os.path.join(root_data_dir, split) if not os.path.exists(data_dir): if split == "train" or split == "val": raise FileNotFoundError(f"Data directory {data_dir} does not exist.") else: dataloader_list.append(None) continue features = os.path.join(data_dir, "features.pt") labels = os.path.join(data_dir, "labels.pt") ds = FeaturesDataset(features, labels) dataloader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=4) dataloader_list.append(dataloader) train_loader, val_loader, test_loader = dataloader_list return train_loader, val_loader, test_loader def train_one_epoch(model, loader, optimizer, criterion, device): model.train() total_loss = 0.0 total = 0 for i, batch in enumerate(loader): feats, labels, indices = batch feats = feats.to(device) logits = model(feats, labels) loss_dict = criterion(logits, indices) loss = loss_dict["loss"] with torch.no_grad(): model.eval() labels = labels.to(device, dtype=torch.long) cross_entropy_loss = F.cross_entropy(model.fc(feats), labels) loss_dict["cross_entropy_loss"] = cross_entropy_loss model.train() optimizer.zero_grad() loss.backward() optimizer.step() total_loss += cross_entropy_loss.item() * feats.size(0) total += feats.size(0) if i % 100 == 0: log_str = f" Batch {i} / {len(loader)}:" for key, value in loss_dict.items(): log_str += f" {key}={value.item():.6f}" logging.info(log_str) return total_loss / total @torch.no_grad() def evaluate(model, loader, device): model.eval() total_loss = 0.0 correct = 0 total = 0 for i, batch in enumerate(loader): feats, labels, _ = batch feats = feats.to(device) labels = labels.to(device, dtype=torch.long) logits = model.fc(feats) loss = F.cross_entropy(logits, labels) total_loss += loss.item() * feats.size(0) preds = logits.argmax(dim=1) correct += (preds == labels).sum().item() total += feats.size(0) if i % 100 == 0: logging.info(f" Batch {i} / {len(loader)}: loss={loss.item():.6f}") return total_loss / total, correct / total def set_seed(seed): """Set random seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Make cudnn deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False logging.info(f"Set random seed to {seed}") Universal Hyperparameters for SCENT and SOX -------------------------------------------------------------------------------- .. code:: python data_dir = "./features/treeoflife10m_subset" root_out_dir = "./outputs/" seed = 2026 device = "cuda" feature_dim = 512 num_classes = 163002 data_size = 762654 save_frequency = 1 epochs = 20 batch_size = 128 setup_logging() Training (SCENT) -------------------------------------------------------------------------------- .. code:: python name = "scent_treeoflife10m_subset" lr = 0.002 alpha = 0.0 momentum = 0.9 alpha_multiplier = 0.03 Creating model and optimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python out_dir = pathlib.Path(root_out_dir) / name os.makedirs(out_dir / "checkpoints", exist_ok=True) out_log_file = out_dir / "out.log" set_seed(seed) train_loader, val_loader, test_loader = build_dataloaders(data_dir, batch_size) model = LinearClassifier(feature_dim, num_classes).to(device) optimizer = SCENT(model.parameters(), lr=lr, momentum=momentum) lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs) criterion = EntLossClassification(data_size=data_size, alpha=alpha, is_scent=True, alpha_multiplier=alpha_multiplier) Training loop ^^^^^^^^^^^^^ .. code:: python best_val_acc = 0.0 best_test_acc = 0.0 best_test_loss = float("inf") start_epoch = 0 for epoch in range(start_epoch + 1, epochs + 1): logging.info(f"Epoch {epoch}/{epochs}, learning_rate={lr_scheduler.get_last_lr()[0]:.6f}") if hasattr(criterion, "adjust_gamma"): criterion.adjust_gamma(epoch, epochs) cross_entropy_loss = train_one_epoch(model, train_loader, optimizer, criterion, device) lr_scheduler.step() logging.info("Evaluating on validation set") val_loss, val_acc = evaluate(model, val_loader, device) if val_acc > best_val_acc: logging.info("New best model found") best_val_acc = val_acc if test_loader is not None: logging.info("Evaluating on test set") best_test_loss, best_test_acc = evaluate(model, test_loader, device) logging.info(f"Epoch {epoch}: cross_entropy_loss={cross_entropy_loss:.6f} val_loss={val_loss:.6f} val_acc={val_acc:.6f}") logging.info(f" Best val_acc={best_val_acc:.6f}") if test_loader is not None: logging.info(f" Best test acc={best_test_acc:.6f} Best test loss={best_test_loss:.6f}") eval_results = { "epoch": epoch, "cross_entropy_loss": cross_entropy_loss, "val_loss": val_loss, "val_acc": val_acc, "best_val_acc": best_val_acc, } if test_loader is not None: eval_results.update({ "best_test_acc": best_test_acc, "best_test_loss": best_test_loss, }) with open(out_dir / f"eval_{name}.jsonl", "a") as f: f.write(json.dumps(eval_results) + "\n") if epoch % save_frequency == 0 or epoch == epochs: save_dict = { "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "epoch": epoch, } if hasattr(criterion, "nu"): save_dict["criterion_nu"] = criterion.nu.cpu() torch.save(save_dict, out_dir / "checkpoints" / f"epoch_{epoch}.pt") logging.info(f"Saved checkpoint for epoch {epoch}.") Training (SOX) -------------------------------------------------------------------------------- .. code:: python name = "sox_treeoflife10m_subset" lr = 0.001 gamma = 0.6 momentum = 0.9 Creating model and optimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python out_dir = pathlib.Path(root_out_dir) / name os.makedirs(out_dir / "checkpoints", exist_ok=True) out_log_file = out_dir / "out.log" set_seed(seed) train_loader, val_loader, test_loader = build_dataloaders(data_dir, batch_size) model = LinearClassifier(feature_dim, num_classes).to(device) optimizer = SCENT(model.parameters(), lr=lr, momentum=momentum) lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs) criterion = EntLossClassification(data_size=data_size, gamma=gamma, is_scent=False) Training loop ^^^^^^^^^^^^^ .. code:: python best_val_acc = 0.0 best_test_acc = 0.0 best_test_loss = float("inf") start_epoch = 0 for epoch in range(start_epoch + 1, epochs + 1): logging.info(f"Epoch {epoch}/{epochs}, learning_rate={lr_scheduler.get_last_lr()[0]:.6f}") if hasattr(criterion, "adjust_gamma"): criterion.adjust_gamma(epoch, epochs) cross_entropy_loss = train_one_epoch(model, train_loader, optimizer, criterion, device) lr_scheduler.step() logging.info("Evaluating on validation set") val_loss, val_acc = evaluate(model, val_loader, device) if val_acc > best_val_acc: logging.info("New best model found") best_val_acc = val_acc if test_loader is not None: logging.info("Evaluating on test set") best_test_loss, best_test_acc = evaluate(model, test_loader, device) logging.info(f"Epoch {epoch}: cross_entropy_loss={cross_entropy_loss:.6f} val_loss={val_loss:.6f} val_acc={val_acc:.6f}") logging.info(f" Best val_acc={best_val_acc:.6f}") if test_loader is not None: logging.info(f" Best test acc={best_test_acc:.6f} Best test loss={best_test_loss:.6f}") eval_results = { "epoch": epoch, "cross_entropy_loss": cross_entropy_loss, "val_loss": val_loss, "val_acc": val_acc, "best_val_acc": best_val_acc, } if test_loader is not None: eval_results.update({ "best_test_acc": best_test_acc, "best_test_loss": best_test_loss, }) with open(out_dir / f"eval_{name}.jsonl", "a") as f: f.write(json.dumps(eval_results) + "\n") if epoch % save_frequency == 0 or epoch == epochs: save_dict = { "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "epoch": epoch, } if hasattr(criterion, "nu"): save_dict["criterion_nu"] = criterion.nu.cpu() torch.save(save_dict, out_dir / "checkpoints" / f"epoch_{epoch}.pt") logging.info(f"Saved checkpoint for epoch {epoch}.") Results and Visualization -------------------------------------------------------------------------------- .. code:: python # plot import json import math import matplotlib.pyplot as plt for metric, name in zip(["cross_entropy_loss", "val_loss", "val_acc"], ["Training Loss", "Validation Loss", "Validation Accuracy"]): if "loss" in metric: results_scent = [-math.log(1.0 / num_classes)] + [0.0] * 20 results_sox = [-math.log(1.0 / num_classes)] + [0.0] * 20 else: results_scent = [1.0 / num_classes] + [0.0] * 20 results_sox = [1.0 / num_classes] + [0.0] * 20 with open("./outputs/scent_treeoflife10m_subset/eval_scent_treeoflife10m_subset.jsonl", "r") as f: for line in f: eval_results = json.loads(line) results_scent[eval_results["epoch"]] = eval_results[metric] with open("./outputs/sox_treeoflife10m_subset/eval_sox_treeoflife10m_subset.jsonl", "r") as f: for line in f: eval_results = json.loads(line) results_sox[eval_results["epoch"]] = eval_results[metric] fig, ax = plt.subplots() ax.plot(results_scent, label="SCENT") ax.plot(results_sox, label="SOX") ax.set_xlabel("Epoch") ax.set_ylabel(name) ax.legend() plt.show() .. image:: ./imgs/SCENT_Extreme_Classification_27_0.png .. image:: ./imgs/SCENT_Extreme_Classification_27_1.png .. image:: ./imgs/SCENT_Extreme_Classification_27_2.png .. list-table:: Cross-entropy on training and validation datasets (example run) :header-rows: 1 :widths: 18 18 22 22 * - Method - Training Loss - Validation Loss - Validation Accuracy * - SCENT - 9.24 - 9.60 - 0.26 * - SOX - 9.73 - 9.95 - 0.20