Skip to content

Log Per-Class Metrics

SmoothAPLoss, RecallAtQuantileLoss, and PAUCAtBudgetLoss all support returning per-class loss values alongside the aggregated scalar, without requiring a second forward pass.

Retrieve per-class losses

Pass return_per_class=True:

import torch
from imbalanced_losses import SmoothAPLoss

loss_fn = SmoothAPLoss(num_classes=4, queue_size=1024)
logits  = torch.randn(32, 4)
targets = torch.randint(0, 4, (32,))

loss, per_class, valid = loss_fn(logits, targets, return_per_class=True)
loss.backward()

# per_class: shape [C], nan for degenerate classes
# valid:     shape [C], bool — True for classes with at least one pos and one neg

Log in PyTorch Lightning

def training_step(self, batch, batch_idx):
    logits, targets = batch
    loss, per_class, valid = self.loss_fn(logits, targets, return_per_class=True)

    self.log("train/loss", loss)
    for c in valid.nonzero(as_tuple=True)[0].tolist():
        self.log(f"train/ap_loss_class_{c}", per_class[c])

    return loss

Only classes in valid are logged — degenerate classes (all-positive or all-negative in the current pool) have nan values and are skipped automatically by the valid mask.

Use with RecallAtQuantileLoss

The same pattern applies to RecallAtQuantileLoss:

from imbalanced_losses import RecallAtQuantileLoss

loss_fn = RecallAtQuantileLoss(num_classes=4, quantile=0.005, queue_size=1024)
loss, per_class, valid = loss_fn(logits, targets, return_per_class=True)

for c in valid.nonzero(as_tuple=True)[0].tolist():
    print(f"Class {c} recall-loss: {per_class[c].item():.4f}")

Use with PAUCAtBudgetLoss

PAUCAtBudgetLoss supports return_per_class=True with the same three-value tuple pattern:

from imbalanced_losses import PAUCAtBudgetLoss

loss_fn = PAUCAtBudgetLoss(num_classes=4, alpha=0.0, beta=0.005, queue_size=1024)
loss, per_class, valid = loss_fn(logits, targets, return_per_class=True)

for c in valid.nonzero(as_tuple=True)[0].tolist():
    print(f"Class {c} pAUC-loss: {per_class[c].item():.4f}")

It also supports return_diagnostics=True, which returns per-class statistics alongside the loss. You can combine both:

# Diagnostics only
loss, stats = loss_fn(logits, targets, return_diagnostics=True)
# stats: per-class [C] tensors — t_alpha, t_beta, tau_eff, band_neg_count, pauc_var, grad_pos_count

# Both per-class losses and diagnostics
loss, per_class, valid, stats = loss_fn(logits, targets, return_per_class=True, return_diagnostics=True)

Log both in a Lightning training step:

def training_step(self, batch, batch_idx):
    logits, targets = batch
    loss, per_class, valid, stats = self.loss_fn(
        logits, targets, return_per_class=True, return_diagnostics=True
    )

    self.log("train/loss", loss)
    for c in valid.nonzero(as_tuple=True)[0].tolist():
        self.log(f"train/pauc_loss_class_{c}", per_class[c])
        self.log(f"train/band_neg_count_class_{c}", stats["band_neg_count"][c])

    return loss

band_neg_count and grad_pos_count are the key health indicators: near-zero values signal band or gradient starvation before the loss itself shows unusual behavior.

Use with LossWarmupWrapper

**kwargs (including return_per_class=True) are forwarded to main_loss only when main_weight >= 1.0 — i.e. final_main_weight == 1.0 (default) and the blend period has ended. During warmup, blend, or when final_main_weight < 1.0, they are silently ignored:

result = self.loss_fn(logits, targets, return_per_class=True)

if isinstance(result, tuple):
    loss, per_class, valid = result
    for c in valid.nonzero(as_tuple=True)[0].tolist():
        self.log(f"train/ap_class_{c}", per_class[c])
else:
    loss = result

return loss

Confirm: During warmup result is a plain scalar tensor. After blend it is a (loss, per_class, valid) tuple.

See also

examples/per_class_metrics_demo.py — runnable script demonstrating return_per_class=True for SmoothAPLoss and RecallAtQuantileLoss, including the valid_mask guard pattern. PAUCAtBudgetLoss uses the same pattern and additionally accepts return_diagnostics=True.