Skip to content

Support checkpoints in gradients #765

@josevalim

Description

@josevalim

Introduce Nx.Defn.checkpoint (or Nx.Defn.remat) to control the memory-computation tradeoff during reverse-mode AD in defn. Users and libraries (e.g. Axon) can mark computation regions whose intermediates are recomputed during the backward pass rather than stored, reducing peak memory for large models.

Background

Nx.Defn.grad performs reverse-mode AD. The current implementation (Nx.Defn.Grad) retains all intermediate activations as residuals for the backward pass. For an n-layer network this means O(n) activation memory — the dominant cost when training large models.

defn forward(params, input) do
  input
  |> dense(params.layer1)  # activation stored
  |> Nx.relu()             # activation stored
  |> dense(params.layer2)  # activation stored
  |> Nx.relu()             # activation stored
  # ... n layers
end

All 2n intermediates must be live simultaneously during backprop, quickly exhausting GPU/TPU memory for transformers and similar architectures.

Proposed solution: Nx.Defn.checkpoint

Nx.Defn.checkpoint(fun)

Only the inputs to fun are saved; intermediates inside fun are discarded and recomputed on demand during the backward pass.

API

defn forward(params, input) do
  input
  |> Nx.Defn.checkpoint(&dense_block(params.layer1, &1))
  |> Nx.Defn.checkpoint(&dense_block(params.layer2, &1))
  |> Nx.Defn.checkpoint(&dense_block(params.layer3, &1))
  |> softmax_cross_entropy(labels)
end

defnp dense_block(params, x) do
  x |> Nx.dot(params.weights) |> Nx.add(params.bias) |> Nx.relu()
end

Activation memory drops from O(n) to O(sqrt(n)) with optimal checkpoint placement (Chen et al., 2016).

Axon integration (axon#372):

Axon.input("input")
|> Axon.dense(512, activation: :relu)
|> Axon.checkpoint()
|> Axon.dense(256, activation: :relu)
|> Axon.checkpoint()
|> Axon.dense(10)

Semantics

  1. Outside grad: no-op — checkpoint(fun)fun.()
  2. Inside grad: backward pass re-executes fun with saved inputs instead of using stored intermediates
  3. Composability: must work with while/4, cond/3, containers, nested checkpoints, and custom_grad/2

Implementation

Expression node

New :checkpoint node in Nx.Defn.Expr, following the :optional/:while pattern:

%Nx.Defn.Expr{
  op: :checkpoint,
  args: [input_container_expr, body_expr, body_fun]
}
  • input_container_expr — inputs saved for recomputation
  • body_expr — traced forward (for shape/type inference)
  • body_fun — re-executed during backprop

Gradient transform

In Nx.Defn.Grad, when update_grads/6 encounters :checkpoint:

  1. Re-trace body_fun with saved inputs
  2. Compute gradients through the re-traced body
  3. Propagate input gradients to the parent graph

This mirrors JAX's remat_call primitive.

Backend support

  • Nx.Defn.Evaluator: pass-through (optionally remat for memory savings in eager mode)
  • EXLA: must prevent XLA's CSE from defeating rematerialization; use rematerialization hints or custom-call boundaries to signal recomputation regions to HLO
  • Torchx: delegate to torch.utils.checkpoint.checkpoint

Interaction with Nx.block (#946)

The two are complementary — Nx.block provides extensible named computation with backend dispatch; checkpoint provides memory optimization for gradients. A checkpoint can wrap a block and vice versa.

Future extensions

Rematerialization policies

An initial implementation should recompute everything. A future version could add selective policies (à la JAX's checkpoint_policies):

Nx.Defn.checkpoint(fun, policy: fn op, _args -> op in [:dot, :conv] end)

This saves expensive-to-recompute ops (matmuls, convolutions) while recomputing cheap elementwise ops. The initial design should accommodate this.

Comparison

Aspect Today (store all) checkpoint Full recompute
Activation memory O(n) O(sqrt(n)) O(1)
Compute overhead ~1.33×
User control None Explicit N/A
CSE prevention needed No Yes N/A

Implementation plan

  • Add :checkpoint node to Nx.Defn.Expr
  • Tracing support: Nx.Defn.checkpoint/1 builds correct expression
  • Nx.Defn.Tree traversals (apply_args/4, traverse/3, etc.)
  • Gradient rule in Nx.Defn.Grad (recompute forward, then differentiate)
  • Nx.Defn.Evaluator support
  • EXLA compiler support with CSE prevention
  • Torchx backend support
  • Tests: gradient correctness, nested checkpoints, while/cond interaction, custom_grad interaction, container inputs/outputs
  • Documentation and examples
  • (Downstream) Axon per-layer/per-block checkpointing (axon#372)

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions