Introduce Nx.Defn.checkpoint (or Nx.Defn.remat) to control the memory-computation tradeoff during reverse-mode AD in defn. Users and libraries (e.g. Axon) can mark computation regions whose intermediates are recomputed during the backward pass rather than stored, reducing peak memory for large models.
Background
Nx.Defn.grad performs reverse-mode AD. The current implementation (Nx.Defn.Grad) retains all intermediate activations as residuals for the backward pass. For an n-layer network this means O(n) activation memory — the dominant cost when training large models.
defn forward(params, input) do
input
|> dense(params.layer1) # activation stored
|> Nx.relu() # activation stored
|> dense(params.layer2) # activation stored
|> Nx.relu() # activation stored
# ... n layers
end
All 2n intermediates must be live simultaneously during backprop, quickly exhausting GPU/TPU memory for transformers and similar architectures.
Proposed solution: Nx.Defn.checkpoint
Only the inputs to fun are saved; intermediates inside fun are discarded and recomputed on demand during the backward pass.
API
defn forward(params, input) do
input
|> Nx.Defn.checkpoint(&dense_block(params.layer1, &1))
|> Nx.Defn.checkpoint(&dense_block(params.layer2, &1))
|> Nx.Defn.checkpoint(&dense_block(params.layer3, &1))
|> softmax_cross_entropy(labels)
end
defnp dense_block(params, x) do
x |> Nx.dot(params.weights) |> Nx.add(params.bias) |> Nx.relu()
end
Activation memory drops from O(n) to O(sqrt(n)) with optimal checkpoint placement (Chen et al., 2016).
Axon integration (axon#372):
Axon.input("input")
|> Axon.dense(512, activation: :relu)
|> Axon.checkpoint()
|> Axon.dense(256, activation: :relu)
|> Axon.checkpoint()
|> Axon.dense(10)
Semantics
- Outside
grad: no-op — checkpoint(fun) ≡ fun.()
- Inside
grad: backward pass re-executes fun with saved inputs instead of using stored intermediates
- Composability: must work with
while/4, cond/3, containers, nested checkpoints, and custom_grad/2
Implementation
Expression node
New :checkpoint node in Nx.Defn.Expr, following the :optional/:while pattern:
%Nx.Defn.Expr{
op: :checkpoint,
args: [input_container_expr, body_expr, body_fun]
}
input_container_expr — inputs saved for recomputation
body_expr — traced forward (for shape/type inference)
body_fun — re-executed during backprop
Gradient transform
In Nx.Defn.Grad, when update_grads/6 encounters :checkpoint:
- Re-trace
body_fun with saved inputs
- Compute gradients through the re-traced body
- Propagate input gradients to the parent graph
This mirrors JAX's remat_call primitive.
Backend support
Nx.Defn.Evaluator: pass-through (optionally remat for memory savings in eager mode)
- EXLA: must prevent XLA's CSE from defeating rematerialization; use rematerialization hints or custom-call boundaries to signal recomputation regions to HLO
- Torchx: delegate to
torch.utils.checkpoint.checkpoint
Interaction with Nx.block (#946)
The two are complementary — Nx.block provides extensible named computation with backend dispatch; checkpoint provides memory optimization for gradients. A checkpoint can wrap a block and vice versa.
Future extensions
Rematerialization policies
An initial implementation should recompute everything. A future version could add selective policies (à la JAX's checkpoint_policies):
Nx.Defn.checkpoint(fun, policy: fn op, _args -> op in [:dot, :conv] end)
This saves expensive-to-recompute ops (matmuls, convolutions) while recomputing cheap elementwise ops. The initial design should accommodate this.
Comparison
| Aspect |
Today (store all) |
checkpoint |
Full recompute |
| Activation memory |
O(n) |
O(sqrt(n)) |
O(1) |
| Compute overhead |
1× |
~1.33× |
2× |
| User control |
None |
Explicit |
N/A |
| CSE prevention needed |
No |
Yes |
N/A |
Implementation plan
References
Introduce
Nx.Defn.checkpoint(orNx.Defn.remat) to control the memory-computation tradeoff during reverse-mode AD indefn. Users and libraries (e.g. Axon) can mark computation regions whose intermediates are recomputed during the backward pass rather than stored, reducing peak memory for large models.Background
Nx.Defn.gradperforms reverse-mode AD. The current implementation (Nx.Defn.Grad) retains all intermediate activations as residuals for the backward pass. For ann-layer network this means O(n) activation memory — the dominant cost when training large models.All
2nintermediates must be live simultaneously during backprop, quickly exhausting GPU/TPU memory for transformers and similar architectures.Proposed solution:
Nx.Defn.checkpointOnly the inputs to
funare saved; intermediates insidefunare discarded and recomputed on demand during the backward pass.API
Activation memory drops from O(n) to O(sqrt(n)) with optimal checkpoint placement (Chen et al., 2016).
Axon integration (axon#372):
Semantics
grad: no-op —checkpoint(fun)≡fun.()grad: backward pass re-executesfunwith saved inputs instead of using stored intermediateswhile/4,cond/3, containers, nested checkpoints, andcustom_grad/2Implementation
Expression node
New
:checkpointnode inNx.Defn.Expr, following the:optional/:whilepattern:input_container_expr— inputs saved for recomputationbody_expr— traced forward (for shape/type inference)body_fun— re-executed during backpropGradient transform
In
Nx.Defn.Grad, whenupdate_grads/6encounters:checkpoint:body_funwith saved inputsThis mirrors JAX's
remat_callprimitive.Backend support
Nx.Defn.Evaluator: pass-through (optionally remat for memory savings in eager mode)torch.utils.checkpoint.checkpointInteraction with
Nx.block(#946)The two are complementary —
Nx.blockprovides extensible named computation with backend dispatch;checkpointprovides memory optimization for gradients. Acheckpointcan wrap ablockand vice versa.Future extensions
Rematerialization policies
An initial implementation should recompute everything. A future version could add selective policies (à la JAX's
checkpoint_policies):This saves expensive-to-recompute ops (matmuls, convolutions) while recomputing cheap elementwise ops. The initial design should accommodate this.
Comparison
checkpointImplementation plan
:checkpointnode toNx.Defn.ExprNx.Defn.checkpoint/1builds correct expressionNx.Defn.Treetraversals (apply_args/4,traverse/3, etc.)Nx.Defn.Grad(recompute forward, then differentiate)Nx.Defn.Evaluatorsupportwhile/condinteraction,custom_gradinteraction, container inputs/outputsReferences
nx/lib/nx/defn/grad.ex,expr.ex,tree.ex,evaluator.ex;exla/lib/exla/defn.ex;torchx/lib/torchx/backend.exjax.checkpoint/jax.remat, JAX PR #1749, PyTorchtorch.utils.checkpoint, PyTorch min-cut recomputation blog, Chen et al. "Training Deep Nets with Sublinear Memory Cost" (2016)