Models Reference¶
This section provides detailed documentation of the model components and architecture available in Tabular SSL.
Architecture Overview¶
Tabular SSL uses a simplified modular architecture where components are directly instantiated using Hydra's _target_
mechanism. This approach is cleaner and more straightforward than the previous registry-based system.
Base Components¶
BaseModel
¶
The main model class that orchestrates all components.
from tabular_ssl.models.base import BaseModel
Constructor¶
def __init__(
self,
event_encoder,
sequence_encoder=None,
embedding=None,
projection_head=None,
prediction_head=None,
learning_rate: float = 1e-4,
weight_decay: float = 0.01,
optimizer_type: str = "adamw",
scheduler_type: str = "cosine"
):
Key Methods¶
forward(x)
: Main forward pass through all componentstraining_step(batch, batch_idx)
: PyTorch Lightning training stepvalidation_step(batch, batch_idx)
: PyTorch Lightning validation stepconfigure_optimizers()
: Optimizer and scheduler configuration
EventEncoder
¶
Base class for components that encode individual events or timesteps.
from tabular_ssl.models.base import EventEncoder
All event encoders should inherit from this class and implement the forward
method:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (batch_size, seq_len, input_dim)
Returns:
Encoded tensor of shape (batch_size, seq_len, output_dim)
"""
SequenceEncoder
¶
Base class for components that encode sequences of events.
from tabular_ssl.models.base import SequenceEncoder
Sequence encoders process the output of event encoders:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (batch_size, seq_len, input_dim)
Returns:
Encoded sequence of shape (batch_size, seq_len, output_dim)
"""
EmbeddingLayer
¶
Base class for embedding layers that handle categorical features.
from tabular_ssl.models.base import EmbeddingLayer
ProjectionHead
¶
Base class for projection heads that transform representations.
from tabular_ssl.models.base import ProjectionHead
PredictionHead
¶
Base class for prediction heads that generate final outputs.
from tabular_ssl.models.base import PredictionHead
Available Components¶
Event Encoders¶
MLPEventEncoder
¶
Multi-layer perceptron for encoding individual events.
from tabular_ssl.models.components import MLPEventEncoder
Constructor Parameters:
- input_dim
(int): Input feature dimension
- hidden_dims
(List[int]): List of hidden layer dimensions
- output_dim
(int): Output dimension
- dropout
(float): Dropout rate (default: 0.1)
- activation
(str): Activation function ("relu", "gelu", "leaky_relu")
- use_batch_norm
(bool): Whether to use batch normalization
Example Configuration:
_target_: tabular_ssl.models.components.MLPEventEncoder
input_dim: 64
hidden_dims: [128, 256]
output_dim: 512
dropout: 0.1
activation: relu
use_batch_norm: true
Sequence Encoders¶
TransformerSequenceEncoder
¶
Transformer-based sequence encoder using self-attention.
from tabular_ssl.models.components import TransformerSequenceEncoder
Constructor Parameters:
- input_dim
(int): Input dimension
- hidden_dim
(int): Hidden dimension
- num_layers
(int): Number of transformer layers
- num_heads
(int): Number of attention heads
- dim_feedforward
(int): Feedforward network dimension
- dropout
(float): Dropout rate
- max_seq_length
(int): Maximum sequence length for positional encoding
Example Configuration:
_target_: tabular_ssl.models.components.TransformerSequenceEncoder
input_dim: 512
hidden_dim: 512
num_layers: 4
num_heads: 8
dim_feedforward: 2048
dropout: 0.1
max_seq_length: 2048
S4SequenceEncoder
¶
Structured State Space (S4) model for efficient long sequence processing.
from tabular_ssl.models.components import S4SequenceEncoder
Constructor Parameters:
- input_dim
(int): Input dimension
- hidden_dim
(int): Hidden state dimension
- num_layers
(int): Number of S4 layers
- dropout
(float): Dropout rate
- bidirectional
(bool): Whether to use bidirectional processing
- max_sequence_length
(int): Maximum sequence length
Example Configuration:
_target_: tabular_ssl.models.components.S4SequenceEncoder
input_dim: 512
hidden_dim: 64
num_layers: 2
dropout: 0.1
bidirectional: true
max_sequence_length: 2048
RNNSequenceEncoder
¶
RNN-based sequence encoder (LSTM, GRU, or vanilla RNN).
from tabular_ssl.models.components import RNNSequenceEncoder
Constructor Parameters:
- input_dim
(int): Input dimension
- hidden_dim
(int): Hidden dimension
- num_layers
(int): Number of RNN layers
- rnn_type
(str): Type of RNN ("lstm", "gru", "rnn")
- dropout
(float): Dropout rate
- bidirectional
(bool): Whether to use bidirectional RNN
Example Configuration:
_target_: tabular_ssl.models.components.RNNSequenceEncoder
input_dim: 128
hidden_dim: 128
num_layers: 2
rnn_type: lstm
dropout: 0.1
bidirectional: false
Embedding Layers¶
CategoricalEmbedding
¶
Handles embedding of categorical features with flexible dimensions.
from tabular_ssl.models.components import CategoricalEmbedding
Constructor Parameters:
- categorical_features
(List[Dict]): List of categorical feature specifications
- default_embedding_dim
(int): Default embedding dimension
- categorical_embedding_dims
(Dict[str, int]): Custom dimensions per feature
Example Configuration:
_target_: tabular_ssl.models.components.CategoricalEmbedding
categorical_features:
- name: category_1
num_categories: 10
embedding_dim: 32
- name: category_2
num_categories: 100
embedding_dim: 64
Projection Heads¶
MLPProjectionHead
¶
MLP-based projection head for transforming representations.
from tabular_ssl.models.components import MLPProjectionHead
Constructor Parameters:
- input_dim
(int): Input dimension
- hidden_dims
(List[int]): Hidden layer dimensions
- output_dim
(int): Output dimension
- dropout
(float): Dropout rate
- activation
(str): Activation function
Prediction Heads¶
ClassificationHead
¶
Classification head for supervised learning tasks.
from tabular_ssl.models.components import ClassificationHead
Constructor Parameters:
- input_dim
(int): Input dimension
- num_classes
(int): Number of output classes
- hidden_dims
(List[int]): Optional hidden layers
- dropout
(float): Dropout rate
Corruption Strategies¶
Corruption strategies are essential components for self-supervised learning on tabular data. They create pretext tasks by transforming input data.
VIMECorruption
¶
VIME (Value Imputation and Mask Estimation) corruption strategy.
from tabular_ssl.models.components import VIMECorruption
Constructor Parameters:
- corruption_rate
(float): Fraction of features to corrupt (default: 0.15)
- categorical_indices
(List[int]): Indices of categorical features
- numerical_indices
(List[int]): Indices of numerical features
- categorical_vocab_sizes
(Dict[int, int]): Vocabulary sizes per categorical feature
- numerical_distributions
(Dict[int, Tuple[float, float]]): (mean, std) per numerical feature
Example Configuration:
_target_: tabular_ssl.models.components.VIMECorruption
corruption_rate: 0.3
categorical_indices: [0, 1, 2]
numerical_indices: [3, 4, 5, 6]
SCARFCorruption
¶
SCARF (Self-Supervised Contrastive Learning using Random Feature Corruption) strategy.
from tabular_ssl.models.components import SCARFCorruption
Constructor Parameters:
- corruption_rate
(float): Fraction of features to corrupt (default: 0.6)
- corruption_strategy
(str): "random_swap" or "marginal_sampling"
Example Configuration:
_target_: tabular_ssl.models.components.SCARFCorruption
corruption_rate: 0.6
corruption_strategy: "random_swap"
ReConTabCorruption
¶
Multi-task reconstruction-based corruption strategy.
from tabular_ssl.models.components import ReConTabCorruption
Constructor Parameters:
- corruption_rate
(float): Base corruption rate for masking (default: 0.15)
- categorical_indices
(List[int]): Indices of categorical features
- numerical_indices
(List[int]): Indices of numerical features
- corruption_types
(List[str]): Types of corruption to apply
- masking_strategy
(str): "random", "column_wise", or "block"
- noise_std
(float): Standard deviation for noise corruption
- swap_probability
(float): Probability of feature swapping
Example Configuration:
_target_: tabular_ssl.models.components.ReConTabCorruption
corruption_rate: 0.15
corruption_types: ["masking", "noise", "swapping"]
masking_strategy: "random"
noise_std: 0.1
Simple Corruption Strategies¶
RandomMasking
¶
_target_: tabular_ssl.models.components.RandomMasking
corruption_rate: 0.15
GaussianNoise
¶
_target_: tabular_ssl.models.components.GaussianNoise
noise_std: 0.1
SwappingCorruption
¶
_target_: tabular_ssl.models.components.SwappingCorruption
swap_prob: 0.15
Utility Functions¶
create_mlp
¶
Utility function for creating MLP layers.
from tabular_ssl.models.base import create_mlp
mlp = create_mlp(
input_dim=64,
hidden_dims=[128, 256],
output_dim=512,
dropout=0.1,
activation="relu",
use_batch_norm=True
)
Component Instantiation¶
Components are instantiated using Hydra's _target_
mechanism:
# Direct instantiation
encoder = hydra.utils.instantiate({
"_target_": "tabular_ssl.models.components.MLPEventEncoder",
"input_dim": 64,
"hidden_dims": [128, 256],
"output_dim": 512
})
# From configuration
encoder = hydra.utils.instantiate(config.model.event_encoder)
Model Assembly¶
The BaseModel
class assembles all components:
# configs/model/default.yaml
defaults:
- event_encoder: mlp
- sequence_encoder: transformer
- projection_head: mlp
- prediction_head: classification
_target_: tabular_ssl.models.base.BaseModel
learning_rate: 1.0e-4
weight_decay: 0.01
optimizer_type: adamw
scheduler_type: cosine
Creating Custom Components¶
To create custom components, inherit from the appropriate base class:
from tabular_ssl.models.base import EventEncoder
class CustomEventEncoder(EventEncoder):
def __init__(self, input_dim: int, output_dim: int, **kwargs):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x.reshape(-1, x.size(-1))).reshape(x.shape[:-1] + (-1,))
Then create a configuration file:
# configs/model/event_encoder/custom.yaml
_target_: path.to.your.CustomEventEncoder
input_dim: 64
output_dim: 128
Best Practices¶
- Component Compatibility: Ensure output dimensions of one component match input dimensions of the next
- Memory Management: Use appropriate batch sizes and sequence lengths for your hardware
- Hyperparameter Tuning: Start with provided experiment configurations and adjust as needed
- Testing: Test custom components with different input shapes before integration
Common Patterns¶
MLP-Only Model¶
defaults:
- event_encoder: mlp
- sequence_encoder: null # No sequence processing
Transformer Model¶
defaults:
- event_encoder: mlp
- sequence_encoder: transformer
Long Sequence Model¶
defaults:
- event_encoder: mlp
- sequence_encoder: s4 # Efficient for long sequences
Troubleshooting¶
Dimension Mismatches¶
Check that component dimensions are compatible:
assert event_encoder.output_dim == sequence_encoder.input_dim
Memory Issues¶
- Reduce batch size or sequence length
- Use gradient accumulation
- Enable mixed precision training
Training Instability¶
- Lower learning rate
- Add gradient clipping
- Use layer normalization or batch normalization