Memory Queue Design¶
The problem: sparse positives per batch¶
SmoothAPLoss estimates Average Precision by comparing every positive in the pool against every other sample. At a 1% positive rate with a batch size of 32, you expect about 0 or 1 positives per batch. A pool of 1–2 positives produces a near-zero AP estimate with essentially no gradient signal.
The memory queue solves this by accumulating past batches. With a queue_size=1024 and a batch size of 32, the total pool is ~1056 samples, yielding ~10 positives at a 1% rate — enough for a stable AP estimate.
How the circular buffer works¶
The queue is a fixed-size circular buffer of (logits, targets) rows. On every training forward pass:
- The live batch is appended to the queue contents to form the full pool
- AP or recall is computed on the full pool
- The live batch is written into the queue, overwriting the oldest entries
Queue entries are stored detached — no gradient flows through queued logits. Gradients only flow through the live batch's portion of the pool. This is important: you cannot backpropagate through historical logits that were computed by a previous version of the model.
Why detaching queue logits is correct¶
At first this seems like it would bias the soft-rank computation: the rank of a live positive is estimated relative to a pool that includes stale, detached logits. In practice this bias is small because:
- The queue rotates, so entries are never more than
queue_size / batch_sizesteps old - Soft ranks are a sum over the pool, so the live-batch contribution is fully differentiable
- The queue mainly provides a reference distribution for the rank, not a gradient signal
This is the same reasoning used in MoCo-style contrastive learning, where negative key embeddings are also kept in a queue with detached gradients.
Queue poisoning and the phase switch¶
When LossWarmupWrapper switches from BCE warmup to AP loss, the queue contains logits from a model trained with BCE. These "warmup-era" logits may have a very different score distribution than the AP-phase model — the ranking statistics are meaningless.
If these stale entries remain in the queue, the AP loss computes ranks relative to a corrupted reference distribution for the first queue_size / batch_size batches of the AP phase.
LossWarmupWrapper prevents this by automatically calling main_loss.reset_queue() at the exact step of the phase switch. After reset, the queue fills with AP-phase logits over the next few batches before the full pool is used.
When to reset manually¶
- Between training and validation:
loss_fn.reset_queue()before the val loop prevents training logits from appearing in val-phase AP estimates - After changing model architecture or checkpoint
- When
reset_queue_each_epoch=TrueinLossWarmupWrapperis set — useful when the model changes significantly epoch-to-epoch and stale logits would bias ranking
Queue size vs. pool size limits¶
The core AP computation is O(|P| × M) where M = batch + queue. At low positive rates this is closer to O(|P| × M) than O(M²), but M still has a practical upper limit of ~4096 for reasonable training step times on a single GPU. At a 0.5% positive rate with M=4096, you get ~20 positives — a comfortable signal.
DDP queue synchronization¶
In distributed training, every worker calls all_gather before passing to the loss. This means every worker sees the same global batch and enqueues the same data. No explicit queue synchronization across workers is needed — they are identical by construction.