Distributed Utilities¶
Helper functions for DDP all-gather with correct gradient handling. Located in imbalanced_losses.distributed.
imbalanced_losses.distributed.all_gather_with_grad(tensor)
¶
All-gather a tensor across all workers, preserving gradients for the local rank's slice.
Standard dist.all_gather returns detached tensors. This function
replaces the local rank's slice in the output with the original tensor,
so gradients flow back to the local model parameters. Other workers'
slices remain stop-gradient, matching DDP semantics (each worker
optimizes its own parameters via all-reduced gradients).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Local tensor to gather. Typically |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Concatenation of all workers' tensors along dim 0, shape
|
Notes
Dim 0 may vary across ranks (e.g. unequal last-batch sizes). When sizes differ, tensors are zero-padded to the max for the collective, then trimmed back to their true lengths before concatenation. An equal-size fast path skips padding when all ranks contribute the same number of rows.
All workers' queues stay synchronized automatically: since every worker
calls all_gather before passing to the loss, every worker enqueues
the same global-batch data. No extra synchronization is needed.
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If |
Examples:
Typical usage in a DDP training step::
from imbalanced_losses.distributed import all_gather_with_grad
logits_global = all_gather_with_grad(logits) # [sum(N_i), C]
targets_global = all_gather_no_grad(targets) # [sum(N_i)]
loss = loss_fn(logits_global, targets_global)
loss.backward()
Source code in src/imbalanced_losses/distributed.py
imbalanced_losses.distributed.all_gather_no_grad(tensor)
¶
All-gather a tensor across all workers without gradient tracking.
Intended for targets / labels, which are integer tensors with no gradient.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Local tensor to gather. Typically |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Concatenation of all workers' tensors along dim 0, shape
|
Notes
Dim 0 may vary across ranks (e.g. unequal last-batch sizes). When sizes differ, tensors are zero-padded to the max for the collective, then trimmed back to their true lengths before concatenation. An equal-size fast path skips padding when all ranks contribute the same number of rows.
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If |
Source code in src/imbalanced_losses/distributed.py
Usage pattern¶
from imbalanced_losses.distributed import all_gather_with_grad, all_gather_no_grad
# In a DDP training step:
logits_global = all_gather_with_grad(logits) # [sum(N_i), C] — grad flows
targets_global = all_gather_no_grad(targets) # [sum(N_i)] — no grad
loss = loss_fn(logits_global, targets_global)
loss.backward()
Behavior summary¶
| Function | Gradient | Use for |
|---|---|---|
all_gather_with_grad |
Flows through local rank's slice | Logits, embeddings |
all_gather_no_grad |
None | Integer targets, labels |
Both functions:
- Raise
RuntimeErroriftorch.distributedis not available or not initialized - Are no-ops (return input unchanged) when
world_size == 1 - Support variable dim-0 sizes across ranks (e.g. unequal last-batch without
drop_last=True). Tensors are padded to the max size for the collective, then trimmed. An equal-size fast path skips this overhead when all ranks have the same batch size.