Skip to content

PAUCAtBudgetLoss

Differentiable partial-AUC-over-an-FPR-band loss with an optional memory queue. Optimizes the normalized partial AUC over a false-positive-rate band [alpha, beta] that brackets a target operating point (e.g. FPR ≈ 0.005 / 50 bps), rather than the full AUC or a single-threshold recall. Useful when you care about recall at a fixed, low false-alarm budget (fraud, screening, alerting).

imbalanced_losses.pauc_loss.PAUCAtBudgetLoss

Bases: _QueuedRankingLoss

Differentiable partial-AUC-over-an-FPR-band loss with an optional memory queue.

For each class, FPR-band edges t_alpha/t_beta are estimated from the pooled iid negatives (stop-gradient) as score quantiles, a scale-aware sigmoid temperature tau_eff is derived from a detached robust dispersion of those negatives, and the normalized partial AUC over [alpha, beta] is optimized via a trapezoid (default) or band-restricted pairwise surrogate. Loss is 1 - pauc.

Multi-class: one-vs-rest per class using logits[:, c], then reduce. Binary: logits[:, 0] with targets in {0, 1}.

Inherits queue management, DDP gather, ignore-index filtering, subsampling, and reduction logic from _QueuedRankingLoss.

Degenerate-dispersion guard: if the robust dispersion of iid-negative scores for a class is at or below _SCALE_EPS (all-equal scores, a collapsed band with tau_scale='band', or too few iid negatives to resolve the tail quantile), that class is marked INVALID and excluded from reduction rather than silently computing with a near-zero temperature. A one-time UserWarning is emitted on the first such occurrence. To avoid degenerate classes, increase queue_size or ensure iid negatives cover a meaningful score range.

The band-edge approximation of population FPR is reliable when the pooled iid-negative count substantially exceeds 1/alpha. Monitor band_neg_count in diagnostics and set queue_size accordingly.

The recommended band convention is alpha ≈ 0, beta ≈ budget (where budget is the target FPR, e.g. beta=0.005 for a 50 bps operating point). This sets the upper threshold t_alpha = quantile(neg, 1.0) = max(neg) and the lower threshold t_beta = quantile(neg, 1 - budget), so the band covers every false-positive that falls above the budget threshold, i.e. all negatives scoring at or above the operating point.

The older convention [budget/2, 1.5·budget] — e.g. [0.0025, 0.0075] for a 50 bps point — excludes the highest-scoring (worst) negatives via its lower edge alpha = budget/2 and extends below the operating threshold via its upper edge beta = 1.5·budget. A band sweep (8 seeds, synthetic contested-top extreme-imbalance data, 50 bps budget) found coverage@budget to be monotone in both edges: smaller alpha and smaller beta are better in every cell, and the old convention sits in the poorly-performing high-alpha region (the worst cell being alpha=budget/2, beta=2.5*budget). The recommended alpha=0, beta=budget band fixes both defects by contrasting positives against all false-positives at the budget.

Caveats: the sweep evidence is synthetic, at a single budget (50 bps), and in the contested-top regime. The improvement is concentrated at pos_rate ≪ budget; once pos_rate ≥ budget, coverage@budget is mechanically capped at budget/pos_rate and band choice is irrelevant.

Parameters:

Name Type Description Default
num_classes int

Number of output classes. Use 1 for binary mode.

required
alpha float

Lower FPR band edge. Must satisfy 0 <= alpha < beta <= 1. alpha=0 sets t_alpha = max(neg_iid), contrasting positives against all negatives above the budget threshold. Default: 0.0 (recommended for contested-top extreme-imbalance).

0.0
beta float

Upper FPR band edge. Must satisfy 0 <= alpha < beta <= 1. Set to your target operating-point FPR (e.g. 0.005 for 50 bps). Default: 0.005.

0.005
surrogate ('trapezoid', 'pairwise')

pAUC estimator. 'trapezoid' (default) integrates soft-TPR over n_knots FPR knots; gradient flows through positives only. 'pairwise' compares positives against band negatives drawn from the gradient pool (band negatives carry gradient). Default: 'trapezoid'.

'trapezoid'
n_knots int

Number of equally-spaced FPR knots in [alpha, beta] for the trapezoid surrogate (knot 0 = alpha, knot n_knots-1 = beta). Must be

= 2. Ignored when surrogate='pairwise'. Default: 2.

The default of 2 (trapezoid rule) is accurate for narrow bands where TPR(FPR) is approximately linear over [alpha, beta]; the integration error scales as (beta - alpha)^3 * TPR''. For wide bands where TPR curvature is non-negligible, n_knots >= 3 is recommended.

2
tau_scale ('iqr', 'band')

Robust dispersion used to make the temperature scale-aware. 'iqr' (default) uses IQR(neg_iid) -- a stable bulk statistic (pair with small temperature, e.g. 0.1). 'band' uses t_alpha - t_beta -- sized directly to the operating region (pair with temperature near 1.0; recommended for wide/volatile bands). Default: 'iqr'.

'iqr'
pos_numerator ('pool', 'live')

Which positives form the soft-TPR numerator (and the pairwise positive set). 'pool' (default) uses all pooled positives (live batch + queue), matching the queue's stabilising role but diluting the live-positive gradient by 1 / |P_pool| when the queue holds many detached positives. 'live' uses only the live-batch positives, giving an undiluted gradient -- useful at extreme imbalance where the queue swamps the few live positives -- at the cost of a higher-variance TPR estimate (mean over the ~few live positives). Thresholds and tau_eff always use the full pooled iid negatives regardless of this setting. A class with no live positives in a step is skipped (invalid) under 'live'. Default: 'pool'.

'pool'
queue_size int

Circular buffer size (rows). Larger queues stabilise the quantile-based band edges -- at low FPR you need many negatives for a meaningful tail quantile. Set to 0 to disable. Default: 1024.

DDP note: when gather_distributed=True, the all-gather runs before the enqueue, so each rank stores global-batch rows. The effective pool per forward pass is already global_batch_size + queue_size.

1024
temperature float

Dimensionless multiplier on tau_eff = temperature * scale. Larger values give smoother gradients but bias soft-TPR toward 0.5; smaller values approximate true TPR but risk sigmoid saturation. Default: 0.1.

0.1
reduction ('mean', 'sum', 'none')

How to aggregate per-class losses. - 'mean': scalar average over valid classes. - 'sum': scalar sum over valid classes. - 'none': tensor of shape [C]; invalid classes are nan. Default: 'mean'.

'mean'
ignore_index int

Target value marking padded positions. Excluded from threshold estimation and the positive set. Default: -100.

-100
update_queue_in_eval bool

If False (default), the queue is frozen during eval mode. Default: False.

False
gather_distributed bool or None

Whether to all-gather logits, targets, and the iid mask across DDP workers before computing the loss. None (default) auto-detects: gathers when torch.distributed is initialized with world_size > 1. Set False to explicitly disable. Resolved once on first forward call. Default: None.

None
quantile_interpolation str

Interpolation method passed to torch.quantile for the band edges. 'higher' is the conservative default. One of ('linear', 'lower', 'higher', 'nearest', 'midpoint'). Default: 'higher'.

'higher'
max_pool_size int or None

Maximum number of rows in the ranking pool (live batch + queue after ignore_index filtering). When exceeded, minimum-quota subsampling caps it. See RecallAtQuantileLoss for details. None (default) disables the cap.

None

Examples:

>>> loss_fn = PAUCAtBudgetLoss(num_classes=4, alpha=0.0, beta=0.005)
>>> logits  = torch.randn(256, 4)
>>> targets = torch.randint(0, 4, (256,))
>>> loss = loss_fn(logits, targets)
>>> loss.backward()
Notes

The iid-negative band edges depend only on rows flagged iid_mask=True; appending non-iid negatives (caller-side densification) does not shift t_alpha/t_beta. iid_mask=None treats all rows as iid (the common case when negatives are never densified by class).

Trapezoid cost is O(|P| x n_knots); pairwise cost is O(|P| x |band|). No O(M^2) path.

References

This loss is an original design, not a published method, but the partial-AUC-over-a-band objective and its estimators build on prior work:

.. [1] D. K. McClish (1989). "Analyzing a Portion of the ROC Curve." Medical Decision Making 9(3), 190-195. -- partial AUC over an ROC region. .. [2] L. E. Dodd and M. S. Pepe (2003). "Partial AUC Estimation and Regression." Biometrics 59(3), 614-623. -- nonparametric pAUC estimator (the consistent plug-in the trapezoid surrogate relates to). .. [3] H. Narasimhan and S. Agarwal (2013). "A Structural SVM Based Approach for Optimizing Partial AUC." ICML 2013. -- optimizing pAUC over an FPR band [alpha, beta] as a learning objective (the KDD 2013 "tight" variant gives the boundary-corrected estimator). .. [4] D. Zhu, G. Li, B. Wang, X. Wu, and T. Yang (2022). "When AUC meets DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence Guarantee." ICML 2022. -- deep one-/two-way pAUC surrogate optimization.

Source code in src/imbalanced_losses/pauc_loss.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
class PAUCAtBudgetLoss(_QueuedRankingLoss):
    """
    Differentiable partial-AUC-over-an-FPR-band loss with an optional memory queue.

    For each class, FPR-band edges ``t_alpha``/``t_beta`` are estimated from the
    pooled *iid negatives* (stop-gradient) as score quantiles, a scale-aware
    sigmoid temperature ``tau_eff`` is derived from a detached robust dispersion
    of those negatives, and the normalized partial AUC over ``[alpha, beta]`` is
    optimized via a trapezoid (default) or band-restricted pairwise surrogate.
    Loss is ``1 - pauc``.

    Multi-class: one-vs-rest per class using logits[:, c], then reduce.
    Binary:      logits[:, 0] with targets in {0, 1}.

    Inherits queue management, DDP gather, ignore-index filtering, subsampling,
    and reduction logic from ``_QueuedRankingLoss``.

    **Degenerate-dispersion guard:** if the robust dispersion of iid-negative
    scores for a class is at or below ``_SCALE_EPS`` (all-equal scores, a
    collapsed band with ``tau_scale='band'``, or too few iid negatives to
    resolve the tail quantile), that class is marked INVALID and excluded from
    reduction rather than silently computing with a near-zero temperature.  A
    one-time ``UserWarning`` is emitted on the first such occurrence.  To avoid
    degenerate classes, increase ``queue_size`` or ensure iid negatives cover a
    meaningful score range.

    The band-edge approximation of population FPR is reliable when the pooled
    iid-negative count substantially exceeds ``1/alpha``.  Monitor
    ``band_neg_count`` in diagnostics and set ``queue_size`` accordingly.

    The recommended band convention is ``alpha ≈ 0, beta ≈ budget`` (where
    ``budget`` is the target FPR, e.g. ``beta=0.005`` for a 50 bps operating
    point).  This sets the upper threshold ``t_alpha = quantile(neg, 1.0) =
    max(neg)`` and the lower threshold ``t_beta = quantile(neg, 1 - budget)``,
    so the band covers every false-positive that falls above the budget
    threshold, i.e. all negatives scoring at or above the operating point.

    The older convention ``[budget/2, 1.5·budget]`` — e.g. ``[0.0025, 0.0075]``
    for a 50 bps point — excludes the highest-scoring (worst) negatives via its
    lower edge ``alpha = budget/2`` and extends below the operating threshold via
    its upper edge ``beta = 1.5·budget``.  A band sweep (8 seeds, synthetic
    contested-top extreme-imbalance data, 50 bps budget) found coverage@budget
    to be monotone in both edges: smaller ``alpha`` and smaller ``beta`` are
    better in every cell, and the old convention sits in the poorly-performing
    high-``alpha`` region (the worst cell being ``alpha=budget/2, beta=2.5*budget``).
    The recommended ``alpha=0, beta=budget`` band fixes both defects by
    contrasting positives against all false-positives at the budget.

    **Caveats:** the sweep evidence is synthetic, at a single budget (50 bps),
    and in the contested-top regime.  The improvement is concentrated at
    ``pos_rate ≪ budget``; once ``pos_rate ≥ budget``, coverage@budget is
    mechanically capped at ``budget/pos_rate`` and band choice is irrelevant.

    Parameters
    ----------
    num_classes : int
        Number of output classes. Use 1 for binary mode.
    alpha : float, optional
        Lower FPR band edge. Must satisfy ``0 <= alpha < beta <= 1``.
        ``alpha=0`` sets ``t_alpha = max(neg_iid)``, contrasting positives
        against all negatives above the budget threshold.
        Default: 0.0 (recommended for contested-top extreme-imbalance).
    beta : float, optional
        Upper FPR band edge. Must satisfy ``0 <= alpha < beta <= 1``.
        Set to your target operating-point FPR (e.g. ``0.005`` for 50 bps).
        Default: 0.005.
    surrogate : {'trapezoid', 'pairwise'}, optional
        pAUC estimator. ``'trapezoid'`` (default) integrates soft-TPR over
        ``n_knots`` FPR knots; gradient flows through positives only.
        ``'pairwise'`` compares positives against band negatives drawn from the
        gradient pool (band negatives carry gradient). Default: 'trapezoid'.
    n_knots : int, optional
        Number of equally-spaced FPR knots in ``[alpha, beta]`` for the
        trapezoid surrogate (knot 0 = alpha, knot n_knots-1 = beta). Must be
        >= 2. Ignored when ``surrogate='pairwise'``. Default: 2.

        The default of 2 (trapezoid rule) is accurate for narrow bands where
        TPR(FPR) is approximately linear over ``[alpha, beta]``; the
        integration error scales as ``(beta - alpha)^3 * TPR''``.  For wide
        bands where TPR curvature is non-negligible, ``n_knots >= 3`` is
        recommended.
    tau_scale : {'iqr', 'band'}, optional
        Robust dispersion used to make the temperature scale-aware.
        ``'iqr'`` (default) uses ``IQR(neg_iid)`` -- a stable bulk statistic
        (pair with small ``temperature``, e.g. 0.1). ``'band'`` uses
        ``t_alpha - t_beta`` -- sized directly to the operating region (pair
        with ``temperature`` near 1.0; recommended for wide/volatile bands).
        Default: 'iqr'.
    pos_numerator : {'pool', 'live'}, optional
        Which positives form the soft-TPR numerator (and the pairwise positive
        set). ``'pool'`` (default) uses all pooled positives (live batch + queue),
        matching the queue's stabilising role but diluting the live-positive
        gradient by ``1 / |P_pool|`` when the queue holds many detached positives.
        ``'live'`` uses only the live-batch positives, giving an undiluted
        gradient -- useful at extreme imbalance where the queue swamps the few
        live positives -- at the cost of a higher-variance TPR estimate (mean over
        the ~few live positives). Thresholds and ``tau_eff`` always use the full
        pooled iid negatives regardless of this setting. A class with no live
        positives in a step is skipped (invalid) under ``'live'``. Default: 'pool'.
    queue_size : int, optional
        Circular buffer size (rows). Larger queues stabilise the quantile-based
        band edges -- at low FPR you need many negatives for a meaningful tail
        quantile. Set to 0 to disable. Default: 1024.

        **DDP note:** when ``gather_distributed=True``, the all-gather runs
        *before* the enqueue, so each rank stores global-batch rows. The
        effective pool per forward pass is already
        ``global_batch_size + queue_size``.
    temperature : float, optional
        Dimensionless multiplier on ``tau_eff = temperature * scale``. Larger
        values give smoother gradients but bias soft-TPR toward 0.5; smaller
        values approximate true TPR but risk sigmoid saturation. Default: 0.1.
    reduction : {'mean', 'sum', 'none'}, optional
        How to aggregate per-class losses.
        - 'mean': scalar average over valid classes.
        - 'sum':  scalar sum over valid classes.
        - 'none': tensor of shape [C]; invalid classes are nan.
        Default: 'mean'.
    ignore_index : int, optional
        Target value marking padded positions. Excluded from threshold
        estimation and the positive set. Default: -100.
    update_queue_in_eval : bool, optional
        If False (default), the queue is frozen during eval mode. Default: False.
    gather_distributed : bool or None, optional
        Whether to all-gather logits, targets, and the iid mask across DDP
        workers before computing the loss. ``None`` (default) auto-detects:
        gathers when ``torch.distributed`` is initialized with world_size > 1.
        Set ``False`` to explicitly disable. Resolved once on first forward
        call. Default: None.
    quantile_interpolation : str, optional
        Interpolation method passed to torch.quantile for the band edges.
        'higher' is the conservative default. One of ('linear', 'lower',
        'higher', 'nearest', 'midpoint'). Default: 'higher'.
    max_pool_size : int or None, optional
        Maximum number of rows in the ranking pool (live batch + queue after
        ignore_index filtering). When exceeded, minimum-quota subsampling caps
        it. See ``RecallAtQuantileLoss`` for details. ``None`` (default)
        disables the cap.

    Examples
    --------
    >>> loss_fn = PAUCAtBudgetLoss(num_classes=4, alpha=0.0, beta=0.005)
    >>> logits  = torch.randn(256, 4)
    >>> targets = torch.randint(0, 4, (256,))
    >>> loss = loss_fn(logits, targets)
    >>> loss.backward()

    Notes
    -----
    The iid-negative band edges depend only on rows flagged
    ``iid_mask=True``; appending non-iid negatives (caller-side densification)
    does not shift ``t_alpha``/``t_beta``. ``iid_mask=None`` treats all rows as
    iid (the common case when negatives are never densified by class).

    Trapezoid cost is O(|P| x n_knots); pairwise cost is O(|P| x |band|).
    No O(M^2) path.

    References
    ----------
    This loss is an original design, not a published method, but the
    partial-AUC-over-a-band objective and its estimators build on prior work:

    .. [1] D. K. McClish (1989). "Analyzing a Portion of the ROC Curve."
       Medical Decision Making 9(3), 190-195. -- partial AUC over an ROC region.
    .. [2] L. E. Dodd and M. S. Pepe (2003). "Partial AUC Estimation and
       Regression." Biometrics 59(3), 614-623. -- nonparametric pAUC estimator
       (the consistent plug-in the trapezoid surrogate relates to).
    .. [3] H. Narasimhan and S. Agarwal (2013). "A Structural SVM Based Approach
       for Optimizing Partial AUC." ICML 2013. -- optimizing pAUC over an FPR
       band [alpha, beta] as a learning objective (the KDD 2013 "tight" variant
       gives the boundary-corrected estimator).
    .. [4] D. Zhu, G. Li, B. Wang, X. Wu, and T. Yang (2022). "When AUC meets
       DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence
       Guarantee." ICML 2022. -- deep one-/two-way pAUC surrogate optimization.
    """

    _VALID_INTERPOLATIONS = ("linear", "lower", "higher", "nearest", "midpoint")
    _VALID_SURROGATES = ("trapezoid", "pairwise")
    _VALID_TAU_SCALES = ("iqr", "band")
    _VALID_POS_NUMERATORS = ("pool", "live")

    # Floor on the detached dispersion to avoid div-by-zero when all iid
    # negative scores are (near) equal.
    _SCALE_EPS = 1e-12

    def __init__(
        self,
        num_classes: int,
        alpha: float = 0.0,
        beta: float = 0.005,
        surrogate: Literal["trapezoid", "pairwise"] = "trapezoid",
        n_knots: int = 2,
        tau_scale: Literal["iqr", "band"] = "iqr",
        pos_numerator: Literal["pool", "live"] = "pool",
        queue_size: int = 1024,
        temperature: float = 0.1,
        reduction: Literal["mean", "sum", "none"] = "mean",
        ignore_index: int = -100,
        update_queue_in_eval: bool = False,
        gather_distributed: bool | None = None,
        quantile_interpolation: str = "higher",
        max_pool_size: int | None = None,
    ) -> None:
        if not (0.0 <= alpha < beta <= 1.0):
            raise ValueError(
                f"alpha and beta must satisfy 0 <= alpha < beta <= 1, "
                f"got alpha={alpha}, beta={beta}"
            )
        if not isinstance(n_knots, int) or n_knots < 2:
            raise ValueError(f"n_knots must be an int >= 2, got {n_knots}")
        if surrogate not in self._VALID_SURROGATES:
            raise ValueError(
                f"surrogate must be one of {self._VALID_SURROGATES}, got '{surrogate}'"
            )
        if tau_scale not in self._VALID_TAU_SCALES:
            raise ValueError(
                f"tau_scale must be one of {self._VALID_TAU_SCALES}, got '{tau_scale}'"
            )
        if pos_numerator not in self._VALID_POS_NUMERATORS:
            raise ValueError(
                f"pos_numerator must be one of {self._VALID_POS_NUMERATORS}, "
                f"got '{pos_numerator}'"
            )
        if quantile_interpolation not in self._VALID_INTERPOLATIONS:
            raise ValueError(
                f"quantile_interpolation must be one of {self._VALID_INTERPOLATIONS}, "
                f"got '{quantile_interpolation}'"
            )

        super().__init__(
            num_classes=num_classes,
            queue_size=queue_size,
            temperature=temperature,
            reduction=reduction,
            ignore_index=ignore_index,
            update_queue_in_eval=update_queue_in_eval,
            gather_distributed=gather_distributed,
            max_pool_size=max_pool_size,
        )

        self.alpha = float(alpha)
        self.beta = float(beta)
        self.surrogate = surrogate
        self.n_knots = n_knots
        self.tau_scale = tau_scale
        self.pos_numerator = pos_numerator
        self.quantile_interpolation = quantile_interpolation

        # Warn once per instance when a class is skipped due to near-zero
        # iid-negative dispersion (mirrors the _subsample_warned pattern).
        self._degenerate_warned = False

    # ------------------------------------------------------------------
    # Backward-compatible access to queue internals
    # ------------------------------------------------------------------
    # Tests and external code may access _q_logits, _q_targets, _q_ptr
    # directly on the loss instance. These properties delegate to the
    # nested _MemoryQueue submodule.

    @property
    def _q_logits(self):
        return self._queue._q_logits

    @_q_logits.setter
    def _q_logits(self, value):
        self._queue._q_logits = value

    @property
    def _q_targets(self):
        return self._queue._q_targets

    @_q_targets.setter
    def _q_targets(self, value):
        self._queue._q_targets = value

    @property
    def _q_ptr(self):
        return self._queue._q_ptr

    @_q_ptr.setter
    def _q_ptr(self, value):
        self._queue._q_ptr = value

    # ------------------------------------------------------------------
    # Backward-compatible queue methods
    # ------------------------------------------------------------------

    @torch.no_grad()
    def _enqueue(self, logits: torch.Tensor, targets: torch.Tensor) -> None:
        """Delegate to the internal ``_MemoryQueue``."""
        self._queue.enqueue(logits, targets)

    def _merge_with_queue(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Delegate to the internal ``_MemoryQueue``."""
        return self._queue.merge(logits, targets)

    # ------------------------------------------------------------------
    # Core algorithm
    # ------------------------------------------------------------------

    def _band_thresholds_and_scale(
        self, neg: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Compute detached band edges and the raw robust dispersion.

        Parameters
        ----------
        neg : torch.Tensor, shape [n_iid_neg]
            Detached iid-negative scores for one class.

        Returns
        -------
        t_alpha : torch.Tensor, scalar
            Lower-FPR band edge ``quantile(neg, 1 - alpha)`` (detached).
        t_beta : torch.Tensor, scalar
            Upper-FPR band edge ``quantile(neg, 1 - beta)`` (detached);
            always ``t_beta <= t_alpha`` since ``alpha < beta``.
        scale : torch.Tensor, scalar
            Raw (unclamped) robust dispersion of ``neg`` -- IQR or band
            width depending on ``tau_scale``.  The caller must test this
            against ``_SCALE_EPS`` before computing ``tau_eff``.
        """
        t_alpha = torch.quantile(
            neg, 1.0 - self.alpha, interpolation=self.quantile_interpolation
        )
        t_beta = torch.quantile(
            neg, 1.0 - self.beta, interpolation=self.quantile_interpolation
        )

        if self.tau_scale == "iqr":
            q75 = torch.quantile(
                neg, 0.75, interpolation=self.quantile_interpolation
            )
            q25 = torch.quantile(
                neg, 0.25, interpolation=self.quantile_interpolation
            )
            scale = q75 - q25
        else:  # "band"
            scale = t_alpha - t_beta

        return t_alpha, t_beta, scale

    def _compute_pauc(
        self,
        scores: torch.Tensor,
        is_pos: torch.Tensor,
        is_neg: torch.Tensor,
        is_iid: torch.Tensor,
        is_live: torch.Tensor,
    ) -> tuple[torch.Tensor, bool, dict[str, Any]]:
        """
        Compute normalized partial AUC over ``[alpha, beta]`` for one class.

        Parameters
        ----------
        scores : torch.Tensor, shape [M]
            Pooled scores for one class (live + queue, padding stripped).
            Gradient flows through live-batch scores; queue scores are
            already detached upstream.
        is_pos : torch.Tensor, shape [M], dtype=bool
            Positive mask for this class.
        is_neg : torch.Tensor, shape [M], dtype=bool
            Negative mask for this class (``~is_pos``).
        is_iid : torch.Tensor, shape [M], dtype=bool
            Per-row iid-eligibility flag.
        is_live : torch.Tensor, shape [M], dtype=bool
            Per-row live-batch flag (True = live-batch row, False = queue).
            Consulted only when ``self.pos_numerator == "live"``.

        Returns
        -------
        pauc : torch.Tensor, scalar
            Normalized partial AUC estimate in [0, 1]. Zero (no gradient) for
            invalid classes.
        valid : bool
            False if there are no positives, no iid negatives, the iid-negative
            dispersion is near-zero (degenerate), or (pairwise) no band negatives.
            When ``pos_numerator="live"``, also False if there are no live
            positives for this class (no gradient signal this step).
            Invalid classes are excluded from reduction.
        diag : dict
            Diagnostic scalars for this class (all detached) when
            ``self._want_diag`` is True. Keys: ``t_alpha``, ``t_beta``,
            ``tau_eff``, ``band_neg_count``, ``pauc_var``. Returns an empty
            dict ``{}`` for invalid/degenerate classes or when
            ``self._want_diag`` is False; the caller (``_compute_per_class``)
            tolerates empty dicts via ``if diag:`` and leaves ``self._last_diag``
            at its nan/0 sentinel default for those classes.

        Notes
        -----
        Band edges ``t_alpha``/``t_beta`` and ``tau_eff`` are computed from the
        DETACHED iid negatives regardless of ``pos_numerator`` — the queue
        still stabilizes thresholds even when the numerator is restricted to
        live positives.  In trapezoid mode gradient reaches the numerator
        positives only; in pairwise mode it reaches numerator positives and
        band negatives.
        """
        iid_neg = is_neg & is_iid
        n_pos = int(is_pos.sum())
        n_iid_neg = int(iid_neg.sum())
        if n_pos == 0 or n_iid_neg == 0:
            return scores.new_zeros(()), False, {}

        # Determine the numerator positive set.
        if self.pos_numerator == "live":
            pos_num = is_pos & is_live
            if int(pos_num.sum()) == 0:
                # No live positives this step: no gradient signal, mark invalid.
                return scores.new_zeros(()), False, {}
        else:
            # "pool": use all pooled positives (pre-change behavior).
            pos_num = is_pos

        neg = scores[iid_neg].detach()
        t_alpha, t_beta, scale = self._band_thresholds_and_scale(neg)

        # Degeneracy guard: if the robust dispersion is ~zero, the sigmoid
        # temperature cannot be calibrated.  Mark as invalid rather than
        # computing with tau_eff ≈ 1e-13 (which yields a signal-free loss=1
        # or an exploding gradient).  Emit a one-time warning.
        if scale <= self._SCALE_EPS:
            if not self._degenerate_warned:
                _alpha_note = (
                    "alpha=0 means t_alpha=max(neg_iid); dispersion is near-zero "
                    "because iid negatives have equal (or near-equal) scores."
                    if self.alpha == 0.0 else
                    f"fewer than ~{1.0 / self.alpha:.4g} iid negatives are needed "
                    f"to resolve the tail quantile at alpha={self.alpha:.4g}."
                )
                warnings.warn(
                    f"{type(self).__name__}: iid-negative score dispersion is "
                    f"near-zero (scale={scale.item():.2e} <= _SCALE_EPS={self._SCALE_EPS:.2e}) "
                    f"for at least one class. This typically means all iid negatives "
                    f"have equal (or near-equal) scores, or the FPR band "
                    f"[{self.alpha}, {self.beta}] is too narrow relative to the "
                    f"available iid-negative count ({_alpha_note}) "
                    f"The affected class is skipped (marked INVALID). "
                    f"To fix: increase queue_size or ensure iid negatives cover a "
                    f"meaningful score range. "
                    f"(This warning is shown once per instance.)",
                    UserWarning,
                    stacklevel=5,
                )
                self._degenerate_warned = True
            return scores.new_zeros(()), False, {}

        # Apply the div-by-zero floor only AFTER the degeneracy check passes.
        tau_eff = self.temperature * scale.clamp_min(self._SCALE_EPS)

        if self.surrogate == "trapezoid":
            # FPR knots equally spaced over [alpha, beta]; threshold per knot.
            # dtype must match scores so torch.quantile doesn't raise on float64.
            f_k = torch.linspace(
                self.alpha, self.beta, self.n_knots,
                device=scores.device, dtype=scores.dtype
            )
            t_k = torch.quantile(
                neg, 1.0 - f_k, interpolation=self.quantile_interpolation
            )  # [n_knots], detached
            p = scores[pos_num]  # gradient flows here (numerator positive set)
            # [n_pos_num, n_knots]; each row is the contribution vector for one positive.
            contrib_mat = torch.sigmoid(
                (p.unsqueeze(1) - t_k.unsqueeze(0)) / tau_eff
            )
            # [n_knots] -- mean over numerator positives.
            tpr = contrib_mat.mean(dim=0)
            # Composite trapezoid on a uniform grid, normalized to [alpha, beta].
            pauc = (
                0.5 * tpr[0] + tpr[1:-1].sum() + 0.5 * tpr[-1]
            ) / (self.n_knots - 1)
            if not self._want_diag:
                return pauc, True, {}
            # Per-positive pAUC contribution: apply the same trapezoid weights
            # per row so that mean(v_i) == pauc.
            with torch.no_grad():
                weights = contrib_mat.new_ones(self.n_knots)
                weights[0] = 0.5
                weights[-1] = 0.5
                # [n_pos_num]
                v = (contrib_mat.detach() * weights.unsqueeze(0)).sum(dim=1) / (self.n_knots - 1)
                pauc_var = v.var(unbiased=False)
            # band_neg_count: iid negatives in the band [t_beta, t_alpha].
            band_neg_count = int(((neg >= t_beta) & (neg <= t_alpha)).sum())
            diag = {
                "t_alpha": t_alpha.detach(),
                "t_beta": t_beta.detach(),
                "tau_eff": tau_eff.detach(),
                "band_neg_count": band_neg_count,
                "pauc_var": pauc_var.detach(),
            }
            return pauc, True, diag

        # surrogate == "pairwise": band negatives from the GRADIENT POOL.
        band = is_neg & (scores >= t_beta) & (scores <= t_alpha)
        if int(band.sum()) == 0:
            return scores.new_zeros(()), False, {}
        p = scores[pos_num]   # numerator positive set (gradient flows here)
        b = scores[band]      # band negatives carry gradient (intended)
        # [n_pos_num, n_band]; each row is the per-positive contribution vector.
        contrib_mat = torch.sigmoid(
            (p.unsqueeze(1) - b.unsqueeze(0)) / tau_eff
        )
        pauc = contrib_mat.mean()
        if not self._want_diag:
            return pauc, True, {}
        # Per-positive contribution: mean over band negatives for each positive.
        with torch.no_grad():
            v = contrib_mat.detach().mean(dim=1)  # [n_pos_num]
            pauc_var = v.var(unbiased=False)
        # band_neg_count counts IID negatives in the band (consistent with
        # trapezoid), since the diagnostic semantics are population-level FPR.
        band_neg_count = int(((neg >= t_beta) & (neg <= t_alpha)).sum())
        diag = {
            "t_alpha": t_alpha.detach(),
            "t_beta": t_beta.detach(),
            "tau_eff": tau_eff.detach(),
            "band_neg_count": band_neg_count,
            "pauc_var": pauc_var.detach(),
        }
        return pauc, True, diag

    # ------------------------------------------------------------------
    # Diagnostics-aware forward override
    # ------------------------------------------------------------------

    def forward(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        iid_mask: torch.Tensor | None = None,
        return_per_class: bool = False,
        return_diagnostics: bool = False,
    ) -> (
        torch.Tensor
        | tuple[torch.Tensor, dict]
        | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]
    ):
        """
        Compute the pAUC loss, optionally returning per-class diagnostics.

        Parameters
        ----------
        logits : torch.Tensor, shape [N, C]
            Raw (un-normalised) class scores.
        targets : torch.Tensor, shape [N]
            Integer class labels.  Positions equal to ``ignore_index``
            are excluded.
        iid_mask : torch.Tensor, shape [N], dtype=bool, optional
            Per-row iid-eligibility flag.  ``None`` treats all rows as iid.
        return_per_class : bool, optional
            If True, also return per-class losses and a validity mask.
        return_diagnostics : bool, optional
            If True, also return a ``stats`` dict with per-class diagnostic
            tensors of shape ``[C]``.  Invalid or degenerate classes yield
            ``nan`` for float fields and ``0`` for count fields.

            Keys: ``t_alpha``, ``t_beta``, ``tau_eff``, ``band_neg_count``,
            ``pauc_var``, ``grad_pos_count``.

            ``grad_pos_count`` is rank-local (computed from the live pre-gather
            batch); under DDP the true gradient-carrying positive population is
            the sum across all ranks.

            When ``False`` (default), behavior is bit-identical to the
            base-class forward.

        Returns
        -------
        loss : torch.Tensor
            Scalar or shape ``[C]`` (``reduction='none'``).
        per_class_loss : torch.Tensor, shape [C]
            Only when ``return_per_class=True``.
        valid_classes : torch.Tensor, shape [C], dtype=bool
            Only when ``return_per_class=True``.
        stats : dict[str, Tensor[C]]
            Only when ``return_diagnostics=True``.  Order in the tuple
            is ``(loss, per_class, valid, stats)`` or ``(loss, stats)``
            depending on ``return_per_class``.

        Notes
        -----
        ``self._last_diag`` is transient per-call internal state; it is
        reset at the top of every forward call.  Statefulness here is a
        deliberate tradeoff to avoid changing the shared ``_QueuedRankingLoss``
        base-class contract (which cannot accept extra return values from
        ``_compute_per_class``).
        """
        # --- squeeze [N,1] targets early so grad_pos_count sees the right shape --
        if targets.ndim == 2 and targets.size(1) == 1:
            targets = targets.squeeze(1)

        # --- reset transient diagnostic state before every call ----------------
        # Always reset so _compute_per_class (called by super().forward) can
        # safely write to self._last_diag regardless of return_diagnostics.
        # Sentinel structure: float fields default to nan, count fields to 0.
        # The empty-pool early-return path in the base class skips
        # _compute_per_class, so _last_diag stays at this nan/0 default,
        # which is exactly the right diagnostic output for an empty pool.
        _nan = float("nan")
        self._last_diag: list[dict] = [
            {
                "t_alpha": _nan,
                "t_beta": _nan,
                "tau_eff": _nan,
                "band_neg_count": 0,
                "pauc_var": _nan,
            }
            for _ in range(self.num_classes)
        ]

        # --- grad_pos_count: live-batch positives per class (after ignore_index) -
        # Computed here because _compute_per_class sees the merged pool and
        # cannot distinguish live rows from queue rows.
        # Gate all diagnostic tensor ops in _compute_pauc behind this flag.
        # Must be set BEFORE super().forward() calls _compute_per_class.
        self._want_diag = bool(return_diagnostics)

        if not return_diagnostics:
            # Fast path: no diagnostics needed — bit-identical to base forward.
            return super().forward(
                logits, targets, iid_mask=iid_mask, return_per_class=return_per_class
            )

        valid_mask = targets != self.ignore_index
        filtered_targets = targets[valid_mask]
        grad_pos_count = logits.new_zeros(self.num_classes, dtype=torch.long)
        if self.num_classes == 1:
            grad_pos_count[0] = int(filtered_targets.bool().sum())
        else:
            for c in range(self.num_classes):
                grad_pos_count[c] = int((filtered_targets == c).sum())

        # --- delegate to base forward (runs _compute_per_class as side effect) --
        base_out = super().forward(
            logits, targets, iid_mask=iid_mask, return_per_class=return_per_class
        )

        # --- assemble stats dict from _last_diag --------------------------------
        dev = logits.device
        dtype = logits.dtype

        def _scalar_or_nan(val, is_nan_sentinel):
            """Return a float tensor from val, or nan if is_nan_sentinel."""
            if is_nan_sentinel:
                return torch.tensor(float("nan"), device=dev, dtype=dtype)
            if isinstance(val, torch.Tensor):
                return val.to(device=dev, dtype=dtype)
            return torch.tensor(float(val), device=dev, dtype=dtype)

        t_alpha_vals, t_beta_vals, tau_eff_vals = [], [], []
        band_neg_counts, pauc_var_vals = [], []

        for c in range(self.num_classes):
            d = self._last_diag[c]
            is_invalid = isinstance(d.get("t_alpha"), float) and (
                d["t_alpha"] != d["t_alpha"]  # nan check
            )
            t_alpha_vals.append(_scalar_or_nan(d["t_alpha"], is_invalid))
            t_beta_vals.append(_scalar_or_nan(d["t_beta"], is_invalid))
            tau_eff_vals.append(_scalar_or_nan(d["tau_eff"], is_invalid))
            band_neg_counts.append(
                torch.tensor(0 if is_invalid else d["band_neg_count"],
                             device=dev, dtype=torch.long)
            )
            pauc_var_vals.append(_scalar_or_nan(d["pauc_var"], is_invalid))

        stats: dict[str, torch.Tensor] = {
            "t_alpha":        torch.stack(t_alpha_vals),
            "t_beta":         torch.stack(t_beta_vals),
            "tau_eff":        torch.stack(tau_eff_vals),
            "band_neg_count": torch.stack(band_neg_counts),
            "pauc_var":       torch.stack(pauc_var_vals),
            "grad_pos_count": grad_pos_count.to(device=dev),
        }

        # --- build return value -------------------------------------------------
        if return_per_class:
            # base_out is (loss, per_class, valid)
            loss, per_class, valid = base_out
            return loss, per_class, valid, stats
        else:
            return base_out, stats

    # ------------------------------------------------------------------
    # Per-class dispatch (required by _QueuedRankingLoss)
    # ------------------------------------------------------------------

    def _compute_per_class(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        is_iid: torch.Tensor,
        is_live: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute 1 - pAUC for each class via one-vs-rest decomposition.

        Parameters
        ----------
        logits : torch.Tensor, shape [M, C]
            Pooled logits (live batch + queue, ignore-index rows removed,
            subsampling applied).
        targets : torch.Tensor, shape [M]
            Corresponding integer targets.
        is_iid : torch.Tensor, shape [M], dtype=bool
            Per-row iid-eligibility flag for FPR-band threshold estimation.
        is_live : torch.Tensor, shape [M], dtype=bool
            Per-row live-batch flag; threaded into ``_compute_pauc`` so
            ``pos_numerator="live"`` can restrict the numerator positive set.

        Returns
        -------
        loss_vec : torch.Tensor, shape [C]
            Per-class loss values (1 - pAUC).
        valid_vec : torch.Tensor, shape [C], dtype=bool
            True for classes with positives, iid negatives, and (pairwise)
            band negatives.
        """
        if self.num_classes == 1:
            # Binary mode: warn on out-of-range targets
            bad = targets[(targets != 0) & (targets != 1)]
            if bad.numel() > 0:
                warnings.warn(
                    f"Binary mode (num_classes=1) expects targets in {{0, 1}}, "
                    f"but found values: {bad[:8].tolist()}. "
                    "Non-zero values are treated as positive.",
                    UserWarning,
                    stacklevel=4,
                )
            is_pos = targets.bool()
            pauc, is_valid, diag = self._compute_pauc(
                logits[:, 0], is_pos, ~is_pos, is_iid, is_live
            )
            loss_vals = [1.0 - pauc]
            valid_mask = [is_valid]
            if diag:
                self._last_diag[0] = diag
        else:
            loss_vals, valid_mask = [], []
            for c in range(self.num_classes):
                is_pos = targets == c
                pauc, is_valid, diag = self._compute_pauc(
                    logits[:, c], is_pos, ~is_pos, is_iid, is_live
                )
                loss_vals.append(1.0 - pauc)
                valid_mask.append(is_valid)
                if diag:
                    self._last_diag[c] = diag

        loss_vec = torch.stack(loss_vals)
        valid_vec = torch.tensor(valid_mask, device=logits.device)
        return loss_vec, valid_vec

forward(logits, targets, iid_mask=None, return_per_class=False, return_diagnostics=False)

Compute the pAUC loss, optionally returning per-class diagnostics.

Parameters:

Name Type Description Default
logits (Tensor, shape[N, C])

Raw (un-normalised) class scores.

required
targets (Tensor, shape[N])

Integer class labels. Positions equal to ignore_index are excluded.

required
iid_mask torch.Tensor, shape [N], dtype=bool

Per-row iid-eligibility flag. None treats all rows as iid.

None
return_per_class bool

If True, also return per-class losses and a validity mask.

False
return_diagnostics bool

If True, also return a stats dict with per-class diagnostic tensors of shape [C]. Invalid or degenerate classes yield nan for float fields and 0 for count fields.

Keys: t_alpha, t_beta, tau_eff, band_neg_count, pauc_var, grad_pos_count.

grad_pos_count is rank-local (computed from the live pre-gather batch); under DDP the true gradient-carrying positive population is the sum across all ranks.

When False (default), behavior is bit-identical to the base-class forward.

False

Returns:

Name Type Description
loss Tensor

Scalar or shape [C] (reduction='none').

per_class_loss (Tensor, shape[C])

Only when return_per_class=True.

valid_classes torch.Tensor, shape [C], dtype=bool

Only when return_per_class=True.

stats dict[str, Tensor[C]]

Only when return_diagnostics=True. Order in the tuple is (loss, per_class, valid, stats) or (loss, stats) depending on return_per_class.

Notes

self._last_diag is transient per-call internal state; it is reset at the top of every forward call. Statefulness here is a deliberate tradeoff to avoid changing the shared _QueuedRankingLoss base-class contract (which cannot accept extra return values from _compute_per_class).

Source code in src/imbalanced_losses/pauc_loss.py
def forward(
    self,
    logits: torch.Tensor,
    targets: torch.Tensor,
    iid_mask: torch.Tensor | None = None,
    return_per_class: bool = False,
    return_diagnostics: bool = False,
) -> (
    torch.Tensor
    | tuple[torch.Tensor, dict]
    | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    | tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]
):
    """
    Compute the pAUC loss, optionally returning per-class diagnostics.

    Parameters
    ----------
    logits : torch.Tensor, shape [N, C]
        Raw (un-normalised) class scores.
    targets : torch.Tensor, shape [N]
        Integer class labels.  Positions equal to ``ignore_index``
        are excluded.
    iid_mask : torch.Tensor, shape [N], dtype=bool, optional
        Per-row iid-eligibility flag.  ``None`` treats all rows as iid.
    return_per_class : bool, optional
        If True, also return per-class losses and a validity mask.
    return_diagnostics : bool, optional
        If True, also return a ``stats`` dict with per-class diagnostic
        tensors of shape ``[C]``.  Invalid or degenerate classes yield
        ``nan`` for float fields and ``0`` for count fields.

        Keys: ``t_alpha``, ``t_beta``, ``tau_eff``, ``band_neg_count``,
        ``pauc_var``, ``grad_pos_count``.

        ``grad_pos_count`` is rank-local (computed from the live pre-gather
        batch); under DDP the true gradient-carrying positive population is
        the sum across all ranks.

        When ``False`` (default), behavior is bit-identical to the
        base-class forward.

    Returns
    -------
    loss : torch.Tensor
        Scalar or shape ``[C]`` (``reduction='none'``).
    per_class_loss : torch.Tensor, shape [C]
        Only when ``return_per_class=True``.
    valid_classes : torch.Tensor, shape [C], dtype=bool
        Only when ``return_per_class=True``.
    stats : dict[str, Tensor[C]]
        Only when ``return_diagnostics=True``.  Order in the tuple
        is ``(loss, per_class, valid, stats)`` or ``(loss, stats)``
        depending on ``return_per_class``.

    Notes
    -----
    ``self._last_diag`` is transient per-call internal state; it is
    reset at the top of every forward call.  Statefulness here is a
    deliberate tradeoff to avoid changing the shared ``_QueuedRankingLoss``
    base-class contract (which cannot accept extra return values from
    ``_compute_per_class``).
    """
    # --- squeeze [N,1] targets early so grad_pos_count sees the right shape --
    if targets.ndim == 2 and targets.size(1) == 1:
        targets = targets.squeeze(1)

    # --- reset transient diagnostic state before every call ----------------
    # Always reset so _compute_per_class (called by super().forward) can
    # safely write to self._last_diag regardless of return_diagnostics.
    # Sentinel structure: float fields default to nan, count fields to 0.
    # The empty-pool early-return path in the base class skips
    # _compute_per_class, so _last_diag stays at this nan/0 default,
    # which is exactly the right diagnostic output for an empty pool.
    _nan = float("nan")
    self._last_diag: list[dict] = [
        {
            "t_alpha": _nan,
            "t_beta": _nan,
            "tau_eff": _nan,
            "band_neg_count": 0,
            "pauc_var": _nan,
        }
        for _ in range(self.num_classes)
    ]

    # --- grad_pos_count: live-batch positives per class (after ignore_index) -
    # Computed here because _compute_per_class sees the merged pool and
    # cannot distinguish live rows from queue rows.
    # Gate all diagnostic tensor ops in _compute_pauc behind this flag.
    # Must be set BEFORE super().forward() calls _compute_per_class.
    self._want_diag = bool(return_diagnostics)

    if not return_diagnostics:
        # Fast path: no diagnostics needed — bit-identical to base forward.
        return super().forward(
            logits, targets, iid_mask=iid_mask, return_per_class=return_per_class
        )

    valid_mask = targets != self.ignore_index
    filtered_targets = targets[valid_mask]
    grad_pos_count = logits.new_zeros(self.num_classes, dtype=torch.long)
    if self.num_classes == 1:
        grad_pos_count[0] = int(filtered_targets.bool().sum())
    else:
        for c in range(self.num_classes):
            grad_pos_count[c] = int((filtered_targets == c).sum())

    # --- delegate to base forward (runs _compute_per_class as side effect) --
    base_out = super().forward(
        logits, targets, iid_mask=iid_mask, return_per_class=return_per_class
    )

    # --- assemble stats dict from _last_diag --------------------------------
    dev = logits.device
    dtype = logits.dtype

    def _scalar_or_nan(val, is_nan_sentinel):
        """Return a float tensor from val, or nan if is_nan_sentinel."""
        if is_nan_sentinel:
            return torch.tensor(float("nan"), device=dev, dtype=dtype)
        if isinstance(val, torch.Tensor):
            return val.to(device=dev, dtype=dtype)
        return torch.tensor(float(val), device=dev, dtype=dtype)

    t_alpha_vals, t_beta_vals, tau_eff_vals = [], [], []
    band_neg_counts, pauc_var_vals = [], []

    for c in range(self.num_classes):
        d = self._last_diag[c]
        is_invalid = isinstance(d.get("t_alpha"), float) and (
            d["t_alpha"] != d["t_alpha"]  # nan check
        )
        t_alpha_vals.append(_scalar_or_nan(d["t_alpha"], is_invalid))
        t_beta_vals.append(_scalar_or_nan(d["t_beta"], is_invalid))
        tau_eff_vals.append(_scalar_or_nan(d["tau_eff"], is_invalid))
        band_neg_counts.append(
            torch.tensor(0 if is_invalid else d["band_neg_count"],
                         device=dev, dtype=torch.long)
        )
        pauc_var_vals.append(_scalar_or_nan(d["pauc_var"], is_invalid))

    stats: dict[str, torch.Tensor] = {
        "t_alpha":        torch.stack(t_alpha_vals),
        "t_beta":         torch.stack(t_beta_vals),
        "tau_eff":        torch.stack(tau_eff_vals),
        "band_neg_count": torch.stack(band_neg_counts),
        "pauc_var":       torch.stack(pauc_var_vals),
        "grad_pos_count": grad_pos_count.to(device=dev),
    }

    # --- build return value -------------------------------------------------
    if return_per_class:
        # base_out is (loss, per_class, valid)
        loss, per_class, valid = base_out
        return loss, per_class, valid, stats
    else:
        return base_out, stats

Quick examples

Optimize pAUC around a 50 bps operating point

from imbalanced_losses import PAUCAtBudgetLoss
import torch

# Recommended band: alpha=0, beta=budget.
# t_alpha = max(neg_iid), t_beta = quantile at the budget threshold.
# Contrasts positives against every false-positive above the operating cutoff.
loss_fn = PAUCAtBudgetLoss(num_classes=4, alpha=0.0, beta=0.005, queue_size=1024)
logits  = torch.randn(256, 4)
targets = torch.randint(0, 4, (256,))

loss = loss_fn(logits, targets)
loss.backward()

Binary classification

loss_fn = PAUCAtBudgetLoss(num_classes=1, alpha=0.0, beta=0.005, queue_size=1024)
logits  = torch.randn(256, 1)
targets = torch.randint(0, 2, (256,))

loss = loss_fn(logits, targets)

Per-class logging

loss, per_class, valid = loss_fn(logits, targets, return_per_class=True)
loss.backward()

for c in valid.nonzero(as_tuple=True)[0].tolist():
    print(f"Class {c} pAUC-loss: {per_class[c].item():.4f}")

Diagnostics — detect band starvation

loss, stats = loss_fn(logits, targets, return_diagnostics=True)
# stats: per-class [C] tensors
print(stats["band_neg_count"])   # iid negatives landing in the band
print(stats["grad_pos_count"])   # live positives carrying gradient (rank-local)
print(stats["t_alpha"], stats["t_beta"], stats["tau_eff"], stats["pauc_var"])

If grad_pos_count sits near 1 and pauc_var wanders, the band is starved of gradient signal — increase the effective batch (DDP all-gather) or densify positives upstream.

Marking densified negatives (advanced)

If a caller densifies negatives by class (e.g. hard-negative mining), pass iid_mask so the FPR band edges are still estimated from an iid sample and beta keeps meaning population FPR:

# iid_mask[i] = True for rows drawn iid; False for injected/densified rows.
loss = loss_fn(logits, targets, iid_mask=iid_mask)

When iid_mask=None (the default) every negative is treated as iid — correct for any pipeline that never densifies negatives by class.

Parameter guidance

Parameter Default Notes
num_classes required Use 1 for binary
alpha 0.0 Lower FPR band edge; 0 <= alpha < beta <= 1. alpha=0 sets t_alpha=max(neg_iid), including all top negatives.
beta 0.005 Upper FPR band edge; set to your target operating-point FPR (e.g. 0.005 for 50 bps).
surrogate "trapezoid" "trapezoid" integrates soft-TPR over the band (gradient through positives only); "pairwise" compares positives vs band negatives (band negatives carry gradient) — for wide/volatile bands
n_knots 2 Trapezoid FPR knots; 2 is accurate for narrow bands, >= 3 for wide bands
tau_scale "iqr" Scale used for the scale-aware temperature: "iqr" (stable bulk statistic; pair with small temperature) or "band" (sized to the operating region; pair with temperature near 1.0)
pos_numerator "pool" Positives in the soft-TPR numerator (and the pairwise positive set): "pool" (all pooled positives) or "live" (live-batch only). "live" gives an undiluted gradient when the queue swamps the few live positives at extreme imbalance — most beneficial for "trapezoid"; "pairwise" usually prefers "pool" to keep the positive×band-negative contrast populated
temperature 0.1 Dimensionless multiplier on tau_eff = temperature * scale — not raw logits. Larger = smoother/biased-to-0.5; smaller = sharper but risks saturation
queue_size 1024 Larger queues stabilise the tail quantile; at low FPR you need many pooled negatives
reduction "mean" "none" returns [C]; invalid classes are nan
ignore_index -100 Excluded from threshold estimation and the positive set
quantile_interpolation "higher" Conservative default for the band edges
max_pool_size None Minimum-quota subsampling cap for very large pools (seq2seq)

Band selection guidance

Recommended convention: alpha ≈ 0, beta ≈ budget. For a 50 bps (FPR = 0.005) operating point, use alpha=0.0, beta=0.005. This sets t_alpha = max(neg_iid) and t_beta = quantile(neg_iid, 1 - budget), so the band covers every false-positive that scores above the operating threshold.

The older convention [budget/2, 1.5·budget] (e.g. [0.0025, 0.0075] for 50 bps) has two defects: its lower edge alpha = budget/2 excludes the highest-scoring (worst) negatives from the contrast, and its upper edge beta = 1.5·budget extends below the budget threshold into negatives that don't matter for coverage. A band sweep on synthetic contested-top extreme-imbalance data found coverage@budget to be monotone in both edges (smaller is better), with the old convention in the poorly-performing high-alpha region (the single worst cell is alpha=budget/2, beta=2.5·budget), well below the alpha≈0, beta≈budget optimum.

Caveat: this evidence is synthetic, at a single budget (50 bps), in the contested-top regime. The improvement is concentrated at pos_rate ≪ budget. Once pos_rate ≥ budget, coverage@budget is mechanically capped at budget/pos_rate and band choice is irrelevant.

The band edges are estimated as score quantiles of the iid negatives and approximate population FPR only when the pooled iid-negative count is adequate. With alpha=0, the upper threshold is always the maximum negative score (no tail-quantile bias for t_alpha); only t_beta requires enough negatives to resolve quantile(neg, 1 - beta). Use queue_size to accumulate enough negatives, and the band_neg_count diagnostic as the empirical check. A class whose iid-negative score dispersion is degenerate (≈ 0) is skipped (marked invalid) with a one-time warning.

Choosing among the ranking losses

PAUCAtBudgetLoss sits between the two existing ranking losses on the ROC:

  • SmoothAPLoss optimizes the whole precision–recall curve (Average Precision).
  • PAUCAtBudgetLoss optimizes a band of the ROC around your operating point.
  • RecallAtQuantileLoss optimizes recall at a single score threshold.

Reach for PAUCAtBudgetLoss when your business constraint is a fixed false-alarm budget (a region, not the whole curve and not one hard point).