Skip to content

SoftmaxFocalLoss

Multiclass focal loss with softmax, for mutually-exclusive classification. Supports per-class alpha weighting, mean_positive reduction (RetinaNet convention), label smoothing, and arbitrary input shapes.

imbalanced_losses.focal_loss.SoftmaxFocalLoss

Bases: Module

Softmax Focal Loss for mutually-exclusive multiclass classification.

Generalises focal loss from the binary sigmoid case to C classes using softmax probabilities and standard cross-entropy as the base loss. Supports optional DDP all-gather so that positive-count-based normalisations (mean_positive) reflect the global batch.

Parameters:

Name Type Description Default
alpha Tensor or list[float] or None

Per-class weighting factors of shape (C,). Typically set to the inverse class frequency or similar. None disables class weighting. When provided, each sample's loss is scaled by alpha[y] where y is the ground-truth class.

None
gamma float

Focusing exponent. gamma=0 recovers vanilla CE. Default: 2.0.

2.0
reduction str

'none' | 'mean' | 'mean_positive' | 'sum'. Default: 'mean'.

  • 'mean': average over all valid (non-ignored) positions.
  • 'mean_positive': sum over ALL valid positions divided by the number of positive (non-background, non-ignored) positions. This is the RetinaNet convention and stabilises the loss scale when the vast majority of samples are background.
'mean'
label_smoothing float

Label-smoothing epsilon forwarded to F.cross_entropy. Default: 0.0.

0.0
ignore_index int

Class index to ignore (passed through to F.cross_entropy). Default: -100.

-100
background_class int

Class index treated as background/negative for the 'mean_positive' reduction denominator. Default: 0.

0
gather_distributed bool or None

Whether to all-gather inputs and targets across DDP workers before computing the loss. None (default) auto-detects: gathers when torch.distributed is initialized with world_size > 1. Set to False to opt out.

None
Notes

In DDP, mean_positive normalization is most affected by gathering: if positives are rare and unevenly distributed across ranks, the local positive count is noisy. Gathering ensures the denominator reflects the true global positive count.

Source code in src/imbalanced_losses/focal_loss.py
class SoftmaxFocalLoss(nn.Module):
    """
    Softmax Focal Loss for mutually-exclusive multiclass classification.

    Generalises focal loss from the binary sigmoid case to C classes using
    softmax probabilities and standard cross-entropy as the base loss.
    Supports optional DDP all-gather so that positive-count-based
    normalisations (``mean_positive``) reflect the global batch.

    Parameters
    ----------
    alpha : Tensor or list[float] or None
        Per-class weighting factors of shape (C,).  Typically set to the
        inverse class frequency or similar.  ``None`` disables class
        weighting.  When provided, each sample's loss is scaled by
        ``alpha[y]`` where ``y`` is the ground-truth class.
    gamma : float
        Focusing exponent.  ``gamma=0`` recovers vanilla CE.  Default: 2.0.
    reduction : str
        'none' | 'mean' | 'mean_positive' | 'sum'.  Default: 'mean'.

        - 'mean': average over all valid (non-ignored) positions.
        - 'mean_positive': sum over ALL valid positions divided by the number
          of positive (non-background, non-ignored) positions.  This is the
          RetinaNet convention and stabilises the loss scale when the vast
          majority of samples are background.
    label_smoothing : float
        Label-smoothing epsilon forwarded to ``F.cross_entropy``.
        Default: 0.0.
    ignore_index : int
        Class index to ignore (passed through to ``F.cross_entropy``).
        Default: -100.
    background_class : int
        Class index treated as background/negative for the
        ``'mean_positive'`` reduction denominator.  Default: 0.
    gather_distributed : bool or None, optional
        Whether to all-gather inputs and targets across DDP workers before
        computing the loss.  ``None`` (default) auto-detects: gathers when
        ``torch.distributed`` is initialized with world_size > 1.  Set to
        ``False`` to opt out.

    Notes
    -----
    In DDP, ``mean_positive`` normalization is most affected by gathering: if
    positives are rare and unevenly distributed across ranks, the local
    positive count is noisy.  Gathering ensures the denominator reflects the
    true global positive count.
    """

    def __init__(
        self,
        alpha: torch.Tensor | list[float] | None = None,
        gamma: float = 2.0,
        reduction: str = "mean",
        label_smoothing: float = 0.0,
        ignore_index: int = -100,
        background_class: int = 0,
        gather_distributed: bool | None = None,
    ):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        self.ignore_index = ignore_index
        self.background_class = background_class
        self.gather_distributed = gather_distributed
        self._gather_resolved: bool | None = None

        if alpha is not None:
            alpha = torch.as_tensor(alpha, dtype=torch.float32)
            if alpha.ndim != 1:
                raise ValueError("alpha must be a 1-D tensor of shape (C,).")
            self.register_buffer("alpha", alpha)
        else:
            self.alpha: torch.Tensor | None = None

    def _should_gather(self) -> bool:
        if self._gather_resolved is None:
            self._gather_resolved = _resolve_gather(self.gather_distributed)
        return self._gather_resolved

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs : Tensor
            Raw logits of shape ``(N, C)`` or ``(N, C, *)``.
        targets : Tensor
            Integer class labels of shape ``(N,)`` or ``(N, *)``.
            Values in ``[0, C)`` (plus ``ignore_index``).

        Returns
        -------
        Tensor
            Scalar or per-sample loss depending on ``reduction``.
        """
        if self._should_gather():
            inputs  = all_gather_with_grad(inputs)
            targets = all_gather_no_grad(targets)

        # ---- 1. Unreduced CE: shape matches targets ---------------------------
        ce_loss = F.cross_entropy(
            inputs,
            targets,
            reduction="none",
            label_smoothing=self.label_smoothing,
            ignore_index=self.ignore_index,
        )

        # ---- 2. Softmax probabilities → p_t for the true class ---------------
        log_probs = F.log_softmax(inputs, dim=1)  # (N, C, ...)

        # Reshape targets for gather along dim=1: (N, ...) → (N, 1, ...)
        gather_idx = targets.unsqueeze(1)

        # Clamp ignore_index entries so gather doesn't go out-of-bounds;
        # zero them out via valid_mask afterwards.
        valid_mask = targets != self.ignore_index
        safe_idx = gather_idx.clamp(0, inputs.size(1) - 1)

        log_p_t = log_probs.gather(1, safe_idx).squeeze(1)  # (N, ...)
        p_t = log_p_t.exp()  # probability assigned to the true class

        # ---- 3. Focal modulator: (1 - p_t)^gamma -----------------------------
        focal_weight = (1.0 - p_t) ** self.gamma
        loss = focal_weight * ce_loss

        # ---- 4. Per-class alpha weighting ------------------------------------
        if self.alpha is not None:
            safe_targets = targets.clamp(0, self.alpha.size(0) - 1)
            alpha_t = self.alpha[safe_targets]
            loss = alpha_t * loss

        # ---- 5. Mask out padding / ignored positions -------------------------
        # Always apply unconditionally — when no positions match ignore_index,
        # valid_mask is all-True and this is a no-op.
        loss = loss * valid_mask

        # ---- 6. Reduction ----------------------------------------------------
        if self.reduction == "mean":
            return loss.sum() / valid_mask.sum().clamp(min=1)

        if self.reduction == "mean_positive":
            positive_mask = valid_mask & (targets != self.background_class)
            return _reduce(loss, "mean_positive", valid_mask, positive_mask)

        return _reduce(loss, self.reduction, valid_mask)

forward(inputs, targets)

Parameters:

Name Type Description Default
inputs Tensor

Raw logits of shape (N, C) or (N, C, *).

required
targets Tensor

Integer class labels of shape (N,) or (N, *). Values in [0, C) (plus ignore_index).

required

Returns:

Type Description
Tensor

Scalar or per-sample loss depending on reduction.

Source code in src/imbalanced_losses/focal_loss.py
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Parameters
    ----------
    inputs : Tensor
        Raw logits of shape ``(N, C)`` or ``(N, C, *)``.
    targets : Tensor
        Integer class labels of shape ``(N,)`` or ``(N, *)``.
        Values in ``[0, C)`` (plus ``ignore_index``).

    Returns
    -------
    Tensor
        Scalar or per-sample loss depending on ``reduction``.
    """
    if self._should_gather():
        inputs  = all_gather_with_grad(inputs)
        targets = all_gather_no_grad(targets)

    # ---- 1. Unreduced CE: shape matches targets ---------------------------
    ce_loss = F.cross_entropy(
        inputs,
        targets,
        reduction="none",
        label_smoothing=self.label_smoothing,
        ignore_index=self.ignore_index,
    )

    # ---- 2. Softmax probabilities → p_t for the true class ---------------
    log_probs = F.log_softmax(inputs, dim=1)  # (N, C, ...)

    # Reshape targets for gather along dim=1: (N, ...) → (N, 1, ...)
    gather_idx = targets.unsqueeze(1)

    # Clamp ignore_index entries so gather doesn't go out-of-bounds;
    # zero them out via valid_mask afterwards.
    valid_mask = targets != self.ignore_index
    safe_idx = gather_idx.clamp(0, inputs.size(1) - 1)

    log_p_t = log_probs.gather(1, safe_idx).squeeze(1)  # (N, ...)
    p_t = log_p_t.exp()  # probability assigned to the true class

    # ---- 3. Focal modulator: (1 - p_t)^gamma -----------------------------
    focal_weight = (1.0 - p_t) ** self.gamma
    loss = focal_weight * ce_loss

    # ---- 4. Per-class alpha weighting ------------------------------------
    if self.alpha is not None:
        safe_targets = targets.clamp(0, self.alpha.size(0) - 1)
        alpha_t = self.alpha[safe_targets]
        loss = alpha_t * loss

    # ---- 5. Mask out padding / ignored positions -------------------------
    # Always apply unconditionally — when no positions match ignore_index,
    # valid_mask is all-True and this is a no-op.
    loss = loss * valid_mask

    # ---- 6. Reduction ----------------------------------------------------
    if self.reduction == "mean":
        return loss.sum() / valid_mask.sum().clamp(min=1)

    if self.reduction == "mean_positive":
        positive_mask = valid_mask & (targets != self.background_class)
        return _reduce(loss, "mean_positive", valid_mask, positive_mask)

    return _reduce(loss, self.reduction, valid_mask)

Quick examples

Standard multiclass

from imbalanced_losses import SoftmaxFocalLoss
import torch

loss_fn = SoftmaxFocalLoss(gamma=2.0, reduction="mean")
logits  = torch.randn(32, 10)
targets = torch.randint(0, 10, (32,))

loss = loss_fn(logits, targets)
loss.backward()

RetinaNet-style detection

loss_fn = SoftmaxFocalLoss(
    gamma=2.0,
    alpha=[0.25] * 10,         # per-class weights
    reduction="mean_positive",  # normalize by foreground count
    background_class=0,
    ignore_index=-100,
)
loss = loss_fn(logits, targets)

Dense prediction (spatial inputs)

# [N, C, H, W] logits, [N, H, W] targets
logits  = torch.randn(4, 10, 64, 64)
targets = torch.randint(0, 10, (4, 64, 64))
loss = loss_fn(logits, targets)

Parameter guidance

Parameter Default Notes
alpha None Per-class 1-D tensor or list; None disables class weighting
gamma 2.0 Higher = harder focus; 0 = vanilla cross-entropy
reduction "mean" "mean_positive" normalizes by foreground count (detection tasks)
background_class 0 Class excluded from mean_positive denominator
ignore_index -100 Padded positions — zero loss, zero gradient
label_smoothing 0.0 Forwarded to F.cross_entropy

mean_positive reduction semantics

  • Numerator: sum of loss over all valid (non-ignored) positions, including background
  • Denominator: count of non-background, non-ignored positions only
  • This matches the original RetinaNet implementation and stabilizes loss scale when positives are rare