Skip to content

feat(hslm): gradient checkpointing for memory efficiency #317

@gHashTag

Description

@gHashTag

Task

Store activations only at transformer block boundaries (3 checkpoints).
Recompute intermediates during backward pass.

Scientific Background

NVIDIA NeMo Gradient Checkpointing (docs.nvidia.com)

Three variants with different tradeoffs:

  1. Full recompute: checkpoint only block inputs, recompute all → 30% compute overhead
  2. Selective: recompute only attention intermediates (softmax, dropout, QKV dot) → best ratio
  3. Partial: recompute subset of layers → tunable

Memory Analysis for 1.95M Model

  • 3 blocks × hidden=729 × context=81 × batch=128 = ~27MB activation memory (FP32)
  • With gradient checkpointing: ~9MB (3 checkpoints only)
  • Savings: 67% → enables batch=256 or batch=384

Is It Worth It for 1.95M params?

Research says: MARGINAL for tiny models (NVIDIA NeMo, PyTorch docs):

  • Memory savings significant only when activations dominate (true for large batch)
  • 20% compute overhead is real cost
  • BUT: enabling batch=256 from batch=128 → better gradient estimates → faster convergence

Better Alternative: Selective Checkpointing

  • Only checkpoint attention layers (quadratic memory in sequence length)
  • Keep FFN activations (cheap to store, expensive to recompute)
  • For context=81: attention memory = 81² × heads × batch = small
  • Verdict: at context=81, attention memory is already tiny. Gradient checkpointing gives most value when context>>256.

Revised Recommendation

Given context=81 (very short), gradient checkpointing provides limited benefit.
Better investment: increase batch size directly (M1 Pro 16GB can handle batch=256 without checkpointing).

IF implementing anyway (for future context=162+):

Changes

  • src/hslm/transformer.zig: save activations only at block boundaries
  • Recompute attention intermediates (softmax, QKV) during backward
  • Flag: --gradient-checkpoint to enable/disable
  • 3 blocks = 3 checkpoints = minimal overhead

Expected

  • 50-67% memory reduction on activations
  • 20% compute overhead (recomputation)
  • Net benefit: enables batch=256+ → ~5% faster convergence
  • Most useful when context grows to 162+

Priority: LOW for current config (context=81), HIGH if context increases

References

  • NVIDIA activation recomputation: docs.nvidia.com/nemo-framework/features/optimizations/activation_recomputation.html
  • PyTorch activation checkpointing: pytorch.org/blog/activation-checkpointing-techniques/
  • Memory analysis: https://arxiv.org/abs/2501.11847

Metadata

Metadata

Assignees

No one assigned

    Labels

    agent:spawnAuto-spawn agent container

    Projects

    Status
    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions