Add OutputDiscardCheckpoint#682
Conversation
A Megatron-style activation-recompute primitive. Forward runs under no_grad; caller frees the output's storage after downstream consumption; a backward hook recomputes and shares storage back into the original output tensor objects without triggering autograd version errors. C++ share_storage extension built via torch.utils.cpp_extension.load_inline, with a Python fallback for environments without a compiler. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Monkeypatches _get_share_storage to None to exercise _fallback_share_storage on CI machines where ninja and a C++ compiler are present and the C++ extension would otherwise always be used. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a code-block example to the class docstring showing the four-step pattern (checkpoint, run downstream, discard+register, backward) and the constraints on the choice of hook_tensor. Add :param:/:returns: docstrings to the three public methods. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous Python fallback rebound dst's storage via Tensor.set_(), which swaps dst's StorageImpl for a new one. Any autograd-saved view of dst that still referenced the original StorageImpl (e.g. the 2D-reshape view that MmBackward saves when Linear is called with a 3D input) would see the original storage -- which was resize_(0)'d -- and backward would hit "tensor has non-zero numel but data is not allocated". The C++ extension path didn't have this bug: it mutates dst's existing StorageImpl in place via set_data_ptr, so saved views see the new data. Make the Python fallback equivalent by resizing dst's existing storage and copying src's bytes into it, preserving StorageImpl identity. This costs an extra allocation + copy during recompute on machines without a C++ toolchain, but is correct for tensors with saved views. Add a 3D-Linear regression test covering the failure mode. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two standalone scripts under src/scripts/ (intentionally outside src/test/ so CI does not run them): - benchmark_odc.py: compares baseline / torch.utils.checkpoint / ODC (C++ extension) / ODC (Python fallback forced) on a synthetic fat-output workload. Reports peak GPU memory and forward/backward wall time. Supports a single-shape mode and a grid sweep. - odc_ffn_integration_check.py: wraps the real olmo_core SwiGLU FeedForward with ODC around the fat (activation(w1(x)) * w3(x)) intermediate that w2 saves for backward. Runs a few iterations and asserts output and gradient parity vs the baseline FeedForward. Exits non-zero on any failure. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two GPU matrix entries that invoke the standalone scripts from src/scripts/ to validate OutputDiscardCheckpoint end-to-end on real GPU hardware. Revert before merging this PR. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Rewrite benchmark_odc.py around a BenchBlock abstraction so the wrapped region (the thing ODC discards) is interchangeable. Adds five concrete block types covering the spectrum of ODC fit: - fp32_cast: x.float() under ODC followed by an fp32 Linear (the MoE router's pattern). Recompute is trivial -> ODC should be neutral or positive even at N=1. - up_proj: fat Linear up-projection with no activation inside. Linear saves only its input -> recompute spike is small. - silu_up: silu(up(x)) inside; activation adds a saved intermediate to recompute, raising recompute peak. - swiglu: OLMo SwiGLU FFN; three fat intermediates saved during recompute -> worst-case ODC recompute footprint. - rms_norm: RMSNorm + Linear + residual; cheap recompute, modest savings. Each block runs at N = 1 (single-layer, ODC's worst case because the savings window is zero) and N = args.n_layers (default 4) so the multi-layer payoff is visible. Adds --only to restrict to a subset and --layers to override the stack depths. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Chain N ODCFeedForward blocks (default N = 4) and verify per-iteration output + gradient parity vs a baseline stack. Exercises per-block OutputDiscardCheckpoint instances and the order in which their recompute hooks fire as backward walks back through the stack -- a regression here would silently corrupt gradients in multi-FFN training. Adds --n-layers (default 4) and --layers (override). By default runs both N = 1 and N = --n-layers so the single-layer case stays covered. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Flip the outer/inner loop in main() so output is grouped by stack depth first, then iterates through all block types within each depth. Makes cross-block-type comparison at a fixed depth easier to read. Also adds an --iters flag (default 10) for the number of timed iterations, matching the workflow's invocation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- New FP32SoftmaxBlock: bf16 -> upcast -> fp32 softmax -> downcast pattern, modeling the attention/router softmax-in-fp32 path. Softmax saves its OUTPUT for backward, so the discarded tensor is the fp32 softmax probs. - --dtype now accepts multiple values (default [bf16]). Each is run as a separate top-level group so the precision-boundary effect is visible side-by-side (fp32_cast / fp32_softmax discard tensor doubles at bf16/fp16 base, neutral at fp32 base). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
So the precision-boundary effect on fp32_cast and fp32_softmax is visible in the CI output. Revert before merge along with the other temporary ODC steps. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When the input is already fp32, x.float() returns x itself (no copy), so h aliases x. ODC's discard would resize x's storage to 0, breaking backward. Same aliasing issue for torch.utils.checkpoint -- its recompute graph would also wrap an identity. Skip the checkpoint variants for this degenerate case so the benchmark row truthfully reports "no benefit" rather than crashing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Benchmarking |
TianhuaTao
left a comment
There was a problem hiding this comment.
Looks ok for now as an alpha feature.
There might be concerns regarding RNG state, autocast context, and cases where requires_grad == False, but for now we don't have them in our active paths, so it should be fine.
|
|
||
| from .output_discard_checkpoint import OutputDiscardCheckpoint | ||
|
|
||
| __all__ = ["OutputDiscardCheckpoint"] |
|
|
||
| from ..doc_utils import beta_feature | ||
|
|
||
| __all__ = ["OutputDiscardCheckpoint"] |
| __all__ = ["OutputDiscardCheckpoint"] | ||
|
|
||
|
|
||
| _SHARE_STORAGE_CPP = r""" |
There was a problem hiding this comment.
Is this the standard way to write inline C++ in Olmo-core? I think it would be cleaner to have it in a separate .cpp file with a test.
| _share_storage_fallback_warned = False | ||
|
|
||
|
|
||
| def _get_share_storage() -> Optional[Callable[[torch.Tensor, torch.Tensor], None]]: |
There was a problem hiding this comment.
Generally, I don't like using global as I find it very hard to reason about. What if we used a module-level singleton here? Something like:
@dataclasses.dataclass
class _SharedStorageLoader:
_ext: Any = None
_build_error: Optional[Exception] = None
_lock: Any = dataclasses.field(default_factory=threading.Lock)
def share(self, dst: torch.Tensor, src: torch.Tensor) -> None:
if (fn := self._load()) is not None:
fn(dst, src)
else:
self._fallback(dst, src)
....
_shared_storage_loader = _SharedStorageLoader()
| _fallback_share_storage(dst, src) | ||
|
|
||
|
|
||
| def _collect_tensor_outputs(outputs: Any) -> Tuple[torch.Tensor, ...]: |
|
|
||
|
|
||
| def _collect_tensor_outputs(outputs: Any) -> Tuple[torch.Tensor, ...]: | ||
| if isinstance(outputs, torch.Tensor): |
There was a problem hiding this comment.
What about:
def _collect_tensor_outputs(outputs: Any) -> tuple[torch.Tensor, ...]:
items = outputs if isinstance(outputs, (tuple, list)) else (outputs,)
if not all(isinstance(o, torch.Tensor) for o in items):
raise TypeError(
"OutputDiscardCheckpoint only supports tensor outputs or tuple/list of tensors."
)
return tuple(items)
| ) | ||
|
|
||
|
|
||
| def _detach_but_keep_requires_grad(x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Inline, as this is only used once?
|
|
||
| recompute_args: list[Any] = [] | ||
| recompute_tensor_inputs: list[torch.Tensor] = [] | ||
| for is_tensor in cast(Tuple[bool, ...], ctx.arg_is_tensor): |
| dtype=dtype, | ||
| device=device, | ||
| ) | ||
| print() |
There was a problem hiding this comment.
Wdyt about writing this as a test instead?
Summary
Add
OutputDiscardCheckpoint, an activation-recompute primitive for cases where the output of a checkpointed region dominates memory (rather than its intermediates).Vanilla
torch.utils.checkpointdiscards intermediates inside the wrapped function but can't free the output -- downstream consumers and their saved-for-backward references hold it live.OutputDiscardCheckpointextends that pattern: forward runs underno_grad, the output's storage can be freed after downstream forward consumes it, and a backward hook recomputes the forward and rebinds the freed storage in place (via a C++share_storageextension, with a Python fallback). The tensor object survives so existing autograd saved-tensor references stay valid; only its underlying bytes are recycled.Useful for fat-output ops where the output is wider than the input -- precision casts (bf16 -> fp32 doubling), FFN up-projections, attention outputs before SDPA fuses them.
Contents
src/olmo_core/nn/output_discard_checkpoint.py-- the primitive (287 lines), with a code-block usage example and:param:/:returns:docstrings.src/olmo_core/nn/__init__.py-- exportsOutputDiscardCheckpoint.src/test/nn/output_discard_checkpoint_test.py-- four tests: storage discard/restore, grad parity vs a non-checkpointedSequential, Python-fallback path forced via monkeypatch, and a 3D-Linear regression test (see below).Bug found and fixed during integration testing
Writing the FFN integration script surfaced a real bug in the Python fallback.
Tensor.set_(new_storage, ...)swapsdst'sStorageImplfor a new one. Any autograd-saved view ofdstthat still referenced the originalStorageImpl-- e.g. the 2D-reshape view thatMmBackwardsaves whenLinearis called with a 3D input -- would see the original storage, which wasresize_(0)'d and never refilled. Backward then hit "tensor has non-zero numel but data is not allocated".The C++ path didn't have this bug: it mutates
dst's existingStorageImplin place viaset_data_ptr, so saved views see the new data through the sameStorageImpl. The fix makes the Python fallback equivalent: resizedst's existing storage and copysrc's bytes into it. Costs an extra allocation + copy during recompute on machines without a C++ toolchain, but is correct for tensors with saved views (which is most real workloads).The 3D-Linear regression test ensures this doesn't sneak back in.
Standalone scripts (not run in CI)
Two scripts under
src/scripts/-- intentionally outsidesrc/test/so pytest never picks them up. Useful for human-driven benchmarking and verification:src/scripts/benchmark_odc.py-- compares baseline /torch.utils.checkpoint/ ODC (C++) / ODC (Python fallback) on a fat-outputLinear -> activation -> Linearworkload. Reports peak GPU memory and forward/backward wall time. Supports--scenarios gridfor a shape sweep.src/scripts/odc_ffn_integration_check.py-- wraps the real OLMoFeedForward(SwiGLU) with ODC around the fatactivation(w1(x)) * w3(x)intermediate thatw2saves for backward. Asserts output + gradient parity vs the baseline; exits non-zero on failure.Test plan
pytest -v src/test/nn/output_discard_checkpoint_test.py-- 4 tests pass (incl. forced Python fallback and 3D-Linear regression).make checks-- isort / black / ruff / mypy all clean.python src/scripts/odc_ffn_integration_check.py --device cpu-- PASS (output and grad parity at 0 diff in fp32).python src/scripts/benchmark_odc.py-- visually verify ODC peak memory < baseline.python src/scripts/odc_ffn_integration_check.py --dtype bf16-- verify parity holds in bf16.🤖 Generated with Claude Code