Skip to content

Migrate from BCE / CrossEntropyLoss

This guide is a drop-in migration checklist for switching from BCEWithLogitsLoss or CrossEntropyLoss to the losses in this library. Each section covers one migration path.


BCE → SigmoidFocalLoss

SigmoidFocalLoss is a direct replacement for BCEWithLogitsLoss. It accepts the same input shape and handles binary ([N, 1]) and multi-label ([N, C]) tasks identically.

Before:

loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, targets.float())

After:

from imbalanced_losses import SigmoidFocalLoss

loss_fn = SigmoidFocalLoss(alpha=0.25, gamma=2.0, reduction="mean")
loss = loss_fn(logits, targets.float())

What changes: - Add alpha (class balance weight; alpha=-1 disables it) and gamma (focusing exponent) - Targets must still be float in [0, 1] - reduction options are the same: "mean", "sum", "none", plus "mean_positive"

What stays the same: - Input and target shapes - .backward() call — no changes to your training loop

Where to start: alpha=0.25, gamma=2.0 are the RetinaNet defaults and a reasonable first try. Set gamma=0 if you want alpha-weighted BCE without focusing.


CrossEntropyLoss → SoftmaxFocalLoss

SoftmaxFocalLoss is a direct replacement for CrossEntropyLoss. It accepts [N, C] logits and [N] long integer targets.

Before:

loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, targets)

After:

from imbalanced_losses import SoftmaxFocalLoss

loss_fn = SoftmaxFocalLoss(gamma=2.0, reduction="mean")
loss = loss_fn(logits, targets)

What changes: - Add gamma (focusing exponent; gamma=0 reduces to standard softmax cross-entropy) - Optionally add alpha (per-class weights, length C) to up-weight rare classes

What stays the same: - ignore_index is supported with the same semantics as CrossEntropyLoss - Input and target shapes: [N, C] logits, [N] long integer targets - Works with spatial/sequence inputs [N, C, *] the same way CrossEntropyLoss does

Computing per-class alpha from class counts:

counts = [(y_train == c).sum().item() for c in range(n_classes)]
total  = sum(counts)
alpha  = [total / (n_classes * max(c, 1)) for c in counts]
alpha  = [a / max(alpha) for a in alpha]  # normalize: max weight = 1.0

loss_fn = SoftmaxFocalLoss(alpha=alpha, gamma=2.0)


BCE / CE → SmoothAPLoss, RecallAtQuantileLoss, or PAUCAtBudgetLoss

Ranking losses require a warm start because their gradients are nearly zero when the model is random. Do not switch cold — wrap with LossWarmupWrapper to run BCE or CE for the first few epochs first.

Binary

from imbalanced_losses import SmoothAPLoss, LossWarmupWrapper

loss_fn = LossWarmupWrapper(
    warmup_loss=nn.BCEWithLogitsLoss(),
    main_loss=SmoothAPLoss(num_classes=1, queue_size=512),
    warmup_epochs=5,
    blend_epochs=2,
)

# In your training loop:
for epoch in range(total_epochs):
    loss_fn.on_train_epoch_start(epoch)
    for step, (xb, yb) in enumerate(loader):
        loss_fn.on_train_batch_start(global_step)
        loss = loss_fn(model(xb), yb.float().unsqueeze(1))  # [N, 1] float targets
        ...
        global_step += 1

Multiclass

from imbalanced_losses import SmoothAPLoss, LossWarmupWrapper

loss_fn = LossWarmupWrapper(
    warmup_loss=nn.CrossEntropyLoss(),
    main_loss=SmoothAPLoss(num_classes=n_classes, queue_size=1024),
    warmup_epochs=5,
    blend_epochs=2,
)

# CrossEntropyLoss and SmoothAPLoss share the same input format:
# logits [N, C] float, targets [N] long — no conversion needed.
for epoch in range(total_epochs):
    loss_fn.on_train_epoch_start(epoch)
    for step, (xb, yb) in enumerate(loader):
        loss_fn.on_train_batch_start(global_step)
        loss = loss_fn(model(xb), yb)
        ...
        global_step += 1

Common mistakes

Forgetting the epoch / step hooks

LossWarmupWrapper tracks phase transitions in on_train_epoch_start and temperature decay in on_train_batch_start. Missing either call means the loss never transitions from warmup.

# Wrong — loss stays in warmup forever
for epoch in range(total_epochs):
    for xb, yb in loader:
        loss = loss_fn(model(xb), yb)

# Correct
for epoch in range(total_epochs):
    loss_fn.on_train_epoch_start(epoch)   # ← required
    for step, (xb, yb) in enumerate(loader):
        loss_fn.on_train_batch_start(global_step)  # ← required for step-based or temp decay
        loss = loss_fn(model(xb), yb)
        global_step += 1

Target dtype mismatch

SigmoidFocalLoss and SmoothAPLoss(num_classes=1) expect float targets. SoftmaxFocalLoss and SmoothAPLoss(num_classes > 1) expect long integer targets. Swapping these silently produces wrong gradients or a runtime error.

# Binary — targets must be float
targets_float = targets.float()                     # or .unsqueeze(1) for [N, 1]

# Multiclass — targets must be long integer indices
targets_long = targets.long()

Skipping queue reset between training phases

The memory queue accumulates logits from previous batches. If you switch from a classifier pre-training phase to fine-tuning without resetting, stale logits from the old distribution contaminate the AP computation.

# Reset manually when starting a new training phase
loss_fn.main_loss.reset_queue()

# Or configure automatic reset in LossWarmupWrapper
loss_fn = LossWarmupWrapper(
    ...,
    reset_queue_each_epoch=True,   # useful when distribution shifts each epoch
)

Using ranking losses without warmup at all

# Wrong — gradients are flat at random initialization, training stalls
loss_fn = SmoothAPLoss(num_classes=5, queue_size=1024)

# Correct — warm up with a proxy loss first
loss_fn = LossWarmupWrapper(
    warmup_loss=nn.CrossEntropyLoss(),
    main_loss=SmoothAPLoss(num_classes=5, queue_size=1024),
    warmup_epochs=5,
)

Decision checklist

Use this to decide how far to migrate:

Situation Recommended loss
Small imbalance (< 10:1), fast iteration Start with CrossEntropyLoss + alpha weighting
Moderate imbalance, retinaNet-style detection SoftmaxFocalLoss(alpha=..., gamma=2.0)
High imbalance (> 50:1), optimize macro-AP SmoothAPLoss with warmup
High imbalance, optimize top-k recall RecallAtQuantileLoss with warmup
Fixed false-alarm budget / operating-point band PAUCAtBudgetLoss with warmup
Per-class monitoring needed Any ranking loss with return_per_class=True

See Assumptions and Failure Modes for when each loss is likely to underperform.