Architecture Decisions
This page explains the design choices made in this implementation, and why they were made.
Why −∞ for Max Pooling with Masks
When computing max pooling over a masked set, we need to ensure that padded (inactive) positions never contribute to the maximum. A naïve approach would be to fill padding with 0:
# WRONG: if all active values are negative, 0-padded positions will win
tensor[mask == 0] = 0.0
result = tensor.max(dim=1)[0]
Consider the case where all active φ-outputs for a set are negative (e.g., [−0.3, −0.8, −0.1]) and padding is filled with 0. The max would incorrectly return 0 (a padding position) instead of −0.1 (the true maximum among active elements).
The correct approach is to fill padding with \(-\infty\):
# CORRECT: -inf never wins against any finite value
tensor = tensor.masked_fill(mask.unsqueeze(-1) == 0, float('-inf'))
result = tensor.max(dim=1)[0]
With \(-\infty\) fill, the max over active positions is always correct, even when all values are negative. This is implemented in _masked_pool:
elif pool_type == 'max':
if mask is not None:
tensor = tensor.masked_fill(mask.unsqueeze(-1) == 0, float('-inf'))
return tensor.max(dim=1, keepdim=keepdim)[0]
Downstream NaN risk
If a set has no active elements (all-zero mask), the max over −∞ values returns −∞. Downstream layers receiving −∞ may produce NaN gradients. Always ensure at least one active element per set in your data.
nn.ModuleList vs nn.Sequential in DeepSetsEquivariant
DeepSetsEquivariant uses nn.ModuleList to store its layers rather than nn.Sequential:
The forward method then iterates manually:
for layer in self.layers:
if isinstance(layer, PermutationEquivariantLayer):
out = layer(out, mask) # pass mask
else:
out = layer(out) # ReLU / Dropout — no mask arg
Why not nn.Sequential?
nn.Sequential.forward calls each layer with a single argument. PermutationEquivariantLayer.forward requires two arguments: the tensor x and the optional mask. There is no standard way to thread extra arguments through nn.Sequential.
nn.ModuleList registers all submodules for parameter tracking (so model.parameters() and model.to(device) work correctly) while giving full control over the forward pass.
The φ–Pool–ρ Decomposition in Practice
The theoretical decomposition \(f(\mathcal{X}) = \rho(\sum_x \varphi(x))\) has a natural implementation trade-off: how deep should φ and ρ be?
Recommended heuristic (from the paper and empirical practice):
| Network | Depth | Reasoning |
|---|---|---|
| φ | Deeper (3–4 layers) | Needs to learn rich per-element representations before pooling discards order information |
| ρ | Shallower (2–3 layers) | Processes a fixed-size pooled vector; standard MLP task |
Pooling is an information bottleneck: anything φ doesn't capture is lost. Investing capacity in φ before the bottleneck is generally more efficient than in ρ after it.
FiLM Conditioning — Design Rationale
The 'film' fusion strategy implements Feature-wise Linear Modulation (Perez et al., 2018):
FiLM applies context as a multiplicative and additive modulation of intermediate features, which is strictly more expressive than simple concatenation or addition:
| Fusion | Context influence | Expressiveness | Parameters added |
|---|---|---|---|
add |
Shifts input space | Low | context_dim × input_dim |
concat |
Extends input | Moderate | context_dim × phi_hidden_dims[0] |
film |
Scales + shifts hidden features | High | 2 × context_dim × phi_hidden_dims[0] |
The reason FiLM requires phi_hidden_dims to be non-empty is that it modulates the output of the first φ layer. Without any hidden layer, there is no intermediate representation to modulate.
Why pool_type='mean' is Excluded from PermutationEquivariantLayer
Lemma 3 characterises equivariant layers as \(\Lambda x_i + \Gamma \cdot \text{pool}(X)\) where pool is a fixed global aggregation. Mean pooling divides by the set size \(|\mathcal{X}|\), which is a function of the input — not a fixed symmetric function. This breaks the theoretical characterisation and can cause instability when set sizes vary within a batch.
For this reason, PermutationEquivariantLayer only supports 'sum' and 'max', both of which have clean theoretical properties in the equivariant layer formulation.
Comparison to Alternative Approaches
| Method | Time complexity | Permutation invariant | Universal approximation | Notes |
|---|---|---|---|---|
| Deep Sets | \(O(M)\) | ✓ (exact) | ✓ (Theorem 2) | Linear in set size |
| Pairwise coupling | \(O(M^2)\) | ✓ (exact) | ✓ | Expensive for large sets |
| Sort + RNN | \(O(M \log M)\) | ✗ (approximate) | — | Sort not differentiable |
| Attention (naive) | \(O(M^2)\) | ✓ (with symmetric attn) | ✓ | Quadratic memory |
| Graph neural networks | \(O(M + E)\) | ✓ (with global pool) | ✓ | Requires graph structure |
Deep Sets is the only method that achieves exact permutation invariance with \(O(M)\) complexity and a universal approximation guarantee, making it ideal for large sets where pairwise methods are prohibitive.
Pooling Mathematics
For a set \(\mathcal{X} = \{x_1, \ldots, x_M\}\) with φ-outputs \(\mathbf{h}_i = \varphi(x_i) \in \mathbb{R}^d\):
With masking (active set \(\mathcal{A} \subseteq \{1,\ldots,M\}\)):