Skip to content

SigmoidFocalLoss

Binary / multi-label focal loss operating on raw logits with a sigmoid activation. Drop-in replacement for BCEWithLogitsLoss for imbalanced problems.

Multi-label vs. multiclass

SigmoidFocalLoss applies a sigmoid independently to each logit, so every output is a separate binary prediction. Use it when a sample can belong to multiple classes at once (multi-label), or for a single yes/no decision (binary). If your classes are mutually exclusive — each sample has exactly one correct class — use SoftmaxFocalLoss instead.

imbalanced_losses.focal_loss.SigmoidFocalLoss

Bases: Module

Sigmoid Focal Loss as used in RetinaNet.

Binary / multi-label variant operating on raw logits with sigmoid activation. Supports optional DDP all-gather so that the global batch is seen when computing mean/sum reductions.

.. note:: Multi-label vs. multiclass: This loss treats every output logit as an independent binary prediction (sigmoid per element). Use it when a sample can belong to multiple classes simultaneously (multi-label), or for a single yes/no prediction (binary). If your classes are mutually exclusive — each sample belongs to exactly one class — use :class:SoftmaxFocalLoss instead, which couples the outputs via softmax.

Parameters:

Name Type Description Default
alpha float

Weighting factor in [0, 1] to balance positives vs negatives, or -1 to ignore. Default: 0.25.

0.25
gamma float

Exponent of the modulating factor (1 - p_t). Default: 2.

2.0
reduction str

'none' | 'mean' | 'sum'. Default: 'mean'.

'mean'
gather_distributed bool or None

Whether to all-gather inputs and targets across DDP workers before computing the loss. None (default) auto-detects: gathers when torch.distributed is initialized with world_size > 1. Set to False to opt out.

None
Source code in src/imbalanced_losses/focal_loss.py
class SigmoidFocalLoss(nn.Module):
    """
    Sigmoid Focal Loss as used in RetinaNet.

    Binary / multi-label variant operating on raw logits with sigmoid activation.
    Supports optional DDP all-gather so that the global batch is seen when
    computing mean/sum reductions.

    .. note::
        **Multi-label vs. multiclass:** This loss treats every output logit as an
        *independent* binary prediction (sigmoid per element).  Use it when a
        sample can belong to *multiple* classes simultaneously (multi-label), or
        for a single yes/no prediction (binary).  If your classes are
        *mutually exclusive* — each sample belongs to exactly one class — use
        :class:`SoftmaxFocalLoss` instead, which couples the outputs via softmax.

    Parameters
    ----------
    alpha : float
        Weighting factor in [0, 1] to balance positives vs negatives, or -1 to
        ignore. Default: 0.25.
    gamma : float
        Exponent of the modulating factor (1 - p_t). Default: 2.
    reduction : str
        'none' | 'mean' | 'sum'. Default: 'mean'.
    gather_distributed : bool or None, optional
        Whether to all-gather inputs and targets across DDP workers before
        computing the loss.  ``None`` (default) auto-detects: gathers when
        ``torch.distributed`` is initialized with world_size > 1.  Set to
        ``False`` to opt out.
    """

    def __init__(
        self,
        alpha: float = 0.25,
        gamma: float = 2.0,
        reduction: str = "mean",
        gather_distributed: bool | None = None,
    ):
        super().__init__()
        if not (0.0 <= alpha <= 1.0) and alpha != -1:
            raise ValueError(
                f"Invalid alpha value: {alpha}. alpha must be in [0, 1] or -1 for ignore."
            )
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.gather_distributed = gather_distributed
        self._gather_resolved: bool | None = None

    def _should_gather(self) -> bool:
        if self._gather_resolved is None:
            self._gather_resolved = _resolve_gather(self.gather_distributed)
        return self._gather_resolved

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs : Tensor
            Raw logits, arbitrary shape.
        targets : Tensor
            Same shape, float 0/1 labels.

        Returns
        -------
        Tensor
            Scalar or per-element loss depending on ``reduction``.
        """
        if self._should_gather():
            inputs  = all_gather_with_grad(inputs)
            targets = all_gather_no_grad(targets)

        p = torch.sigmoid(inputs)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        p_t = p * targets + (1 - p) * (1 - targets)
        loss = ce_loss * ((1 - p_t) ** self.gamma)

        if self.alpha >= 0:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            loss = alpha_t * loss

        return _reduce(loss, self.reduction)

forward(inputs, targets)

Parameters:

Name Type Description Default
inputs Tensor

Raw logits, arbitrary shape.

required
targets Tensor

Same shape, float 0/1 labels.

required

Returns:

Type Description
Tensor

Scalar or per-element loss depending on reduction.

Source code in src/imbalanced_losses/focal_loss.py
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Parameters
    ----------
    inputs : Tensor
        Raw logits, arbitrary shape.
    targets : Tensor
        Same shape, float 0/1 labels.

    Returns
    -------
    Tensor
        Scalar or per-element loss depending on ``reduction``.
    """
    if self._should_gather():
        inputs  = all_gather_with_grad(inputs)
        targets = all_gather_no_grad(targets)

    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** self.gamma)

    if self.alpha >= 0:
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        loss = alpha_t * loss

    return _reduce(loss, self.reduction)

Quick example

from imbalanced_losses import SigmoidFocalLoss
import torch

loss_fn = SigmoidFocalLoss(alpha=0.25, gamma=2.0)
logits  = torch.randn(32, 1)
targets = torch.randint(0, 2, (32, 1)).float()

loss = loss_fn(logits, targets)
loss.backward()

Parameter guidance

Parameter Default Effect
alpha 0.25 Weights positives; set to -1 to disable
gamma 2.0 Higher = more focus on hard examples; 0 = vanilla BCE
reduction "mean" "mean" averages over elements; "sum" for total; "none" returns per-element tensor
gather_distributed None Auto-detects DDP; set False to opt out