[NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out#3849
Conversation
a8a53c3 to
e015455
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
b525127 to
1727514
Compare
da2a12a to
023b019
Compare
4da6eb3 to
a56510e
Compare
225a082 to
7146bab
Compare
|
🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully prepares MaxText for the NNX migration (10/N) by implementing a memory-efficient jax.custom_vjp for vocabulary tiling on NNX with output-head carve-out. It introduces robust logic for serve-mode quantization reload, ensuring that on-disk scale factors are correctly merged and matched during from-pretrained model load. The code quality, comments, and comprehensive test suite are exceptionally high, making this an excellent contribution to the codebase.
🔍 General Feedback
- Exceptional Testing Quality: The inclusion of 9 dedicated unit tests covering tied/non-tied embeddings, padded segmentations, and gradients over hidden states ensures strong correctness guarantees.
- Thorough and Informative Comments: The inline documentation explaining why complex steps (like explicit zero cotangents or 4-way split of loaded AQT states) were taken is incredibly helpful and makes the codebase highly maintainable.
- Robust Integration: Correctly modernizes the variable access pattern (
embedding[...]instead of.value) to prevent tracer leakages.
bvandermoon
left a comment
There was a problem hiding this comment.
LGTM, one small comment
|
Due to TPU UT check hanging issue remains unsolved, I will add "pull ready" manually and run TAP Presubmit on critique side first. |
Replaces the PR9.5 NNX vocab-tiling MVP (chunked forward + default
autograd backward) with a jax.custom_vjp that mirrors the Linen path's
backward-memory savings, then carves out the output-head params so the
custom_vjp's residuals + grad accumulator scale with LM-head size, not
with the full model. Linen vocab_tiling_linen_loss is byte-for-byte
unchanged. Call sites in train.py / pyconfig_deprecated.py /
configs/types.py are unchanged.
Custom_vjp + output-head carve-out (vocabulary_tiling.py):
- Outside the custom_vjp: 3-way nnx.split with a callable path filter
(_is_output_head_param_path) matching {token_embedder,
shared_embedding, decoder_norm, logits_dense} — the only nnx.Param
paths apply_output_head touches. Returns (graphdef, head_params,
other_params, rest).
- Custom_vjp primals: (head_params, other_params, rest, hidden_states,
labels, segmentation). Only head_params and hidden_states are
differentiated; other_params + rest are threaded through as
non-differentiated primals so their tracers don't have to cross both
the custom_vjp and the inner lax.scan boundary (which previously
caused UnexpectedTracerError under logits_via_embedding=True).
- Forward (_chunked_cross_entropy_loss_fwd): reshapes to
(num_vocab_tiling, vocab_tile_size, ...) and runs lax.scan whose body
rebuilds the model per chunk via nnx.merge(graphdef, chunk_head,
chunk_other, chunk_rest, copy=True) and calls
logits_from_hidden_states. Initial scan accumulator is fp32 (was
hidden_states.dtype previously — caused a lax.scan carry dtype
mismatch with bf16 hidden_states since cross_entropy_with_logits
always returns fp32). Residuals are (chunk_head, chunk_other,
chunk_rest, reshaped_*, batch/seq/emb).
- Backward (_chunked_cross_entropy_loss_bwd): a second lax.scan whose
body builds loss_fn_for_vjp = lambda p, h: ..., calls
jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk),
accumulates grad_head via tree.map(add), emits per-chunk grad_hidden.
Chain-rules grad_head *= loss_cotangent and dtype-casts to each
primal's dtype (custom_vjp requires this). chunk_other_params and
chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) zero
pytrees, NOT None — None makes JAX synthesize zeros at AOT trace time
with axis-0 stacking (jax.scan convention) for nnx.scan-stacked
transformer-layer params, which carry axis-1 stacking (nnx
convention), and the cotangent-shape check fails as
"Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M],
but got bfloat16[L,E,M]". Materializing the zeros ties the cotangent
shape to the primal shape exactly.
- Correctness: logits_from_hidden_states provably depends only on
head_params; the gradient w.r.t. other_params through this loss is
exactly zero. When train.py also calls the full model forward (which
produces hidden_states), transformer-layer gradients flow back
through grad_hidden_states → outer backward, unaffected by the
carve-out.
Supporting fixes (touched for the carve-out to work end-to-end):
- nnx_decoders.py::apply_output_head logits_via_embedding=True branch
reads embedding_table = shared_embedding.embedding[...] instead of
the deprecated .value shim. The .value shim registers the access in
NNX mutation tracking, which JAX detects as a tracer leak when the
embedding is closure-captured / threaded across the custom_vjp +
lax.scan boundaries. The Linen branch is unchanged.
- models.py: deletes dead-code self.hidden_states = None and
if num_vocab_tiling > 1: self.hidden_states = hidden_state from the
NNX Transformer class. Two lines left over from an early PR5
implementation idea — neither path actually reads
model.hidden_states (Linen reads via mutable=["intermediates"]; NNX
reads via nnx.pop(model, nnx.Intermediate) from the decoder's sown
("decoder", "hidden_states") intermediate). Without this fix, AOT
compile under pure_nnx=True + num_vocab_tiling>1 raised
ValueError: Cannot assign data value of type 'LinearizeTracer' to
static attribute 'hidden_states' of Pytree type 'Transformer' —
would have silently broken any post-PR11 user with vocab tiling on.
Tests (tiling_test.py — new VocabTilingNNXTest class with 9 TPU tests):
- test_nnx_vocab_tiling_non_tied_embedding / _tied_embedding: loss +
grad parity vs. full-vocab xent reference for both LM-head modes.
- test_nnx_vocab_tiling_total_z_loss_value_parity: asserts the second
tuple element matches the reference (was untested before).
- test_nnx_vocab_tiling_padded_segmentation: half-padded mask;
exercises the segmentation != 0 mask branch and asserts padded loss
is strictly less than unpadded.
- test_nnx_vocab_tiling_grad_over_hidden_states: argnums=1
differentiation; exercises the custom_vjp's second-primal cotangent
path (grad_reshaped_hidden_states), shape + dtype + value parity.
- test_nnx_vocab_tiling_bf16_hidden_states: bf16 inputs with rtol/atol
loosened to 5e-2; asserts grad_h.dtype == bf16 (the bwd dtype-cast
preserves the primal's dtype). Caught the fp32-accumulator bug.
- test_nnx_vocab_tiling_z_loss_zero: z_loss_multiplier=0;
total_z_loss == 0.0 exactly and grad parity holds.
- test_nnx_vocab_tiling_num_vocab_tiling_variants: runs n ∈ {2, 4, 8}
and asserts identical loss + grads (catches off-by-one in
vocab_tile_size and scan/reshape interactions).
- test_nnx_vocab_tiling_other_params_get_zero_grad (carve-out
invariant): asserts every non-head leaf has gradient exactly zero
AND at least one head leaf has non-zero gradient (so the test can't
trivially pass by zeroing everything). Catches filter bugs (e.g.
forgetting that NNX names the embedder token_embedder while Linen
names it shared_embedding) and bwd zero-shape bugs.
AOT compile coverage (train_compile_test.py):
- Removed the now-stale pytest.skip("Vocab tiling not supported on
NNX.") in test_vocab_tiling_bf16.
- Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train
step under pure_nnx=true + enable_nnx=true + pure_nnx_decoder=true
with num_vocab_tiling=4 and weight_dtype=bfloat16. Surfaced both the
models.py dead-code regression and the cotangent-axis-ordering issue
the explicit-zeros bwd fixes.
Tests pass: 18 in tiling + AOT (7 Linen UTs + 9 NNX UTs + 2 AOT, one
Linen and one NNX); 52 in adjacent NNX surfaces (train_nnx, dpo_nnx,
grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx)
— regression check for the nnx_decoders.py change.
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
4.6. ✅ Linen↔NNX checkpoint comparator. (PR [NNX] NNX migration prep (4.6/N): Linen<->NNX checkpoint comparator #3846)
9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
custom_vjpfor NNX (with output-head carve-out — originally scoped as PR10.5, bundled in here).True; regenerate sharding goldens; flip back integration-testpure_nnx=Falseannotations.Description
Replaces the PR5 NNX vocab-tiling MVP (chunked forward + default autograd backward) with a
jax.custom_vjpthat mirrors the Linen path's backward-memory savings. Adds an output-head carve-out so thecustom_vjp's residuals + grad accumulator scale with LM-head size, not the full model. The PR10.5 carve-out is bundled here because it's required for the NNX path to AOT-compile correctly undernum_vocab_tiling > 1 + pure_nnx=True.Linen
vocab_tiling_linen_lossis byte-for-byte unchanged. Call sites intrain.py/pyconfig_deprecated.py/configs/types.pyare unchanged.Diff: +532 / −59 across 7 files (5 src, 2 test).
Changes
utils/vocabulary_tiling.py—vocab_tiling_nnx_lossrewrite.custom_vjp: 3-waynnx.splitwith a path filter for{token_embedder, shared_embedding, decoder_norm, logits_dense}— the onlynnx.Parampathsapply_output_headtouches.head_paramsis differentiated;other_paramsandrestare threaded through as non-differentiated primals (closure-capturing them leaks tracers acrosscustom_vjp+lax.scan, which fails onlogits_via_embedding=True).lax.scanover chunks,nnx.merge(graphdef, chunk_head, chunk_other, chunk_rest, copy=True)per chunk. Initial accumulator is fp32 — the previoushidden_states.dtypeaccumulator mismatched the bf16 carry against fp32 body output fromcross_entropy_with_logitsunderlax.scan.jax.vjpover(chunk_head_params, hidden_chunk).chunk_other_params/chunk_restcotangents are explicittree_map(jnp.zeros_like, ...)notNone— under AOT,Nonesynthesizes zeros with the wrong axis order fornnx.scan-stacked layer params and the cotangent shape check fails.logits_from_hidden_states_for_vocab_tilingprovably depends only onhead_params, so the loss gradient w.r.t.other_paramsis exactly zero. The full model forward intrain.pystill produceshidden_states, so transformer-layer gradients flow back throughgrad_hidden_states→ outer backward, unaffected by the carve-out.layers/nnx_decoders.py—apply_output_headlogits_via_embedding=Truebranch usesshared_embedding.embedding[...]instead of the deprecated.valueshim. The shim records the access in NNX's mutation tracking, which JAX detects as a tracer leak when the embedding is threaded across thecustom_vjpboundary. Linen branch unchanged.models/models.py— deletes deadself.hidden_states = Noneand theif num_vocab_tiling > 1: self.hidden_states = hidden_statewrite on the NNXTransformer. Two lines left from an early PR5 idea; neither path readsmodel.hidden_states(Linen usesmutable=["intermediates"], NNX usesnnx.pop(model, nnx.Intermediate)). Without this fix, AOT compile underpure_nnx=True + num_vocab_tiling>1raisedValueError: Cannot assign data value of type 'LinearizeTracer' to static attribute 'hidden_states'.inference/maxengine/maxengine.py+utils/model_creation_utils.py— review-comment fixes from PR #3849:_overlay: addisinstance(src, dict)guard so a structural mismatch can't TypeError onk in src. On mismatch, keepdst(PREFILL) rather than overwriting with a leaf/subtree._free_device_memory: walk viajax.tree_util.tree_leaves(inner)so AQT serve-modeqrhs.frozen's QTensor leaves (qvalue + scale arrays) are freed, not just single-jax.ArrayVariables.Tests
tests/unit/tiling_test.py— newVocabTilingNNXTest(9 TPU tests). Loss + grad parity vs. a full-vocab xent reference, covering: non-tied / tied embedding,total_z_lossvalue parity, half-padded segmentation,argnums=1(grad_hidden_states), bf16 inputs (caught the fp32-accumulator bug),z_loss=0,num_vocab_tiling ∈ {2,4,8}parity, and a carve-out invariant asserting every non-head leaf has zero grad and at least one head leaf doesn't.tests/unit/train_compile_test.py— removed the now-stalepytest.skip("Vocab tiling not supported on NNX.")intest_vocab_tiling_bf16. Addedtest_vocab_tiling_bf16_nnx(cpu_only): AOT-compiles the train step underpure_nnx=truewithnum_vocab_tiling=4+weight_dtype=bfloat16. This test surfaced both themodels.pydead-code regression and the cotangent-axis-order issue.Existing tests untouched.
Stats
vocab_tiling_linen_lossbyte-for-byte unchanged;nnx_decoders.pyLinen branch unchanged;TransformerLinenPureunchanged.apply_output_head.value→[...]change is the only edit that runs in non-vocab-tiling NNX paths. 52 adjacent NNX tests pass (train_nnx, dpo_nnx, grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx).Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.