How-to: Self-Supervised Learning Training¶
This guide covers advanced techniques for training self-supervised learning models with Tabular SSL's corruption strategies.
Quick Start¶
Run Demo Scripts First¶
Before training, explore the corruption strategies interactively:
# See how corruption strategies work
python demo_corruption_strategies.py
# Try with real credit card data
python demo_credit_card_data.py
Basic SSL Training¶
# VIME: Mask estimation + value imputation
python train.py +experiment=vime_ssl
# SCARF: Contrastive learning
python train.py +experiment=scarf_ssl
# ReConTab: Multi-task reconstruction
python train.py +experiment=recontab_ssl
Choosing the Right Strategy¶
VIME - When to Use¶
Best for: - Mixed categorical/numerical tabular data - Interpretable pretext tasks - Downstream tasks requiring feature reconstruction
Characteristics: - Moderate corruption rate (30%) - Returns explicit masks - Two complementary tasks: mask estimation + value imputation
python train.py +experiment=vime_ssl
SCARF - When to Use¶
Best for: - Large datasets with diverse features - Pure representation learning - High-dimensional tabular data
Characteristics: - High corruption rate (60%+) - Contrastive learning approach - Requires larger batch sizes
python train.py +experiment=scarf_ssl
ReConTab - When to Use¶
Best for: - Complex multi-task scenarios - Fine-grained corruption control - Hybrid reconstruction + contrastive approaches
Characteristics: - Low base corruption rate (15%) but multiple types - Detailed corruption tracking - Flexible masking strategies
python train.py +experiment=recontab_ssl
Customizing Corruption Parameters¶
VIME Customization¶
# Adjust corruption rate
python train.py +experiment=vime_ssl model/corruption.corruption_rate=0.5
# Use with different sequence lengths
python train.py +experiment=vime_ssl data.sequence_length=64
# Modify loss weights
python train.py +experiment=vime_ssl model.mask_estimation_weight=2.0 model.value_imputation_weight=1.0
SCARF Customization¶
# Change corruption strategy
python train.py +experiment=scarf_ssl model/corruption.corruption_strategy=marginal_sampling
# Adjust temperature for contrastive loss
python train.py +experiment=scarf_ssl model.temperature=0.05
# Use larger batch size (important for SCARF)
python train.py +experiment=scarf_ssl data.batch_size=256
ReConTab Customization¶
# Enable only specific corruption types
python train.py +experiment=recontab_ssl model/corruption.corruption_types=['masking','noise']
# Use column-wise masking
python train.py +experiment=recontab_ssl model/corruption.masking_strategy=column_wise
# Adjust individual corruption parameters
python train.py +experiment=recontab_ssl model/corruption.noise_std=0.2 model/corruption.swap_probability=0.2
Working with Your Own Data¶
Preparing Data for SSL¶
- Create your DataModule:
# configs/data/your_data.yaml
_target_: tabular_ssl.data.base.DataModule
data_path: "path/to/your/data.csv"
sequence_length: 32
batch_size: 64
# Feature specifications
categorical_columns: ["category_col1", "category_col2"]
numerical_columns: ["num_col1", "num_col2", "num_col3"]
# Sample data generation (optional)
sample_data_config:
n_users: 1000
sequence_length: 32
- Use with SSL experiments:
# Use your data with VIME
python train.py +experiment=vime_ssl data=your_data
# Use your data with SCARF
python train.py +experiment=scarf_ssl data=your_data
Feature Type Detection¶
Corruption strategies need to know which features are categorical vs numerical:
# Automatic detection (default)
python train.py +experiment=vime_ssl
# Manual specification
python train.py +experiment=vime_ssl \
model/corruption.categorical_indices=[0,1,2] \
model/corruption.numerical_indices=[3,4,5,6,7]
Advanced Training Techniques¶
Multi-GPU Training¶
# Use multiple GPUs
python train.py +experiment=vime_ssl trainer.devices=2 trainer.strategy=ddp
# Adjust batch size for multi-GPU
python train.py +experiment=scarf_ssl trainer.devices=4 data.batch_size=512
Mixed Precision Training¶
All SSL experiments support mixed precision for faster training:
# Already enabled in experiments (precision: 16-mixed)
python train.py +experiment=vime_ssl
# Disable if needed
python train.py +experiment=vime_ssl trainer.precision=32
Hyperparameter Optimization¶
Use Hydra's multirun for hyperparameter sweeps:
# Sweep corruption rates for VIME
python train.py +experiment=vime_ssl -m model/corruption.corruption_rate=0.1,0.3,0.5
# Sweep SCARF parameters
python train.py +experiment=scarf_ssl -m \
model/corruption.corruption_rate=0.4,0.6,0.8 \
model.temperature=0.05,0.1,0.2
Monitoring Training¶
Key Metrics to Watch¶
VIME:
- train/mask_estimation_loss
- Should decrease steadily
- train/value_imputation_loss
- Should decrease steadily
- val/total_loss
- Overall validation performance
SCARF:
- train/contrastive_loss
- Should decrease and stabilize
- Representation quality metrics (if using downstream tasks)
ReConTab:
- train/masked_reconstruction
- Masking reconstruction quality
- train/denoising
- Noise removal quality
- train/unswapping
- Feature unswapping quality
Using Weights & Biases¶
SSL experiments are pre-configured for W&B logging:
# Logs automatically to your W&B account
python train.py +experiment=vime_ssl
# Customize project name
python train.py +experiment=vime_ssl logger.wandb.project=my-ssl-project
Troubleshooting¶
Poor Convergence¶
Problem: Training loss not decreasing
Solutions:
# Lower learning rate
python train.py +experiment=vime_ssl model.learning_rate=5e-5
# Increase warmup steps
python train.py +experiment=vime_ssl model.scheduler_type=cosine_with_warmup
# Reduce corruption rate
python train.py +experiment=vime_ssl model/corruption.corruption_rate=0.2
Memory Issues¶
Problem: CUDA out of memory
Solutions:
# Reduce batch size
python train.py +experiment=scarf_ssl data.batch_size=32
# Reduce sequence length
python train.py +experiment=vime_ssl data.sequence_length=16
# Use gradient accumulation
python train.py +experiment=vime_ssl trainer.accumulate_grad_batches=2
SCARF-Specific Issues¶
Problem: Contrastive loss not decreasing
Solutions:
# Increase batch size (critical for SCARF)
python train.py +experiment=scarf_ssl data.batch_size=256
# Adjust temperature
python train.py +experiment=scarf_ssl model.temperature=0.07
# Increase corruption rate
python train.py +experiment=scarf_ssl model/corruption.corruption_rate=0.8
Evaluation and Downstream Tasks¶
Save Trained Models¶
SSL experiments automatically save checkpoints:
# Training saves to outputs/YYYY-MM-DD/HH-MM-SS/
ls outputs/ # Find your experiment
# Best checkpoint is saved automatically
ls outputs/2024-01-15/14-30-45/checkpoints/
Extract Representations¶
import torch
from tabular_ssl.models.base import BaseModel
# Load trained model
model = BaseModel.load_from_checkpoint("path/to/checkpoint.ckpt")
model.eval()
# Extract representations
with torch.no_grad():
representations = model(your_data)
Downstream Task Training¶
Use pre-trained SSL models for downstream tasks:
# Load SSL checkpoint for fine-tuning
python train.py +experiment=classification_finetune \
model.ssl_checkpoint_path=outputs/2024-01-15/14-30-45/checkpoints/best.ckpt
Custom Corruption Strategies¶
Create Your Own Strategy¶
# custom_corruption.py
import torch
import torch.nn as nn
class CustomCorruption(nn.Module):
def __init__(self, corruption_rate: float = 0.2):
super().__init__()
self.corruption_rate = corruption_rate
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training:
return x
# Your custom corruption logic
mask = torch.rand_like(x) > self.corruption_rate
return x * mask
Use Custom Strategy¶
# configs/model/corruption/custom.yaml
_target_: path.to.custom_corruption.CustomCorruption
corruption_rate: 0.2
python train.py model/corruption=custom
Best Practices¶
1. Start with Demos¶
Always run demo_corruption_strategies.py
first to understand how each strategy works.
2. Use Appropriate Batch Sizes¶
- VIME/ReConTab: 32-128 typically sufficient
- SCARF: 128+ recommended for effective contrastive learning
3. Monitor Feature Types¶
Ensure categorical/numerical indices are correctly specified for optimal corruption.
4. Experiment with Corruption Rates¶
- Start with paper defaults
- Tune based on downstream task performance
- Higher rates aren't always better
5. Use Mixed Precision¶
Enable precision: 16-mixed
for 2x speedup with minimal quality loss.
6. Save Everything¶
SSL training can be expensive - ensure checkpointing is enabled.
Paper References¶
For implementation details and theoretical background:
- VIME: NeurIPS 2020
- SCARF: arXiv 2021
- General SSL: Self-Supervised Learning Survey