Task
Store activations only at transformer block boundaries (3 checkpoints).
Recompute intermediates during backward pass.
Scientific Background
NVIDIA NeMo Gradient Checkpointing (docs.nvidia.com)
Three variants with different tradeoffs:
- Full recompute: checkpoint only block inputs, recompute all → 30% compute overhead
- Selective: recompute only attention intermediates (softmax, dropout, QKV dot) → best ratio
- Partial: recompute subset of layers → tunable
Memory Analysis for 1.95M Model
- 3 blocks × hidden=729 × context=81 × batch=128 = ~27MB activation memory (FP32)
- With gradient checkpointing: ~9MB (3 checkpoints only)
- Savings: 67% → enables batch=256 or batch=384
Is It Worth It for 1.95M params?
Research says: MARGINAL for tiny models (NVIDIA NeMo, PyTorch docs):
- Memory savings significant only when activations dominate (true for large batch)
- 20% compute overhead is real cost
- BUT: enabling batch=256 from batch=128 → better gradient estimates → faster convergence
Better Alternative: Selective Checkpointing
- Only checkpoint attention layers (quadratic memory in sequence length)
- Keep FFN activations (cheap to store, expensive to recompute)
- For context=81: attention memory = 81² × heads × batch = small
- Verdict: at context=81, attention memory is already tiny. Gradient checkpointing gives most value when context>>256.
Revised Recommendation
Given context=81 (very short), gradient checkpointing provides limited benefit.
Better investment: increase batch size directly (M1 Pro 16GB can handle batch=256 without checkpointing).
IF implementing anyway (for future context=162+):
Changes
src/hslm/transformer.zig: save activations only at block boundaries
- Recompute attention intermediates (softmax, QKV) during backward
- Flag:
--gradient-checkpoint to enable/disable
- 3 blocks = 3 checkpoints = minimal overhead
Expected
- 50-67% memory reduction on activations
- 20% compute overhead (recomputation)
- Net benefit: enables batch=256+ → ~5% faster convergence
- Most useful when context grows to 162+
Priority: LOW for current config (context=81), HIGH if context increases
References
- NVIDIA activation recomputation: docs.nvidia.com/nemo-framework/features/optimizations/activation_recomputation.html
- PyTorch activation checkpointing: pytorch.org/blog/activation-checkpointing-techniques/
- Memory analysis: https://arxiv.org/abs/2501.11847
Task
Store activations only at transformer block boundaries (3 checkpoints).
Recompute intermediates during backward pass.
Scientific Background
NVIDIA NeMo Gradient Checkpointing (docs.nvidia.com)
Three variants with different tradeoffs:
Memory Analysis for 1.95M Model
Is It Worth It for 1.95M params?
Research says: MARGINAL for tiny models (NVIDIA NeMo, PyTorch docs):
Better Alternative: Selective Checkpointing
Revised Recommendation
Given context=81 (very short), gradient checkpointing provides limited benefit.
Better investment: increase batch size directly (M1 Pro 16GB can handle batch=256 without checkpointing).
IF implementing anyway (for future context=162+):
Changes
src/hslm/transformer.zig: save activations only at block boundaries--gradient-checkpointto enable/disableExpected
Priority: LOW for current config (context=81), HIGH if context increases
References