API Reference
All public symbols are defined in deepsets.py. Import them directly:
from deepsets import (
DeepSetsInvariant,
PermutationEquivariantLayer,
DeepSetsEquivariant,
DeepSetsConditional,
)
The internal helper _masked_pool is also documented here for completeness.
_masked_pool
def _masked_pool(
tensor: torch.Tensor,
pool_type: str,
mask: Optional[torch.Tensor] = None,
keepdim: bool = False,
) -> torch.Tensor
Pool a (B, M, D) tensor along the set dimension (dim=1) with optional masking.
Parameters
| Parameter | Type | Description |
|---|---|---|
tensor |
Tensor shape (B, M, D) |
Input tensor to pool |
pool_type |
str |
One of 'sum', 'max', 'mean' |
mask |
Tensor shape (B, M), optional |
Float or bool; 1 = active, 0 = padding |
keepdim |
bool |
If True, keep the M dimension as size 1 in the output |
Returns
Tensor of shape (B, D) (or (B, 1, D) if keepdim=True).
Raises
ValueError if pool_type is not 'sum', 'max', or 'mean'.
Notes
sum/mean: padding positions are multiplied by zero.max: padding positions are filled with−∞so they cannot be the maximum even when all active values are negative.mean: denominator is clamped to≥ 1to avoid division by zero on empty sets.
DeepSetsInvariant
Deep Sets architecture for permutation invariant functions.
Implements Theorem 2: \(f(\mathcal{X}) = \rho\!\left(\sum_{x \in \mathcal{X}} \varphi(x)\right)\)
φ transforms each element independently; the results are pooled into a single vector and passed through ρ to produce the final output.
Constructor
DeepSetsInvariant(
input_dim: int,
phi_hidden_dims: List[int],
rho_hidden_dims: List[int],
output_dim: int,
pool_type: str = 'sum',
dropout: float = 0.0,
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
input_dim |
int |
— | Dimension of each set element |
phi_hidden_dims |
List[int] |
— | Hidden layer widths for the φ (element-wise) network. Pass [] to use a linear φ. |
rho_hidden_dims |
List[int] |
— | Hidden layer widths for the ρ (aggregation) network |
output_dim |
int |
— | Output dimension |
pool_type |
str |
'sum' |
Aggregation operation: 'sum', 'max', or 'mean' |
dropout |
float |
0.0 |
Dropout probability applied after each hidden layer (0 = disabled) |
Raises
ValueError if pool_type is not valid.
Attributes
| Attribute | Type | Description |
|---|---|---|
phi |
nn.Sequential |
Element-wise network |
rho |
nn.Sequential |
Aggregation network |
pool_type |
str |
Pooling mode |
forward
Parameters
| Parameter | Type | Description |
|---|---|---|
x |
Tensor shape (B, M, input_dim) |
Batch of sets |
mask |
Tensor shape (B, M), optional |
1 for active elements, 0 for padding |
Returns
Tensor of shape (B, output_dim).
Example
model = DeepSetsInvariant(
input_dim=10,
phi_hidden_dims=[64, 64],
rho_hidden_dims=[64, 32],
output_dim=1,
pool_type='sum',
)
x = torch.randn(32, 20, 10)
out = model(x) # (32, 1)
out = model(x, mask) # (32, 1), with masking
PermutationEquivariantLayer
Single permutation equivariant layer (Lemma 3).
Implements: \(f(\mathbf{x}_i) = \Lambda\mathbf{x}_i + \Gamma \cdot \text{pool}(\mathbf{X})\)
where \(\Lambda\) and \(\Gamma\) are learnable linear maps. The pooled global context is broadcast back to every element.
Constructor
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
input_dim |
int |
— | Input feature dimension per element |
output_dim |
int |
— | Output feature dimension per element |
pool_type |
str |
'max' |
Pooling for global context: 'max' or 'sum' (not 'mean') |
Raises
ValueError if pool_type is not 'sum' or 'max'.
Attributes
| Attribute | Type | Description |
|---|---|---|
lambda_net |
nn.Linear |
Per-element transform Λ: (input_dim → output_dim) |
gamma_net |
nn.Linear |
Global-context transform Γ: (input_dim → output_dim) |
pool_type |
str |
Pooling mode |
forward
Parameters
| Parameter | Type | Description |
|---|---|---|
x |
Tensor shape (B, M, input_dim) |
Batch of sets |
mask |
Tensor shape (B, M), optional |
1 for active, 0 for padding |
Returns
Tensor of shape (B, M, output_dim).
Example
layer = PermutationEquivariantLayer(input_dim=16, output_dim=32, pool_type='max')
x = torch.randn(4, 10, 16)
out = layer(x) # (4, 10, 32)
DeepSetsEquivariant
Deep Sets architecture for permutation equivariant functions.
Stacks PermutationEquivariantLayer blocks (with ReLU activations) and optionally applies a final pooling step to produce an invariant output.
Constructor
DeepSetsEquivariant(
input_dim: int,
hidden_dims: List[int],
output_dim: int,
pool_type: str = 'max',
final_pool: Optional[str] = None,
dropout: float = 0.0,
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
input_dim |
int |
— | Dimension of each set element |
hidden_dims |
List[int] |
— | Hidden widths for each equivariant layer. The total number of PermutationEquivariantLayer instances is len(hidden_dims) + 1. |
output_dim |
int |
— | Output dimension per element (or global if final_pool is set) |
pool_type |
str |
'max' |
Pooling inside each equivariant layer: 'sum' or 'max' |
final_pool |
str or None |
None |
If set, pool across the set after all layers to produce an invariant (B, output_dim) output. One of 'sum', 'max', 'mean'. |
dropout |
float |
0.0 |
Dropout probability between hidden layers |
Raises
ValueError if final_pool is set to an invalid value.
Attributes
| Attribute | Type | Description |
|---|---|---|
layers |
nn.ModuleList |
All layers: alternating PermutationEquivariantLayer, ReLU, optional Dropout |
final_pool |
str or None |
Final pooling mode |
forward
Parameters
| Parameter | Type | Description |
|---|---|---|
x |
Tensor shape (B, M, input_dim) |
Batch of sets |
mask |
Tensor shape (B, M), optional |
1 for active, 0 for padding |
Returns
Tensorof shape(B, M, output_dim)iffinal_pool is NoneTensorof shape(B, output_dim)iffinal_poolis set
Example
# Set-to-set (equivariant output)
model = DeepSetsEquivariant(
input_dim=8, hidden_dims=[32, 32], output_dim=4, pool_type='max'
)
out = model(torch.randn(4, 10, 8)) # (4, 10, 4)
# Set-to-vector (invariant output)
model = DeepSetsEquivariant(
input_dim=8, hidden_dims=[32, 32], output_dim=4,
pool_type='max', final_pool='sum'
)
out = model(torch.randn(4, 10, 8)) # (4, 4)
DeepSetsConditional
Deep Sets conditioned on a context vector \(\mathbf{z}\).
Implements: \(f(\mathcal{X} \mid \mathbf{z}) = \rho\!\left(\sum_{x \in \mathcal{X}} \varphi(x \mid \mathbf{z})\right)\)
Three strategies for fusing context into φ: 'concat', 'film', 'add'.
Constructor
DeepSetsConditional(
input_dim: int,
context_dim: int,
phi_hidden_dims: List[int],
rho_hidden_dims: List[int],
output_dim: int,
pool_type: str = 'sum',
fusion_type: str = 'concat',
context_in_rho: bool = True,
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
input_dim |
int |
— | Dimension of each set element |
context_dim |
int |
— | Dimension of conditioning vector |
phi_hidden_dims |
List[int] |
— | Hidden widths for φ. Must be non-empty when fusion_type='film'. |
rho_hidden_dims |
List[int] |
— | Hidden widths for ρ |
output_dim |
int |
— | Output dimension |
pool_type |
str |
'sum' |
Pooling: 'sum', 'max', or 'mean' |
fusion_type |
str |
'concat' |
Context fusion strategy: 'concat', 'film', or 'add' |
context_in_rho |
bool |
True |
If True, concatenate context to the ρ input so ρ has direct access to \(\mathbf{z}\) |
Raises
ValueErrorifpool_typeis invalid.ValueErroriffusion_typeis invalid.ValueErroriffusion_type='film'andphi_hidden_dimsis empty.
Attributes
Depend on fusion_type:
fusion_type='concat'
| Attribute | Description |
|---|---|
phi |
nn.Sequential — φ network with input dim input_dim + context_dim |
rho |
nn.Sequential — ρ network |
fusion_type='film'
| Attribute | Description |
|---|---|
phi_first |
nn.Linear — first φ layer: (input_dim → phi_hidden_dims[0]) |
film_gamma |
nn.Linear — scale parameter: (context_dim → phi_hidden_dims[0]) |
film_beta |
nn.Linear — shift parameter: (context_dim → phi_hidden_dims[0]) |
phi_rest |
nn.Sequential — remaining φ layers after FiLM |
rho |
nn.Sequential — ρ network |
fusion_type='add'
| Attribute | Description |
|---|---|
context_proj |
nn.Linear — project context: (context_dim → input_dim) |
phi |
nn.Sequential — φ network with input dim input_dim |
rho |
nn.Sequential — ρ network |
forward
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor
Parameters
| Parameter | Type | Description |
|---|---|---|
x |
Tensor shape (B, M, input_dim) |
Batch of sets |
context |
Tensor shape (B, context_dim) |
Conditioning vectors |
mask |
Tensor shape (B, M), optional |
1 for active, 0 for padding |
Returns
Tensor of shape (B, output_dim).
Static Method: _make_phi
Helper that constructs a φ network as [Linear → ReLU] × n. Used internally.