Skip to content

Commit a56510e

Browse files
committed
NNX: vocab tiling custom_vjp with output-head carve-out
Replaces the PR9.5 NNX vocab-tiling MVP (chunked forward + default autograd backward) with a jax.custom_vjp that mirrors the Linen path's backward-memory savings, then carves out the output-head params so the custom_vjp's residuals + grad accumulator scale with LM-head size, not with the full model. Linen vocab_tiling_linen_loss is byte-for-byte unchanged. Call sites in train.py / pyconfig_deprecated.py / configs/types.py are unchanged. Custom_vjp + output-head carve-out (vocabulary_tiling.py): - Outside the custom_vjp: 3-way nnx.split with a callable path filter (_is_output_head_param_path) matching {token_embedder, shared_embedding, decoder_norm, logits_dense} — the only nnx.Param paths apply_output_head touches. Returns (graphdef, head_params, other_params, rest). - Custom_vjp primals: (head_params, other_params, rest, hidden_states, labels, segmentation). Only head_params and hidden_states are differentiated; other_params + rest are threaded through as non-differentiated primals so their tracers don't have to cross both the custom_vjp and the inner lax.scan boundary (which previously caused UnexpectedTracerError under logits_via_embedding=True). - Forward (_chunked_cross_entropy_loss_fwd): reshapes to (num_vocab_tiling, vocab_tile_size, ...) and runs lax.scan whose body rebuilds the model per chunk via nnx.merge(graphdef, chunk_head, chunk_other, chunk_rest, copy=True) and calls logits_from_hidden_states. Initial scan accumulator is fp32 (was hidden_states.dtype previously — caused a lax.scan carry dtype mismatch with bf16 hidden_states since cross_entropy_with_logits always returns fp32). Residuals are (chunk_head, chunk_other, chunk_rest, reshaped_*, batch/seq/emb). - Backward (_chunked_cross_entropy_loss_bwd): a second lax.scan whose body builds loss_fn_for_vjp = lambda p, h: ..., calls jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk), accumulates grad_head via tree.map(add), emits per-chunk grad_hidden. Chain-rules grad_head *= loss_cotangent and dtype-casts to each primal's dtype (custom_vjp requires this). chunk_other_params and chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) zero pytrees, NOT None — None makes JAX synthesize zeros at AOT trace time with axis-0 stacking (jax.scan convention) for nnx.scan-stacked transformer-layer params, which carry axis-1 stacking (nnx convention), and the cotangent-shape check fails as "Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M], but got bfloat16[L,E,M]". Materializing the zeros ties the cotangent shape to the primal shape exactly. - Correctness: logits_from_hidden_states provably depends only on head_params; the gradient w.r.t. other_params through this loss is exactly zero. When train.py also calls the full model forward (which produces hidden_states), transformer-layer gradients flow back through grad_hidden_states → outer backward, unaffected by the carve-out. Supporting fixes (touched for the carve-out to work end-to-end): - nnx_decoders.py::apply_output_head logits_via_embedding=True branch reads embedding_table = shared_embedding.embedding[...] instead of the deprecated .value shim. The .value shim registers the access in NNX mutation tracking, which JAX detects as a tracer leak when the embedding is closure-captured / threaded across the custom_vjp + lax.scan boundaries. The Linen branch is unchanged. - models.py: deletes dead-code self.hidden_states = None and if num_vocab_tiling > 1: self.hidden_states = hidden_state from the NNX Transformer class. Two lines left over from an early PR5 implementation idea — neither path actually reads model.hidden_states (Linen reads via mutable=["intermediates"]; NNX reads via nnx.pop(model, nnx.Intermediate) from the decoder's sown ("decoder", "hidden_states") intermediate). Without this fix, AOT compile under pure_nnx=True + num_vocab_tiling>1 raised ValueError: Cannot assign data value of type 'LinearizeTracer' to static attribute 'hidden_states' of Pytree type 'Transformer' — would have silently broken any post-PR11 user with vocab tiling on. Tests (tiling_test.py — new VocabTilingNNXTest class with 9 TPU tests): - test_nnx_vocab_tiling_non_tied_embedding / _tied_embedding: loss + grad parity vs. full-vocab xent reference for both LM-head modes. - test_nnx_vocab_tiling_total_z_loss_value_parity: asserts the second tuple element matches the reference (was untested before). - test_nnx_vocab_tiling_padded_segmentation: half-padded mask; exercises the segmentation != 0 mask branch and asserts padded loss is strictly less than unpadded. - test_nnx_vocab_tiling_grad_over_hidden_states: argnums=1 differentiation; exercises the custom_vjp's second-primal cotangent path (grad_reshaped_hidden_states), shape + dtype + value parity. - test_nnx_vocab_tiling_bf16_hidden_states: bf16 inputs with rtol/atol loosened to 5e-2; asserts grad_h.dtype == bf16 (the bwd dtype-cast preserves the primal's dtype). Caught the fp32-accumulator bug. - test_nnx_vocab_tiling_z_loss_zero: z_loss_multiplier=0; total_z_loss == 0.0 exactly and grad parity holds. - test_nnx_vocab_tiling_num_vocab_tiling_variants: runs n ∈ {2, 4, 8} and asserts identical loss + grads (catches off-by-one in vocab_tile_size and scan/reshape interactions). - test_nnx_vocab_tiling_other_params_get_zero_grad (carve-out invariant): asserts every non-head leaf has gradient exactly zero AND at least one head leaf has non-zero gradient (so the test can't trivially pass by zeroing everything). Catches filter bugs (e.g. forgetting that NNX names the embedder token_embedder while Linen names it shared_embedding) and bwd zero-shape bugs. AOT compile coverage (train_compile_test.py): - Removed the now-stale pytest.skip("Vocab tiling not supported on NNX.") in test_vocab_tiling_bf16. - Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train step under pure_nnx=true + enable_nnx=true + pure_nnx_decoder=true with num_vocab_tiling=4 and weight_dtype=bfloat16. Surfaced both the models.py dead-code regression and the cotangent-axis-ordering issue the explicit-zeros bwd fixes. Tests pass: 18 in tiling + AOT (7 Linen UTs + 9 NNX UTs + 2 AOT, one Linen and one NNX); 52 in adjacent NNX surfaces (train_nnx, dpo_nnx, grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx) — regression check for the nnx_decoders.py change.
1 parent f9d94d1 commit a56510e

5 files changed

Lines changed: 507 additions & 48 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,9 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
984984
if cfg.logits_via_embedding:
985985
# Use the transpose of embedding matrix for logit transform.
986986
if isinstance(shared_embedding, nnx.Module):
987-
embedding_table = shared_embedding.embedding.value
987+
# Use [...] not the deprecated .value: .value records the read in NNX's mutation
988+
# tracking, which leaks a tracer out of vocab_tiling_nnx_loss's custom_vjp.
989+
embedding_table = shared_embedding.embedding[...]
988990
else:
989991
embedding_table = shared_embedding.variables["params"]["embedding"]
990992
if isinstance(embedding_table, nn.spmd.LogicallyPartitioned):

src/maxtext/models/models.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ def __init__(
347347
else:
348348
decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
349349
self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs)
350-
self.hidden_states = None
351350

352351
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode)
353352
dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
@@ -541,10 +540,6 @@ def __call__(
541540
mutable=mutable_collections,
542541
) # pytype: disable=wrong-keyword-args
543542

544-
# Materialize hidden state when vocab tiling is enabled
545-
if self.config.num_vocab_tiling > 1:
546-
self.hidden_states = hidden_state
547-
548543
# If we are initializing the model AND MTP is enabled, we must create
549544
# dummy target tensors. This allows Flax to trace the MTPBlock and create
550545
# all its necessary parameters, without requiring the main training pipeline

src/maxtext/utils/vocabulary_tiling.py

Lines changed: 145 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@
3030
from maxtext.utils import max_utils
3131

3232

33+
# Submodule names whose params are used by logits_from_hidden_states_for_vocab_tiling:
34+
# the final norm, the LM-head dense, and the embedding table when logits are tied.
35+
# vocab_tiling_nnx_loss splits these out as the only params the loss differentiates.
36+
_OUTPUT_HEAD_PATH_KEYS = ("token_embedder", "shared_embedding", "decoder_norm", "logits_dense")
37+
38+
39+
def _is_output_head_param_path(path, _value):
40+
"""Filter for nnx.split: True when the param path belongs to the output head."""
41+
keys = [str(getattr(k, "key", k)) for k in path]
42+
return any(k in keys for k in _OUTPUT_HEAD_PATH_KEYS)
43+
44+
3345
def vocab_tiling_linen_loss(
3446
hidden_states,
3547
data,
@@ -253,12 +265,12 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
253265
def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train):
254266
"""Computes cross-entropy loss with vocab tiling for NNX models.
255267
256-
NNX equivalent of ``vocab_tiling_linen_loss``. Scans the vocab dimension
257-
and calls ``model.logits_from_hidden_states_for_vocab_tiling`` per chunk. The NNX model
258-
carries its own parameters, so no explicit gather is needed.
259-
260-
Uses default autograd; a custom_vjp for backward memory savings can be
261-
added later if needed.
268+
NNX equivalent of `vocab_tiling_linen_loss`. A `custom_vjp` runs the loss in
269+
vocab chunks via `jax.lax.scan` so the backward only holds one chunk's logits
270+
at a time, matching the Linen path's memory profile. `nnx.split` separates the
271+
output-head params (which the loss differentiates) from everything else; the
272+
rest of the model is passed through but not differentiated, so the scan's
273+
residuals stay small.
262274
263275
Args:
264276
model: NNX model exposing ``logits_from_hidden_states_for_vocab_tiling``.
@@ -320,42 +332,137 @@ def _reshape(inputs, out_shape, out_sharding):
320332
labels = _maybe_shard_with_name(labels, label_spec)
321333
segmentation = _maybe_shard_with_name(segmentation, label_spec)
322334

323-
batch_size, seq_len, emb_dim = hidden_states.shape
324-
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling
335+
# head_params is what the loss differentiates; other_params (transformer layers) and
336+
# rest (rngs) are passed through the custom_vjp but not differentiated. They go through
337+
# as primals rather than closure captures: capturing them leaks tracers across the
338+
# custom_vjp + lax.scan boundary, which fails for tied embeddings.
339+
graphdef, head_params, other_params, rest = nnx.split(model, _is_output_head_param_path, nnx.Param, ...)
325340

326-
reshaped_hidden_states = _reshape(
327-
hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec
328-
)
329-
reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
330-
reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
331-
332-
# Rebuild the model per chunk inside the scan: the output head pulls an rng stream, and
333-
# mutating the outer model's rng inside scan's sub-trace raises TraceContextError.
334-
# nnx.merge(..., copy=True) makes fresh Variables local to each iteration.
335-
graphdef, model_state = nnx.split(model)
336-
337-
def _scan_body(accumulators, chunk_data):
338-
loss_accumulator, z_loss_accumulator = accumulators
339-
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
340-
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
341-
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
342-
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)
343-
344-
chunk_model = nnx.merge(graphdef, model_state, copy=True)
345-
chunk_logits = chunk_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode)
346-
chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec)
347-
one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size)
348-
chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits(
349-
chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier
341+
def _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk):
342+
local_model = nnx.merge(graphdef, chunk_head_params, chunk_other_params, chunk_rest, copy=True)
343+
chunk_logits = local_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode)
344+
return _maybe_shard_with_name(chunk_logits, chunked_logits_spec)
345+
346+
@jax.custom_vjp
347+
def chunked_cross_entropy_loss(chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation):
348+
(total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd(
349+
chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation
350350
)
351+
return total_loss, total_z_loss
351352

352-
masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0))
353-
masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0))
353+
def _chunked_cross_entropy_loss_fwd(
354+
chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation
355+
):
356+
batch_size, seq_len, emb_dim = hidden_states.shape
357+
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling
354358

355-
return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None
359+
reshaped_hidden_states = _reshape(
360+
hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec
361+
)
362+
reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
363+
reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
356364

357-
initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype))
358-
(total_loss, total_z_loss), _ = jax.lax.scan(
359-
_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
365+
def _fwd_scan_body(accumulators, chunk_data):
366+
loss_accumulator, z_loss_accumulator = accumulators
367+
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
368+
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
369+
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
370+
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)
371+
372+
chunk_logits = _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk)
373+
one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size)
374+
chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits(
375+
chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier
376+
)
377+
378+
masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0))
379+
masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0))
380+
381+
return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None
382+
383+
# Always accumulate in fp32 — `cross_entropy_with_logits` returns fp32 regardless of
384+
# logits dtype, and a bf16 carry would mismatch the body output type under lax.scan.
385+
initial_acc = (jnp.zeros((), dtype=jnp.float32), jnp.zeros((), dtype=jnp.float32))
386+
(total_loss, total_z_loss), _ = jax.lax.scan(
387+
_fwd_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
388+
)
389+
residuals = (
390+
chunk_head_params,
391+
chunk_other_params,
392+
chunk_rest,
393+
reshaped_hidden_states,
394+
reshaped_labels,
395+
reshaped_segmentation,
396+
batch_size,
397+
seq_len,
398+
emb_dim,
399+
)
400+
return (total_loss, total_z_loss), residuals
401+
402+
def _chunked_cross_entropy_loss_bwd(residuals, cotangents):
403+
# z_loss is folded into the xent loss inside cross_entropy_with_logits.
404+
loss_cotangent, _ = cotangents
405+
406+
(
407+
chunk_head_params,
408+
chunk_other_params,
409+
chunk_rest,
410+
reshaped_hidden_states,
411+
reshaped_labels,
412+
reshaped_segmentation,
413+
batch_size,
414+
seq_len,
415+
emb_dim,
416+
) = residuals
417+
418+
def _single_chunk_loss_fn(input_head_params, input_hidden_chunk, input_label_chunk, input_segmentation_chunk):
419+
chunk_logits = _logits_for_chunk(input_head_params, chunk_other_params, chunk_rest, input_hidden_chunk)
420+
one_hot_label_chunk = jax.nn.one_hot(input_label_chunk, config.vocab_size)
421+
xent, _ = max_utils.cross_entropy_with_logits(chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier)
422+
return jnp.sum(xent * (input_segmentation_chunk != 0))
423+
424+
def _bwd_scan_body(grad_head_acc, chunk_data):
425+
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
426+
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
427+
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
428+
segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec)
429+
430+
# pylint: disable=unnecessary-lambda-assignment
431+
loss_fn_for_vjp = lambda p, h: _single_chunk_loss_fn(p, h, label_chunk, segmentation_chunk)
432+
_, vjp_fn = jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk)
433+
(grad_head_update, grad_hidden_chunk) = vjp_fn(1.0)
434+
grad_hidden_chunk = _maybe_shard_with_name(grad_hidden_chunk, chunked_hidden_spec)
435+
436+
grad_head_acc = jax.tree_util.tree_map(lambda acc, update: acc + update, grad_head_acc, grad_head_update)
437+
return grad_head_acc, grad_hidden_chunk
438+
439+
initial_grad_head = jax.tree_util.tree_map(jnp.zeros_like, chunk_head_params)
440+
441+
grad_head, grad_reshaped_hidden_states = jax.lax.scan(
442+
_bwd_scan_body, initial_grad_head, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
443+
)
444+
grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec)
445+
grad_head = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_head)
446+
grad_head = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), chunk_head_params, grad_head)
447+
grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)
448+
449+
# Return explicit zeros for other_params and rest, not None. With None, JAX builds
450+
# the zero cotangents with the wrong layer-axis order for scanned params, and the
451+
# AOT trace fails the cotangent shape check.
452+
grad_other = jax.tree_util.tree_map(jnp.zeros_like, chunk_other_params)
453+
grad_rest = jax.tree_util.tree_map(jnp.zeros_like, chunk_rest)
454+
return (
455+
grad_head,
456+
grad_other,
457+
grad_rest,
458+
grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype),
459+
None,
460+
None,
461+
)
462+
463+
chunked_cross_entropy_loss.defvjp(_chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd)
464+
465+
total_loss, total_z_loss = chunked_cross_entropy_loss(
466+
head_params, other_params, rest, hidden_states, labels, segmentation
360467
)
361468
return total_loss, total_z_loss

0 commit comments

Comments
 (0)