[NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix#3844
Merged
Merged
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
ca957ab to
e173538
Compare
4 tasks
31ac0e6 to
88417d0
Compare
Open
4 tasks
54f4f9d to
71525e7
Compare
d4bcba4 to
b3dd0c1
Compare
0ea4ca7 to
ce5a342
Compare
|
🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully integrates AQT quantization with the NNX model path in MaxEngine and the GRPO trainer. It also resolves a significant bug in the GPT-3 attention implementation where KV caches were not being correctly updated during prefill, and provides a model-agnostic layerwise quantization (conversion) utility for NNX models.
🔍 General Feedback
- Robust NNX+AQT Integration: The changes to
model_creation_utils.pyandmaxengine.pydemonstrate a deep understanding of the NNX variable system and AQT's composite types (QTensor). The use ofget_value()and the removal ofPartitionedwrappers are critical for stable checkpoint roundtrips. - Model Correctness: The GPT-3 fix brings the model's inference behavior in line with other major architectures in the repository.
- Improved Testing: The addition of
aqt_serve_roundtrip_nnx_test.pyand other NNX-specific unit tests provides excellent coverage for the complex logic involved in loading and aligning quantized checkpoints. - Clean Metrics Collection:
collect_intermediates_by_suffixis a valuable utility for unified metrics collection across diverse model structures.
Draft
4 tasks
NuojCheng
approved these changes
Jun 8, 2026
bvandermoon
approved these changes
Jun 9, 2026
Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both
load pre-quantized checkpoints directly and convert full-precision
checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill
bug surfaced by the AQT end-to-end validation.
NNX + AQT in MaxEngine:
- model_creation_utils threads quant_mode_str ("train" | "convert" |
"serve") through from_config / create_model /
get_nnx_create_model_fn / create_nnx_abstract_model /
from_pretrained. Default "train" preserves existing callers; "serve"
propagates to configure_quantization so AQT layers don't materialize
the full-precision kernel when the on-disk checkpoint already
carries qrhs scale factors.
- maxengine.__init__ selects the quant mode from
config.checkpoint_is_quantized; _load_params_nnx drops its
NotImplementedError. Two paths: pre-quantized
(checkpoint_is_quantized=True) loads via quant_mode_str="serve";
full-precision + quantization=int8 loads in TRAIN mode and AQT
layers quantize per-forward (same numerical result for absmax
calibration).
- layerwise_quantization._load_and_quantize_nnx: whole-model NNX
convert path. Loads full-precision in TRAIN mode, transfers kernels
into a CONVERT-mode model, runs forward to populate qrhs.frozen via
the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths,
saves serve-mode-shaped state.
Sharding helpers and from_pretrained QTensor handling (5 chained fixes
that kept the serve-mode reload from working):
- maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a
parallel-tree of replicated NamedSharding leaves when a Variable's
value is a composite pytree (AQT serve-mode QTensor with a qvalue
int8 leaf and a list of bf16 scale leaves).
- model_creation_utils.from_pretrained: drops a redundant
jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode
AQT under Flax 0.12.6. _build_value_target / _free_device_memory /
_unwrap_for_align use Variable.get_value() instead of v[...]
indexing for QTensor leaves (QTensor.__getitem__ trips on the
LogicallyPartitioned wrapper around qvalue). Widens the restore
filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable
type. Skips QTensor leaves in the per-axis shape-alignment dispatch
(their saved shape already matches the model). _build_value_target
strips Partitioned wrappers around composite-leaf values so the
restore tree path matches the on-disk layout (LogicallyPartitioned
was adding an extra .value key under each QTensor leaf, which made
orbax silently fill the path with zero-init values).
gpt3 prefill / autoregressive fix (pre-existing, surfaced here):
- Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without
ever calling update_kv_caches to build cached_values, so any
non-TRAIN forward (prefill or autoregressive) tripped the
`assert prefill_kv_cache` check. Mirror the standard Attention
plumbing in attentions.py: __init__ constructs a KVCache_0 module
when model_mode != MODEL_MODE_TRAIN, threads
max_prefill_predict_length into AttentionOp; __call__ calls
self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as
cached_values to attention_op. TRAIN-mode shape unchanged.
Tests:
- layerwise_quantization_nnx_test (new): 3 cases for
_strip_kernels_at_quantized_paths covering quantized removal,
non-quantized preservation (norms, embeddings), mixed-shape trees.
- aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that
builds a small NNX model in CONVERT mode with int8, runs a forward
to populate qrhs.frozen via the ToNNX bridge, saves the
serve-mode-shape state to a tmp local orbax checkpoint, reloads via
from_pretrained(quant_mode_str="serve"), and asserts every saved
qrhs.frozen.qvalue array byte-matches what came back. Guards the
full chain of QTensor / Partitioned / filter fixes.
- maxengine_test: replaced test_quantize_raises_for_nnx with
test_quantize_passes_gate_for_nnx; added
test_load_pre_quantized_nnx_passes_quant_gate and
test_quantized_prefill_nnx_train_mode (real numerical verification
with quantization=int8 + random params + TRAIN mode).
End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen
extraction + serve-mode-shape save + reload via
from_pretrained(quant_mode_str="serve") + maxengine.load_params +
quantized prefill forward all work; loaded qrhs.frozen.qvalue
byte-matches the on-disk state.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
7.5. ✅ Finish NNX MaxEngine inference carve-outs.
9.5. 🔄 [This PR] NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix.
custom_vjpfor NNX.True; regenerate sharding goldens; flip back integration-testpure_nnx=Falseannotations.Description
Migrates the NNX + AQT integration in MaxEngine so
pure_nnx=Truecan both load pre-quantized checkpoints directly (checkpoint_is_quantized=True) and convert full-precision checkpoints to int8 on load (checkpoint_is_quantized=False+quantization=int8). Also bundles a pre-existing gpt3 prefill / autoregressive bug surfaced by the AQT end-to-end validation.Originally part of PR9; split into its own follow-up to keep each PR reviewable. Linen paths byte-for-byte preserved (every NNX edit is gated on
config.pure_nnx). Stacks on PR9 (feat/nnx-qk-clip-and-checkpoint-utils); the two halves are file-disjoint.Diff: +641 / −58 across 8 files (5 src, 3 test, 2 new).
What it does
quant_mode_strplumbing — thread"train"/"convert"/"serve"throughfrom_config→create_model→get_nnx_create_model_fn→create_nnx_abstract_model→from_pretrained. Default"train"preserves existing callers;"serve"propagates toconfigure_quantizationso AQT layers don't materialize the full-precision kernel when the on-disk checkpoint already carries qrhs scale factors.maxengine.__init__— selects the quant mode fromconfig.checkpoint_is_quantized;_load_params_nnxdrops itsNotImplementedError. Two paths now work: pre-quantized loads viaquant_mode_str="serve"; full-precision +quantization=int8loads in TRAIN mode and AQT quantizes per-forward against the loaded kernel (same numerical result as serve mode for absmax calibration).layerwise_quantization._load_and_quantize_nnx(new) — whole-model NNX convert path. Load full-precision in TRAIN mode → copy kernels into a separate CONVERT-mode model → run a forward (theToNNX(AqtDotGeneral)bridge auto-capturesqrhs.frozen) → strip kernels at quantized paths via_strip_kernels_at_quantized_paths→ save serve-mode-shape state. The DeepSeek-only assertion is lifted for NNX.Sharding helpers +
from_pretrainedQTensor handling — 5 chained fixes that kept serve-mode reload from working:get_nnx_named_sharding_with_scan_axisemits a parallel-tree of replicatedNamedShardingleaves when a Variable's value is a composite pytree (QTensor int8qvalue+ bf16scalelist). Previously returned the Variable as-is whenvalhad no.shape._build_value_target/_free_device_memory/_unwrap_for_alignuseVariable.get_value()instead ofv[...]—QTensor.__getitem__trips on theLogicallyPartitionedwrapper aroundqvalue.nnx.Paramonly to "everything exceptnnx.RngState/nnx.Cache", and_load_params_nnxadds a 4-waynnx.split+ overlay soaqt-typedqrhs.frozenleaves survive into the rest-state._build_value_targetstripsPartitionedwrappers around composite-leaf values so the restore tree path matches the on-disk layout. Without this,jax.tree.flatten_with_pathadded an extra.valuekey under every QTensor leaf and orbax silently filled the missing paths with zero-init values (qvalue=0,scale=1— exactly the symptom)._walk_alignskips composite leaves in per-axis shape alignment — quantized payloads are saved at the model shape; previously crashed onQTensor.shape(which delegates toqvalue.shapeon aLogicallyPartitioned).Also dropped a redundant
jax.set_mesh(mesh)wrap increate_nnx_abstract_model— underjax.set_mesh, Flax 0.12.6's_to_variablerejects serve-mode AQT variables (NamedSharding(mesh=AbstractMesh, spec=None)).gpt3 prefill fix (
models/gpt3.py) — pre-existing bug surfaced by the AQT e2e validation.Gpt3MultiHeadAttention.__call__invokedattention_op(...)without ever callingupdate_kv_cachesto buildcached_values, so any non-TRAIN forward (prefill or autoregressive) trippedassert prefill_kv_cache. Mirrors the standardAttentionplumbing inattentions.py:__init__constructsKVCache_0whenmodel_mode != MODEL_MODE_TRAIN(and threadsmax_prefill_predict_lengthintoAttentionOp, was-1default);__call__callsself.KVCache_0(...)and passes[prefill_kv_cache, ar_kv_cache]ascached_values. TRAIN-mode shape unchanged. Bundled here because the AQT e2e exercises this path.Tests
aqt_serve_roundtrip_nnx_test(new, 1 test) — end-to-end regression. Builds a small CONVERT-mode NNX model withquantization=int8, runs a forward to populateqrhs.frozen, saves serve-mode state to orbax, reloads viafrom_pretrained(quant_mode_str="serve"), asserts every savedqvaluebyte-matches what came back. Guards the full chain of QTensor / Partitioned / filter fixes.layerwise_quantization_nnx_test(new, 3 tests) —_strip_kernels_at_quantized_pathscovering quantized-kernel removal, non-quantized preservation (norms, embeddings), mixed-shape trees.maxengine_test— replacedtest_quantize_raises_for_nnxwithtest_quantize_passes_gate_for_nnx; addedtest_load_pre_quantized_nnx_passes_quant_gate(serve-mode gate) andtest_quantized_prefill_nnx_train_mode(real numerical prefill withquantization=int8+ random params + TRAIN mode produces finite logits).Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.