Creating Custom Components¶
Time to complete: 20 minutes
In this tutorial, you'll learn how to create your own custom components for Tabular SSL. We'll start with a simple example and gradually build up your understanding.
What You'll Learn¶
- How to create a simple custom event encoder
- How to configure it using YAML files
- How to test and use your custom component
- Basic patterns for extending the library
Prerequisites¶
- Completed the Getting Started tutorial
- Basic PyTorch knowledge
- Understanding of neural networks
The Goal¶
We'll create a custom event encoder that uses a different activation function and architecture than the default MLP encoder. This will teach you the patterns for creating any type of custom component.
Step 1: Understanding Component Structure¶
First, let's look at what makes a component in Tabular SSL:
- Inherits from a base class (like
EventEncoder
) - Has a constructor with parameters
- Implements a
forward
method for processing data - Can be configured via YAML files
Step 2: Create Your First Custom Component¶
Let's create a simple custom event encoder. Create a new file:
# src/tabular_ssl/models/my_components.py
import torch
import torch.nn as nn
from typing import List
from tabular_ssl.models.base import EventEncoder
class SimpleCustomEncoder(EventEncoder):
"""A simple custom encoder with GELU activation and layer normalization."""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 128,
dropout: float = 0.1
):
super().__init__()
# Store parameters
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
# Create layers
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process the input tensor."""
# x shape: (batch_size, seq_len, input_dim)
batch_size, seq_len, _ = x.shape
# Reshape to process all timesteps at once
x_flat = x.reshape(-1, self.input_dim)
# Apply our layers
output_flat = self.layers(x_flat)
# Reshape back to sequence format
output = output_flat.reshape(batch_size, seq_len, self.output_dim)
return output
Step 3: Test Your Component¶
Before using it in training, let's test that it works:
# test_my_component.py
import torch
from tabular_ssl.models.my_components import SimpleCustomEncoder
def test_custom_encoder():
# Create test data
batch_size, seq_len, input_dim = 8, 10, 32
x = torch.randn(batch_size, seq_len, input_dim)
# Create encoder
encoder = SimpleCustomEncoder(
input_dim=32,
output_dim=64,
hidden_dim=128
)
# Test forward pass
output = encoder(x)
# Check output shape
expected_shape = (batch_size, seq_len, 64)
assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}"
print("✅ Custom encoder test passed!")
if __name__ == "__main__":
test_custom_encoder()
Run this test:
python test_my_component.py
Step 4: Create a Configuration File¶
Now create a YAML configuration for your component:
# configs/model/event_encoder/simple_custom.yaml
_target_: tabular_ssl.models.my_components.SimpleCustomEncoder
input_dim: 64
output_dim: 256
hidden_dim: 128
dropout: 0.1
Step 5: Use Your Component in an Experiment¶
Create an experiment that uses your custom component:
# configs/experiments/my_custom_experiment.yaml
# @package _global_
defaults:
- override /model/event_encoder: simple_custom
- override /model/sequence_encoder: null # Keep it simple for now
- override /model/projection_head: mlp
- override /model/prediction_head: classification
tags: ["custom", "simple"]
model:
learning_rate: 1.0e-3
weight_decay: 0.01
trainer:
max_epochs: 5 # Short run for testing
data:
batch_size: 32
Step 6: Run Your Custom Experiment¶
First, make sure Python can find your module:
# Add this to your train.py or create a new script
import tabular_ssl.models.my_components # This imports your custom components
Then run your experiment:
python train.py +experiment=my_custom_experiment
Step 7: Compare with Default¶
Let's compare your custom component with the default:
# Run with default MLP encoder
python train.py +experiment=simple_mlp trainer.max_epochs=5
# Run with your custom encoder
python train.py +experiment=my_custom_experiment
Look at the training logs to see how they perform differently!
Understanding What You Built¶
Your custom encoder differs from the default in several ways:
- GELU activation instead of ReLU
- Layer normalization instead of batch normalization
- Simpler architecture with just one hidden layer
- Different parameter initialization
These choices can affect training dynamics and final performance.
Next Steps: Making It Better¶
Now that you understand the basics, try these improvements:
1. Add Multiple Hidden Layers¶
class BetterCustomEncoder(EventEncoder):
def __init__(self, input_dim: int, output_dim: int, hidden_dims: List[int]):
super().__init__()
layers = []
dims = [input_dim] + hidden_dims + [output_dim]
for i in range(len(dims) - 1):
layers.extend([
nn.Linear(dims[i], dims[i + 1]),
nn.LayerNorm(dims[i + 1]),
nn.GELU(),
nn.Dropout(0.1)
])
self.layers = nn.Sequential(*layers[:-2]) # Remove last activation and dropout
2. Add Residual Connections¶
class ResidualCustomEncoder(EventEncoder):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# ... process x ...
residual = x if x.size(-1) == output.size(-1) else None
if residual is not None:
output = output + residual
return output
Key Patterns You've Learned¶
- Inherit from base classes (
EventEncoder
,SequenceEncoder
, etc.) - Use
_target_
in YAML to specify your component class - Test components independently before integration
- Import your modules so Hydra can find them
- Handle tensor shapes carefully (especially batch and sequence dimensions)
Common Pitfalls to Avoid¶
- Forgetting to import your custom module
- Shape mismatches between components
- Not testing before using in experiments
- Complex components that are hard to debug
What's Next?¶
🎯 Ready for more advanced patterns? Check out: - How-to: Advanced Component Patterns - Reference: Component API - Explanation: Architecture Design
Congratulations! 🎉 You've created your first custom component. You now understand the core patterns for extending Tabular SSL with your own innovations.