SigmoidFocalLoss¶
Binary / multi-label focal loss operating on raw logits with a sigmoid activation. Drop-in replacement for BCEWithLogitsLoss for imbalanced problems.
Multi-label vs. multiclass
SigmoidFocalLoss applies a sigmoid independently to each logit, so every output is a separate binary prediction. Use it when a sample can belong to multiple classes at once (multi-label), or for a single yes/no decision (binary). If your classes are mutually exclusive — each sample has exactly one correct class — use SoftmaxFocalLoss instead.
imbalanced_losses.focal_loss.SigmoidFocalLoss
¶
Bases: Module
Sigmoid Focal Loss as used in RetinaNet.
Binary / multi-label variant operating on raw logits with sigmoid activation. Supports optional DDP all-gather so that the global batch is seen when computing mean/sum reductions.
.. note::
Multi-label vs. multiclass: This loss treats every output logit as an
independent binary prediction (sigmoid per element). Use it when a
sample can belong to multiple classes simultaneously (multi-label), or
for a single yes/no prediction (binary). If your classes are
mutually exclusive — each sample belongs to exactly one class — use
:class:SoftmaxFocalLoss instead, which couples the outputs via softmax.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha
|
float
|
Weighting factor in [0, 1] to balance positives vs negatives, or -1 to ignore. Default: 0.25. |
0.25
|
gamma
|
float
|
Exponent of the modulating factor (1 - p_t). Default: 2. |
2.0
|
reduction
|
str
|
'none' | 'mean' | 'sum'. Default: 'mean'. |
'mean'
|
gather_distributed
|
bool or None
|
Whether to all-gather inputs and targets across DDP workers before
computing the loss. |
None
|
Source code in src/imbalanced_losses/focal_loss.py
forward(inputs, targets)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
Tensor
|
Raw logits, arbitrary shape. |
required |
targets
|
Tensor
|
Same shape, float 0/1 labels. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Scalar or per-element loss depending on |
Source code in src/imbalanced_losses/focal_loss.py
Quick example¶
from imbalanced_losses import SigmoidFocalLoss
import torch
loss_fn = SigmoidFocalLoss(alpha=0.25, gamma=2.0)
logits = torch.randn(32, 1)
targets = torch.randint(0, 2, (32, 1)).float()
loss = loss_fn(logits, targets)
loss.backward()
Parameter guidance¶
| Parameter | Default | Effect |
|---|---|---|
alpha |
0.25 |
Weights positives; set to -1 to disable |
gamma |
2.0 |
Higher = more focus on hard examples; 0 = vanilla BCE |
reduction |
"mean" |
"mean" averages over elements; "sum" for total; "none" returns per-element tensor |
gather_distributed |
None |
Auto-detects DDP; set False to opt out |