Skip to content

[NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out#3849

Merged
copybara-service[bot] merged 1 commit into
mainfrom
feat/nnx-vocab-tiling-custom-vjp
Jun 12, 2026
Merged

[NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out#3849
copybara-service[bot] merged 1 commit into
mainfrom
feat/nnx-vocab-tiling-custom-vjp

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented May 8, 2026

Copy link
Copy Markdown
Collaborator

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: model creation, gradient accumulation, checkpointing, dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX + post-training bugfixes. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
    4.6. ✅ Linen↔NNX checkpoint comparator. (PR [NNX] NNX migration prep (4.6/N): Linen<->NNX checkpoint comparator #3846)
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling MVP on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. ✅ NNX-native LoRA + GRPO. (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. ✅ NNX-aware QK-Clip + NNX-format checkpoint utilities. (PR [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities #3836)
    9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
  10. 🔄 [This PR] Vocab tiling custom_vjp for NNX (with output-head carve-out — originally scoped as PR10.5, bundled in here).
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

Replaces the PR5 NNX vocab-tiling MVP (chunked forward + default autograd backward) with a jax.custom_vjp that mirrors the Linen path's backward-memory savings. Adds an output-head carve-out so the custom_vjp's residuals + grad accumulator scale with LM-head size, not the full model. The PR10.5 carve-out is bundled here because it's required for the NNX path to AOT-compile correctly under num_vocab_tiling > 1 + pure_nnx=True.

Linen vocab_tiling_linen_loss is byte-for-byte unchanged. Call sites in train.py / pyconfig_deprecated.py / configs/types.py are unchanged.

Diff: +532 / −59 across 7 files (5 src, 2 test).

Changes

utils/vocabulary_tiling.pyvocab_tiling_nnx_loss rewrite.

  • Outside the custom_vjp: 3-way nnx.split with a path filter for {token_embedder, shared_embedding, decoder_norm, logits_dense} — the only nnx.Param paths apply_output_head touches. head_params is differentiated; other_params and rest are threaded through as non-differentiated primals (closure-capturing them leaks tracers across custom_vjp + lax.scan, which fails on logits_via_embedding=True).
  • Forward: reshape, lax.scan over chunks, nnx.merge(graphdef, chunk_head, chunk_other, chunk_rest, copy=True) per chunk. Initial accumulator is fp32 — the previous hidden_states.dtype accumulator mismatched the bf16 carry against fp32 body output from cross_entropy_with_logits under lax.scan.
  • Backward: per-chunk jax.vjp over (chunk_head_params, hidden_chunk). chunk_other_params / chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) not None — under AOT, None synthesizes zeros with the wrong axis order for nnx.scan-stacked layer params and the cotangent shape check fails.
  • Correctness: logits_from_hidden_states_for_vocab_tiling provably depends only on head_params, so the loss gradient w.r.t. other_params is exactly zero. The full model forward in train.py still produces hidden_states, so transformer-layer gradients flow back through grad_hidden_states → outer backward, unaffected by the carve-out.

layers/nnx_decoders.pyapply_output_head logits_via_embedding=True branch uses shared_embedding.embedding[...] instead of the deprecated .value shim. The shim records the access in NNX's mutation tracking, which JAX detects as a tracer leak when the embedding is threaded across the custom_vjp boundary. Linen branch unchanged.

models/models.py — deletes dead self.hidden_states = None and the if num_vocab_tiling > 1: self.hidden_states = hidden_state write on the NNX Transformer. Two lines left from an early PR5 idea; neither path reads model.hidden_states (Linen uses mutable=["intermediates"], NNX uses nnx.pop(model, nnx.Intermediate)). Without this fix, AOT compile under pure_nnx=True + num_vocab_tiling>1 raised ValueError: Cannot assign data value of type 'LinearizeTracer' to static attribute 'hidden_states'.

inference/maxengine/maxengine.py + utils/model_creation_utils.py — review-comment fixes from PR #3849:

  • _overlay: add isinstance(src, dict) guard so a structural mismatch can't TypeError on k in src. On mismatch, keep dst (PREFILL) rather than overwriting with a leaf/subtree.
  • _free_device_memory: walk via jax.tree_util.tree_leaves(inner) so AQT serve-mode qrhs.frozen's QTensor leaves (qvalue + scale arrays) are freed, not just single-jax.Array Variables.

Tests

tests/unit/tiling_test.py — new VocabTilingNNXTest (9 TPU tests). Loss + grad parity vs. a full-vocab xent reference, covering: non-tied / tied embedding, total_z_loss value parity, half-padded segmentation, argnums=1 (grad_hidden_states), bf16 inputs (caught the fp32-accumulator bug), z_loss=0, num_vocab_tiling ∈ {2,4,8} parity, and a carve-out invariant asserting every non-head leaf has zero grad and at least one head leaf doesn't.

tests/unit/train_compile_test.py — removed the now-stale pytest.skip("Vocab tiling not supported on NNX.") in test_vocab_tiling_bf16. Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train step under pure_nnx=true with num_vocab_tiling=4 + weight_dtype=bfloat16. This test surfaced both the models.py dead-code regression and the cotangent-axis-order issue.

Existing tests untouched.

Stats

  • Diff: +532 / −59 across 7 files (5 src, 2 test).
  • Linen preservation: vocab_tiling_linen_loss byte-for-byte unchanged; nnx_decoders.py Linen branch unchanged; TransformerLinenPure unchanged.
  • NNX impact outside vocab tiling: the apply_output_head .value[...] change is the only edit that runs in non-vocab-tiling NNX paths. 52 adjacent NNX tests pass (train_nnx, dpo_nnx, grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-vocab-tiling-custom-vjp branch from a8a53c3 to e015455 Compare May 8, 2026 16:19
@codecov

codecov Bot commented May 8, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 93.50649% with 5 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/model_creation_utils.py 40.00% 1 Missing and 2 partials ⚠️
src/maxtext/utils/vocabulary_tiling.py 97.18% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-vocab-tiling-custom-vjp branch 13 times, most recently from b525127 to 1727514 Compare May 15, 2026 01:11
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-vocab-tiling-custom-vjp branch 11 times, most recently from da2a12a to 023b019 Compare May 22, 2026 21:10
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-vocab-tiling-custom-vjp branch 4 times, most recently from 4da6eb3 to a56510e Compare May 28, 2026 04:40
@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown

🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully prepares MaxText for the NNX migration (10/N) by implementing a memory-efficient jax.custom_vjp for vocabulary tiling on NNX with output-head carve-out. It introduces robust logic for serve-mode quantization reload, ensuring that on-disk scale factors are correctly merged and matched during from-pretrained model load. The code quality, comments, and comprehensive test suite are exceptionally high, making this an excellent contribution to the codebase.

🔍 General Feedback

  • Exceptional Testing Quality: The inclusion of 9 dedicated unit tests covering tied/non-tied embeddings, padded segmentations, and gradients over hidden states ensures strong correctness guarantees.
  • Thorough and Informative Comments: The inline documentation explaining why complex steps (like explicit zero cotangents or 4-way split of loaded AQT states) were taken is incredibly helpful and makes the codebase highly maintainable.
  • Robust Integration: Correctly modernizes the variable access pattern (embedding[...] instead of .value) to prevent tracer leakages.

Comment thread src/maxtext/inference/maxengine/maxengine.py Outdated
Comment thread src/maxtext/utils/model_creation_utils.py Outdated
Comment thread src/maxtext/utils/vocabulary_tiling.py
Comment thread src/maxtext/inference/maxengine/maxengine.py Outdated

@bvandermoon bvandermoon left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, one small comment

Comment thread src/maxtext/layers/nnx_decoders.py Outdated
@ecnal-cienet

ecnal-cienet commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator Author

Due to TPU UT check hanging issue remains unsolved, I will add "pull ready" manually and run TAP Presubmit on critique side first.

Replaces the PR9.5 NNX vocab-tiling MVP (chunked forward + default
autograd backward) with a jax.custom_vjp that mirrors the Linen path's
backward-memory savings, then carves out the output-head params so the
custom_vjp's residuals + grad accumulator scale with LM-head size, not
with the full model. Linen vocab_tiling_linen_loss is byte-for-byte
unchanged. Call sites in train.py / pyconfig_deprecated.py /
configs/types.py are unchanged.

Custom_vjp + output-head carve-out (vocabulary_tiling.py):
- Outside the custom_vjp: 3-way nnx.split with a callable path filter
  (_is_output_head_param_path) matching {token_embedder,
  shared_embedding, decoder_norm, logits_dense} — the only nnx.Param
  paths apply_output_head touches. Returns (graphdef, head_params,
  other_params, rest).
- Custom_vjp primals: (head_params, other_params, rest, hidden_states,
  labels, segmentation). Only head_params and hidden_states are
  differentiated; other_params + rest are threaded through as
  non-differentiated primals so their tracers don't have to cross both
  the custom_vjp and the inner lax.scan boundary (which previously
  caused UnexpectedTracerError under logits_via_embedding=True).
- Forward (_chunked_cross_entropy_loss_fwd): reshapes to
  (num_vocab_tiling, vocab_tile_size, ...) and runs lax.scan whose body
  rebuilds the model per chunk via nnx.merge(graphdef, chunk_head,
  chunk_other, chunk_rest, copy=True) and calls
  logits_from_hidden_states. Initial scan accumulator is fp32 (was
  hidden_states.dtype previously — caused a lax.scan carry dtype
  mismatch with bf16 hidden_states since cross_entropy_with_logits
  always returns fp32). Residuals are (chunk_head, chunk_other,
  chunk_rest, reshaped_*, batch/seq/emb).
- Backward (_chunked_cross_entropy_loss_bwd): a second lax.scan whose
  body builds loss_fn_for_vjp = lambda p, h: ..., calls
  jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk),
  accumulates grad_head via tree.map(add), emits per-chunk grad_hidden.
  Chain-rules grad_head *= loss_cotangent and dtype-casts to each
  primal's dtype (custom_vjp requires this). chunk_other_params and
  chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) zero
  pytrees, NOT None — None makes JAX synthesize zeros at AOT trace time
  with axis-0 stacking (jax.scan convention) for nnx.scan-stacked
  transformer-layer params, which carry axis-1 stacking (nnx
  convention), and the cotangent-shape check fails as
  "Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M],
  but got bfloat16[L,E,M]". Materializing the zeros ties the cotangent
  shape to the primal shape exactly.
- Correctness: logits_from_hidden_states provably depends only on
  head_params; the gradient w.r.t. other_params through this loss is
  exactly zero. When train.py also calls the full model forward (which
  produces hidden_states), transformer-layer gradients flow back
  through grad_hidden_states → outer backward, unaffected by the
  carve-out.

Supporting fixes (touched for the carve-out to work end-to-end):
- nnx_decoders.py::apply_output_head logits_via_embedding=True branch
  reads embedding_table = shared_embedding.embedding[...] instead of
  the deprecated .value shim. The .value shim registers the access in
  NNX mutation tracking, which JAX detects as a tracer leak when the
  embedding is closure-captured / threaded across the custom_vjp +
  lax.scan boundaries. The Linen branch is unchanged.
- models.py: deletes dead-code self.hidden_states = None and
  if num_vocab_tiling > 1: self.hidden_states = hidden_state from the
  NNX Transformer class. Two lines left over from an early PR5
  implementation idea — neither path actually reads
  model.hidden_states (Linen reads via mutable=["intermediates"]; NNX
  reads via nnx.pop(model, nnx.Intermediate) from the decoder's sown
  ("decoder", "hidden_states") intermediate). Without this fix, AOT
  compile under pure_nnx=True + num_vocab_tiling>1 raised
  ValueError: Cannot assign data value of type 'LinearizeTracer' to
  static attribute 'hidden_states' of Pytree type 'Transformer' —
  would have silently broken any post-PR11 user with vocab tiling on.

Tests (tiling_test.py — new VocabTilingNNXTest class with 9 TPU tests):
- test_nnx_vocab_tiling_non_tied_embedding / _tied_embedding: loss +
  grad parity vs. full-vocab xent reference for both LM-head modes.
- test_nnx_vocab_tiling_total_z_loss_value_parity: asserts the second
  tuple element matches the reference (was untested before).
- test_nnx_vocab_tiling_padded_segmentation: half-padded mask;
  exercises the segmentation != 0 mask branch and asserts padded loss
  is strictly less than unpadded.
- test_nnx_vocab_tiling_grad_over_hidden_states: argnums=1
  differentiation; exercises the custom_vjp's second-primal cotangent
  path (grad_reshaped_hidden_states), shape + dtype + value parity.
- test_nnx_vocab_tiling_bf16_hidden_states: bf16 inputs with rtol/atol
  loosened to 5e-2; asserts grad_h.dtype == bf16 (the bwd dtype-cast
  preserves the primal's dtype). Caught the fp32-accumulator bug.
- test_nnx_vocab_tiling_z_loss_zero: z_loss_multiplier=0;
  total_z_loss == 0.0 exactly and grad parity holds.
- test_nnx_vocab_tiling_num_vocab_tiling_variants: runs n ∈ {2, 4, 8}
  and asserts identical loss + grads (catches off-by-one in
  vocab_tile_size and scan/reshape interactions).
- test_nnx_vocab_tiling_other_params_get_zero_grad (carve-out
  invariant): asserts every non-head leaf has gradient exactly zero
  AND at least one head leaf has non-zero gradient (so the test can't
  trivially pass by zeroing everything). Catches filter bugs (e.g.
  forgetting that NNX names the embedder token_embedder while Linen
  names it shared_embedding) and bwd zero-shape bugs.

AOT compile coverage (train_compile_test.py):
- Removed the now-stale pytest.skip("Vocab tiling not supported on
  NNX.") in test_vocab_tiling_bf16.
- Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train
  step under pure_nnx=true + enable_nnx=true + pure_nnx_decoder=true
  with num_vocab_tiling=4 and weight_dtype=bfloat16. Surfaced both the
  models.py dead-code regression and the cotangent-axis-ordering issue
  the explicit-zeros bwd fixes.

Tests pass: 18 in tiling + AOT (7 Linen UTs + 9 NNX UTs + 2 AOT, one
Linen and one NNX); 52 in adjacent NNX surfaces (train_nnx, dpo_nnx,
grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx)
— regression check for the nnx_decoders.py change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants