Skip to content

Commit b525127

Browse files
committed
NNX: vocab tiling custom_vjp with output-head carve-out
Replaces the PR9.5 NNX vocab-tiling MVP (chunked forward + default autograd backward) with a jax.custom_vjp that mirrors the Linen path's backward-memory savings, then carves out the output-head params so the custom_vjp's residuals + grad accumulator scale with LM-head size, not with the full model. Linen vocab_tiling_linen_loss is byte-for-byte unchanged. Call sites in train.py / pyconfig_deprecated.py / configs/types.py are unchanged. Custom_vjp + output-head carve-out (vocabulary_tiling.py): - Outside the custom_vjp: 3-way nnx.split with a callable path filter (_is_output_head_param_path) matching {token_embedder, shared_embedding, decoder_norm, logits_dense} — the only nnx.Param paths apply_output_head touches. Returns (graphdef, head_params, other_params, rest). - Custom_vjp primals: (head_params, other_params, rest, hidden_states, labels, segmentation). Only head_params and hidden_states are differentiated; other_params + rest are threaded through as non-differentiated primals so their tracers don't have to cross both the custom_vjp and the inner lax.scan boundary (which previously caused UnexpectedTracerError under logits_via_embedding=True). - Forward (_chunked_cross_entropy_loss_fwd): reshapes to (num_vocab_tiling, vocab_tile_size, ...) and runs lax.scan whose body rebuilds the model per chunk via nnx.merge(graphdef, chunk_head, chunk_other, chunk_rest, copy=True) and calls logits_from_hidden_states. Initial scan accumulator is fp32 (was hidden_states.dtype previously — caused a lax.scan carry dtype mismatch with bf16 hidden_states since cross_entropy_with_logits always returns fp32). Residuals are (chunk_head, chunk_other, chunk_rest, reshaped_*, batch/seq/emb). - Backward (_chunked_cross_entropy_loss_bwd): a second lax.scan whose body builds loss_fn_for_vjp = lambda p, h: ..., calls jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk), accumulates grad_head via tree.map(add), emits per-chunk grad_hidden. Chain-rules grad_head *= loss_cotangent and dtype-casts to each primal's dtype (custom_vjp requires this). chunk_other_params and chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) zero pytrees, NOT None — None makes JAX synthesize zeros at AOT trace time with axis-0 stacking (jax.scan convention) for nnx.scan-stacked transformer-layer params, which carry axis-1 stacking (nnx convention), and the cotangent-shape check fails as "Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M], but got bfloat16[L,E,M]". Materializing the zeros ties the cotangent shape to the primal shape exactly. - Correctness: logits_from_hidden_states provably depends only on head_params; the gradient w.r.t. other_params through this loss is exactly zero. When train.py also calls the full model forward (which produces hidden_states), transformer-layer gradients flow back through grad_hidden_states → outer backward, unaffected by the carve-out. Supporting fixes (touched for the carve-out to work end-to-end): - nnx_decoders.py::apply_output_head logits_via_embedding=True branch reads embedding_table = shared_embedding.embedding[...] instead of the deprecated .value shim. The .value shim registers the access in NNX mutation tracking, which JAX detects as a tracer leak when the embedding is closure-captured / threaded across the custom_vjp + lax.scan boundaries. The Linen branch is unchanged. - models.py: deletes dead-code self.hidden_states = None and if num_vocab_tiling > 1: self.hidden_states = hidden_state from the NNX Transformer class. Two lines left over from an early PR5 implementation idea — neither path actually reads model.hidden_states (Linen reads via mutable=["intermediates"]; NNX reads via nnx.pop(model, nnx.Intermediate) from the decoder's sown ("decoder", "hidden_states") intermediate). Without this fix, AOT compile under pure_nnx=True + num_vocab_tiling>1 raised ValueError: Cannot assign data value of type 'LinearizeTracer' to static attribute 'hidden_states' of Pytree type 'Transformer' — would have silently broken any post-PR11 user with vocab tiling on. Tests (tiling_test.py — new VocabTilingNNXTest class with 9 TPU tests): - test_nnx_vocab_tiling_non_tied_embedding / _tied_embedding: loss + grad parity vs. full-vocab xent reference for both LM-head modes. - test_nnx_vocab_tiling_total_z_loss_value_parity: asserts the second tuple element matches the reference (was untested before). - test_nnx_vocab_tiling_padded_segmentation: half-padded mask; exercises the segmentation != 0 mask branch and asserts padded loss is strictly less than unpadded. - test_nnx_vocab_tiling_grad_over_hidden_states: argnums=1 differentiation; exercises the custom_vjp's second-primal cotangent path (grad_reshaped_hidden_states), shape + dtype + value parity. - test_nnx_vocab_tiling_bf16_hidden_states: bf16 inputs with rtol/atol loosened to 5e-2; asserts grad_h.dtype == bf16 (the bwd dtype-cast preserves the primal's dtype). Caught the fp32-accumulator bug. - test_nnx_vocab_tiling_z_loss_zero: z_loss_multiplier=0; total_z_loss == 0.0 exactly and grad parity holds. - test_nnx_vocab_tiling_num_vocab_tiling_variants: runs n ∈ {2, 4, 8} and asserts identical loss + grads (catches off-by-one in vocab_tile_size and scan/reshape interactions). - test_nnx_vocab_tiling_other_params_get_zero_grad (carve-out invariant): asserts every non-head leaf has gradient exactly zero AND at least one head leaf has non-zero gradient (so the test can't trivially pass by zeroing everything). Catches filter bugs (e.g. forgetting that NNX names the embedder token_embedder while Linen names it shared_embedding) and bwd zero-shape bugs. AOT compile coverage (train_compile_test.py): - Removed the now-stale pytest.skip("Vocab tiling not supported on NNX.") in test_vocab_tiling_bf16. - Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train step under pure_nnx=true + enable_nnx=true + pure_nnx_decoder=true with num_vocab_tiling=4 and weight_dtype=bfloat16. Surfaced both the models.py dead-code regression and the cotangent-axis-ordering issue the explicit-zeros bwd fixes. Tests pass: 18 in tiling + AOT (7 Linen UTs + 9 NNX UTs + 2 AOT, one Linen and one NNX); 52 in adjacent NNX surfaces (train_nnx, dpo_nnx, grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx) — regression check for the nnx_decoders.py change.
1 parent 88417d0 commit b525127

5 files changed

Lines changed: 538 additions & 44 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,10 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
987987
if cfg.logits_via_embedding:
988988
# Use the transpose of embedding matrix for logit transform.
989989
if isinstance(shared_embedding, nnx.Module):
990-
embedding_table = shared_embedding.embedding.value
990+
# Modern NNX API; the deprecated `.value` shim registers the access in NNX's
991+
# mutation tracking, which JAX detects as a tracer leak when the embedding is
992+
# closure-captured across a custom_vjp boundary (e.g. vocab_tiling_nnx_loss).
993+
embedding_table = shared_embedding.embedding[...]
991994
else:
992995
embedding_table = shared_embedding.variables["params"]["embedding"]
993996
if isinstance(embedding_table, nn.spmd.LogicallyPartitioned):

src/maxtext/models/models.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ def __init__(
347347
else:
348348
decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
349349
self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs)
350-
self.hidden_states = None
351350

352351
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode)
353352
dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
@@ -545,10 +544,6 @@ def __call__(
545544
mutable=mutable_collections,
546545
) # pytype: disable=wrong-keyword-args
547546

548-
# Materialize hidden state when vocab tiling is enabled
549-
if self.config.num_vocab_tiling > 1:
550-
self.hidden_states = hidden_state
551-
552547
# If we are initializing the model AND MTP is enabled, we must create
553548
# dummy target tensors. This allows Flax to trace the MTPBlock and create
554549
# all its necessary parameters, without requiring the main training pipeline

src/maxtext/utils/vocabulary_tiling.py

Lines changed: 166 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import functools
1818

1919
from flax import linen as nn
20+
from flax import nnx
2021

2122
import jax
2223
import jax.numpy as jnp
@@ -29,6 +30,25 @@
2930
from maxtext.utils import max_utils
3031

3132

33+
# Submodule-name keys whose `nnx.Param` leaves are touched by
34+
# `Transformer.logits_from_hidden_states` (= `decoder.apply_output_head`):
35+
# * `token_embedder` / `shared_embedding` — token embedder; used for tied logits.
36+
# * `decoder_norm` — final layer norm.
37+
# * `logits_dense` — LM-head dense; used for non-tied logits.
38+
# Path filter for the 3-way `nnx.split` in `vocab_tiling_nnx_loss`'s output-head
39+
# carve-out: matching leaves go into `head_params` (the custom_vjp's differentiated
40+
# primal); everything else ends up in `other_params` and is threaded through as a
41+
# non-differentiated primal so the bwd can rebuild the model without crossing
42+
# trace boundaries.
43+
_OUTPUT_HEAD_PATH_KEYS = ("token_embedder", "shared_embedding", "decoder_norm", "logits_dense")
44+
45+
46+
def _is_output_head_param_path(path, _value):
47+
"""nnx.split callable filter: True iff `path` lies under an output-head submodule."""
48+
keys = [str(getattr(k, "key", k)) for k in path]
49+
return any(k in keys for k in _OUTPUT_HEAD_PATH_KEYS)
50+
51+
3252
def vocab_tiling_linen_loss(
3353
hidden_states,
3454
data,
@@ -252,14 +272,18 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
252272
def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train):
253273
"""Calculates cross-entropy loss using vocab tiling for NNX models.
254274
255-
NNX equivalent of `vocab_tiling_linen_loss`. Iterates the vocab dimension via
256-
`jax.lax.scan` with `model.logits_from_hidden_states` per chunk; the model
257-
carries its parameters internally so no explicit gather is needed.
275+
NNX equivalent of `vocab_tiling_linen_loss`. The model is partitioned via
276+
`nnx.split` into output-head params (`token_embedder`/`shared_embedding`,
277+
`decoder_norm`, `logits_dense`), other params (transformer layers, etc.), and
278+
non-Param state (rngs). Only the output-head params are the differentiated
279+
primal of the custom_vjp; other params + rest are threaded through as
280+
non-differentiated primals (bwd returns explicit zero pytrees of the same
281+
shape/dtype as each primal). Forward and backward scans both rebuild the model
282+
per chunk via `nnx.merge(..., copy=True)` and call `logits_from_hidden_states`.
258283
259-
This is a memory-efficient forward (chunked logits) but uses the default
260-
autograd path (no custom_vjp), so backward memory savings vs. the Linen
261-
custom_vjp path are not yet realized. TODO: add a custom_vjp using
262-
`nnx.split`/`nnx.merge` if backward memory becomes a concern.
284+
Backward memory is bounded by one chunk's logits (same as the Linen path).
285+
The output-head carve-out additionally shrinks the custom_vjp's residual +
286+
grad-accumulator scope from O(model params) to O(head params).
263287
264288
Args:
265289
model: The NNX model instance (must implement `logits_from_hidden_states`).
@@ -321,36 +345,145 @@ def _reshape(inputs, out_shape, out_sharding):
321345
labels = _maybe_shard_with_name(labels, label_spec)
322346
segmentation = _maybe_shard_with_name(segmentation, label_spec)
323347

324-
batch_size, seq_len, emb_dim = hidden_states.shape
325-
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling
348+
# 3-way split outside the custom_vjp:
349+
# * head_params — only leaves `logits_from_hidden_states` touches
350+
# (token_embedder/shared_embedding, decoder_norm, logits_dense). Differentiated.
351+
# * other_params — every other `nnx.Param` (transformer layers, etc.).
352+
# Threaded through the custom_vjp as a primal; bwd returns explicit zeros.
353+
# * rest — non-Param state (rngs). Threaded through as a primal too.
354+
# Threading non-head primals (instead of closure-capture) is required to avoid
355+
# `UnexpectedTracerError` when the embedded variables are accessed through the
356+
# custom_vjp + lax.scan boundaries (manifests on `logits_via_embedding=True`).
357+
graphdef, head_params, other_params, rest = nnx.split(model, _is_output_head_param_path, nnx.Param, ...)
358+
359+
def _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk):
360+
local_model = nnx.merge(graphdef, chunk_head_params, chunk_other_params, chunk_rest, copy=True)
361+
chunk_logits = local_model.logits_from_hidden_states(hidden_chunk, deterministic, model_mode)
362+
return _maybe_shard_with_name(chunk_logits, chunked_logits_spec)
326363

327-
reshaped_hidden_states = _reshape(
328-
hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec
329-
)
330-
reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
331-
reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
332-
333-
def _scan_body(accumulators, chunk_data):
334-
loss_accumulator, z_loss_accumulator = accumulators
335-
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
336-
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
337-
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
338-
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)
339-
340-
chunk_logits = model.logits_from_hidden_states(hidden_chunk, deterministic, model_mode)
341-
chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec)
342-
one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size)
343-
chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits(
344-
chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier
364+
@jax.custom_vjp
365+
def chunked_cross_entropy_loss(chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation):
366+
(total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd(
367+
chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation
368+
)
369+
return total_loss, total_z_loss
370+
371+
def _chunked_cross_entropy_loss_fwd(
372+
chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation
373+
):
374+
batch_size, seq_len, emb_dim = hidden_states.shape
375+
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling
376+
377+
reshaped_hidden_states = _reshape(
378+
hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec
345379
)
380+
reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
381+
reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
382+
383+
def _fwd_scan_body(accumulators, chunk_data):
384+
loss_accumulator, z_loss_accumulator = accumulators
385+
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
386+
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
387+
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
388+
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)
389+
390+
chunk_logits = _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk)
391+
one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size)
392+
chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits(
393+
chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier
394+
)
346395

347-
masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0))
348-
masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0))
396+
masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0))
397+
masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0))
349398

350-
return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None
399+
return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None
351400

352-
initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype))
353-
(total_loss, total_z_loss), _ = jax.lax.scan(
354-
_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
401+
# Always accumulate in fp32 — `cross_entropy_with_logits` returns fp32 regardless of
402+
# logits dtype, and a bf16 carry would mismatch the body output type under lax.scan.
403+
initial_acc = (jnp.zeros((), dtype=jnp.float32), jnp.zeros((), dtype=jnp.float32))
404+
(total_loss, total_z_loss), _ = jax.lax.scan(
405+
_fwd_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
406+
)
407+
residuals = (
408+
chunk_head_params,
409+
chunk_other_params,
410+
chunk_rest,
411+
reshaped_hidden_states,
412+
reshaped_labels,
413+
reshaped_segmentation,
414+
batch_size,
415+
seq_len,
416+
emb_dim,
417+
)
418+
return (total_loss, total_z_loss), residuals
419+
420+
def _chunked_cross_entropy_loss_bwd(residuals, cotangents):
421+
# z_loss is folded into the xent loss inside cross_entropy_with_logits.
422+
loss_cotangent, _ = cotangents
423+
424+
(
425+
chunk_head_params,
426+
chunk_other_params,
427+
chunk_rest,
428+
reshaped_hidden_states,
429+
reshaped_labels,
430+
reshaped_segmentation,
431+
batch_size,
432+
seq_len,
433+
emb_dim,
434+
) = residuals
435+
436+
def _single_chunk_loss_fn(input_head_params, input_hidden_chunk, input_label_chunk, input_segmentation_chunk):
437+
chunk_logits = _logits_for_chunk(input_head_params, chunk_other_params, chunk_rest, input_hidden_chunk)
438+
one_hot_label_chunk = jax.nn.one_hot(input_label_chunk, config.vocab_size)
439+
xent, _ = max_utils.cross_entropy_with_logits(chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier)
440+
return jnp.sum(xent * (input_segmentation_chunk != 0))
441+
442+
def _bwd_scan_body(grad_head_acc, chunk_data):
443+
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
444+
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
445+
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
446+
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)
447+
448+
# pylint: disable=unnecessary-lambda-assignment
449+
loss_fn_for_vjp = lambda p, h: _single_chunk_loss_fn(p, h, label_chunk, segmentation_chunk)
450+
_, vjp_fn = jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk)
451+
(grad_head_update, grad_hidden_chunk) = vjp_fn(1.0)
452+
grad_hidden_chunk = _maybe_shard_with_name(grad_hidden_chunk, chunked_hidden_spec)
453+
454+
grad_head_acc = jax.tree_util.tree_map(lambda acc, update: acc + update, grad_head_acc, grad_head_update)
455+
return grad_head_acc, grad_hidden_chunk
456+
457+
initial_grad_head = jax.tree_util.tree_map(jnp.zeros_like, chunk_head_params)
458+
459+
grad_head, grad_reshaped_hidden_states = jax.lax.scan(
460+
_bwd_scan_body, initial_grad_head, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
461+
)
462+
grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec)
463+
grad_head = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_head)
464+
grad_head = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), chunk_head_params, grad_head)
465+
grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)
466+
467+
# Explicit zero cotangents for `chunk_other_params` and `chunk_rest`. Returning `None`
468+
# makes JAX synthesize zeros at AOT time with the wrong axis convention for nnx-scanned
469+
# transformer layer params (axis-0 instead of nnx's axis-1 stacking), causing
470+
# `Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M], but got
471+
# bfloat16[L,E,M]` at trace check. Materializing the zeros here ties the cotangent
472+
# shape to the primal shape exactly.
473+
grad_other = jax.tree_util.tree_map(jnp.zeros_like, chunk_other_params)
474+
grad_rest = jax.tree_util.tree_map(jnp.zeros_like, chunk_rest)
475+
return (
476+
grad_head,
477+
grad_other,
478+
grad_rest,
479+
grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype),
480+
None,
481+
None,
482+
)
483+
484+
chunked_cross_entropy_loss.defvjp(_chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd)
485+
486+
total_loss, total_z_loss = chunked_cross_entropy_loss(
487+
head_params, other_params, rest, hidden_states, labels, segmentation
355488
)
356489
return total_loss, total_z_loss

0 commit comments

Comments
 (0)