Skip to content

[NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix#3844

Merged
copybara-service[bot] merged 1 commit into
mainfrom
feat/nnx-aqt-maxengine
Jun 9, 2026
Merged

[NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix#3844
copybara-service[bot] merged 1 commit into
mainfrom
feat/nnx-aqt-maxengine

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented May 7, 2026

Copy link
Copy Markdown
Collaborator

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: model creation, gradient accumulation, checkpointing, dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX + post-training bugfixes. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling MVP on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference (core prefill/generate/insert path). (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
    7.5. ✅ Finish NNX MaxEngine inference carve-outs.
  8. ✅ NNX-native LoRA + GRPO. (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. ✅ NNX-aware QK-Clip + NNX-format checkpoint utilities. (PR [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities #3836)
    9.5. 🔄 [This PR] NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix.
  10. ❌ Vocab tiling custom_vjp for NNX.
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

Migrates the NNX + AQT integration in MaxEngine so pure_nnx=True can both load pre-quantized checkpoints directly (checkpoint_is_quantized=True) and convert full-precision checkpoints to int8 on load (checkpoint_is_quantized=False + quantization=int8). Also bundles a pre-existing gpt3 prefill / autoregressive bug surfaced by the AQT end-to-end validation.

Originally part of PR9; split into its own follow-up to keep each PR reviewable. Linen paths byte-for-byte preserved (every NNX edit is gated on config.pure_nnx). Stacks on PR9 (feat/nnx-qk-clip-and-checkpoint-utils); the two halves are file-disjoint.

Diff: +641 / −58 across 8 files (5 src, 3 test, 2 new).

What it does

  • quant_mode_str plumbing — thread "train" / "convert" / "serve" through from_configcreate_modelget_nnx_create_model_fncreate_nnx_abstract_modelfrom_pretrained. Default "train" preserves existing callers; "serve" propagates to configure_quantization so AQT layers don't materialize the full-precision kernel when the on-disk checkpoint already carries qrhs scale factors.

  • maxengine.__init__ — selects the quant mode from config.checkpoint_is_quantized; _load_params_nnx drops its NotImplementedError. Two paths now work: pre-quantized loads via quant_mode_str="serve"; full-precision + quantization=int8 loads in TRAIN mode and AQT quantizes per-forward against the loaded kernel (same numerical result as serve mode for absmax calibration).

  • layerwise_quantization._load_and_quantize_nnx (new) — whole-model NNX convert path. Load full-precision in TRAIN mode → copy kernels into a separate CONVERT-mode model → run a forward (the ToNNX(AqtDotGeneral) bridge auto-captures qrhs.frozen) → strip kernels at quantized paths via _strip_kernels_at_quantized_paths → save serve-mode-shape state. The DeepSeek-only assertion is lifted for NNX.

  • Sharding helpers + from_pretrained QTensor handling — 5 chained fixes that kept serve-mode reload from working:

    1. get_nnx_named_sharding_with_scan_axis emits a parallel-tree of replicated NamedSharding leaves when a Variable's value is a composite pytree (QTensor int8 qvalue + bf16 scale list). Previously returned the Variable as-is when val had no .shape.
    2. _build_value_target / _free_device_memory / _unwrap_for_align use Variable.get_value() instead of v[...]QTensor.__getitem__ trips on the LogicallyPartitioned wrapper around qvalue.
    3. Restore filter widened from nnx.Param only to "everything except nnx.RngState / nnx.Cache", and _load_params_nnx adds a 4-way nnx.split + overlay so aqt-typed qrhs.frozen leaves survive into the rest-state.
    4. _build_value_target strips Partitioned wrappers around composite-leaf values so the restore tree path matches the on-disk layout. Without this, jax.tree.flatten_with_path added an extra .value key under every QTensor leaf and orbax silently filled the missing paths with zero-init values (qvalue=0, scale=1 — exactly the symptom).
    5. _walk_align skips composite leaves in per-axis shape alignment — quantized payloads are saved at the model shape; previously crashed on QTensor.shape (which delegates to qvalue.shape on a LogicallyPartitioned).

    Also dropped a redundant jax.set_mesh(mesh) wrap in create_nnx_abstract_model — under jax.set_mesh, Flax 0.12.6's _to_variable rejects serve-mode AQT variables (NamedSharding(mesh=AbstractMesh, spec=None)).

  • gpt3 prefill fix (models/gpt3.py) — pre-existing bug surfaced by the AQT e2e validation. Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without ever calling update_kv_caches to build cached_values, so any non-TRAIN forward (prefill or autoregressive) tripped assert prefill_kv_cache. Mirrors the standard Attention plumbing in attentions.py: __init__ constructs KVCache_0 when model_mode != MODEL_MODE_TRAIN (and threads max_prefill_predict_length into AttentionOp, was -1 default); __call__ calls self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as cached_values. TRAIN-mode shape unchanged. Bundled here because the AQT e2e exercises this path.

Tests

  • aqt_serve_roundtrip_nnx_test (new, 1 test) — end-to-end regression. Builds a small CONVERT-mode NNX model with quantization=int8, runs a forward to populate qrhs.frozen, saves serve-mode state to orbax, reloads via from_pretrained(quant_mode_str="serve"), asserts every saved qvalue byte-matches what came back. Guards the full chain of QTensor / Partitioned / filter fixes.
  • layerwise_quantization_nnx_test (new, 3 tests) — _strip_kernels_at_quantized_paths covering quantized-kernel removal, non-quantized preservation (norms, embeddings), mixed-shape trees.
  • maxengine_test — replaced test_quantize_raises_for_nnx with test_quantize_passes_gate_for_nnx; added test_load_pre_quantized_nnx_passes_quant_gate (serve-mode gate) and test_quantized_prefill_nnx_train_mode (real numerical prefill with quantization=int8 + random params + TRAIN mode produces finite logits).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented May 7, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 40.00000% with 72 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/layerwise_quantization.py 19.71% 56 Missing and 1 partial ⚠️
src/maxtext/utils/model_creation_utils.py 66.66% 10 Missing and 5 partials ⚠️

📢 Thoughts on this report? Let us know!

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch from ca957ab to e173538 Compare May 7, 2026 21:53
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch 15 times, most recently from 31ac0e6 to 88417d0 Compare May 14, 2026 22:51
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch 9 times, most recently from 54f4f9d to 71525e7 Compare May 21, 2026 19:58
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch 2 times, most recently from d4bcba4 to b3dd0c1 Compare May 25, 2026 15:26
@github-actions

github-actions Bot commented Jun 3, 2026

Copy link
Copy Markdown

🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully integrates AQT quantization with the NNX model path in MaxEngine and the GRPO trainer. It also resolves a significant bug in the GPT-3 attention implementation where KV caches were not being correctly updated during prefill, and provides a model-agnostic layerwise quantization (conversion) utility for NNX models.

🔍 General Feedback

  • Robust NNX+AQT Integration: The changes to model_creation_utils.py and maxengine.py demonstrate a deep understanding of the NNX variable system and AQT's composite types (QTensor). The use of get_value() and the removal of Partitioned wrappers are critical for stable checkpoint roundtrips.
  • Model Correctness: The GPT-3 fix brings the model's inference behavior in line with other major architectures in the repository.
  • Improved Testing: The addition of aqt_serve_roundtrip_nnx_test.py and other NNX-specific unit tests provides excellent coverage for the complex logic involved in loading and aligning quantized checkpoints.
  • Clean Metrics Collection: collect_intermediates_by_suffix is a valuable utility for unified metrics collection across diverse model structures.

Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both
load pre-quantized checkpoints directly and convert full-precision
checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill
bug surfaced by the AQT end-to-end validation.

NNX + AQT in MaxEngine:
- model_creation_utils threads quant_mode_str ("train" | "convert" |
  "serve") through from_config / create_model /
  get_nnx_create_model_fn / create_nnx_abstract_model /
  from_pretrained. Default "train" preserves existing callers; "serve"
  propagates to configure_quantization so AQT layers don't materialize
  the full-precision kernel when the on-disk checkpoint already
  carries qrhs scale factors.
- maxengine.__init__ selects the quant mode from
  config.checkpoint_is_quantized; _load_params_nnx drops its
  NotImplementedError. Two paths: pre-quantized
  (checkpoint_is_quantized=True) loads via quant_mode_str="serve";
  full-precision + quantization=int8 loads in TRAIN mode and AQT
  layers quantize per-forward (same numerical result for absmax
  calibration).
- layerwise_quantization._load_and_quantize_nnx: whole-model NNX
  convert path. Loads full-precision in TRAIN mode, transfers kernels
  into a CONVERT-mode model, runs forward to populate qrhs.frozen via
  the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths,
  saves serve-mode-shaped state.

Sharding helpers and from_pretrained QTensor handling (5 chained fixes
that kept the serve-mode reload from working):
- maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a
  parallel-tree of replicated NamedSharding leaves when a Variable's
  value is a composite pytree (AQT serve-mode QTensor with a qvalue
  int8 leaf and a list of bf16 scale leaves).
- model_creation_utils.from_pretrained: drops a redundant
  jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode
  AQT under Flax 0.12.6. _build_value_target / _free_device_memory /
  _unwrap_for_align use Variable.get_value() instead of v[...]
  indexing for QTensor leaves (QTensor.__getitem__ trips on the
  LogicallyPartitioned wrapper around qvalue). Widens the restore
  filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable
  type. Skips QTensor leaves in the per-axis shape-alignment dispatch
  (their saved shape already matches the model). _build_value_target
  strips Partitioned wrappers around composite-leaf values so the
  restore tree path matches the on-disk layout (LogicallyPartitioned
  was adding an extra .value key under each QTensor leaf, which made
  orbax silently fill the path with zero-init values).

gpt3 prefill / autoregressive fix (pre-existing, surfaced here):
- Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without
  ever calling update_kv_caches to build cached_values, so any
  non-TRAIN forward (prefill or autoregressive) tripped the
  `assert prefill_kv_cache` check. Mirror the standard Attention
  plumbing in attentions.py: __init__ constructs a KVCache_0 module
  when model_mode != MODEL_MODE_TRAIN, threads
  max_prefill_predict_length into AttentionOp; __call__ calls
  self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as
  cached_values to attention_op. TRAIN-mode shape unchanged.

Tests:
- layerwise_quantization_nnx_test (new): 3 cases for
  _strip_kernels_at_quantized_paths covering quantized removal,
  non-quantized preservation (norms, embeddings), mixed-shape trees.
- aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that
  builds a small NNX model in CONVERT mode with int8, runs a forward
  to populate qrhs.frozen via the ToNNX bridge, saves the
  serve-mode-shape state to a tmp local orbax checkpoint, reloads via
  from_pretrained(quant_mode_str="serve"), and asserts every saved
  qrhs.frozen.qvalue array byte-matches what came back. Guards the
  full chain of QTensor / Partitioned / filter fixes.
- maxengine_test: replaced test_quantize_raises_for_nnx with
  test_quantize_passes_gate_for_nnx; added
  test_load_pre_quantized_nnx_passes_quant_gate and
  test_quantized_prefill_nnx_train_mode (real numerical verification
  with quantization=int8 + random params + TRAIN mode).

End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen
extraction + serve-mode-shape save + reload via
from_pretrained(quant_mode_str="serve") + maxengine.load_params +
quantized prefill forward all work; loaded qrhs.frozen.qvalue
byte-matches the on-disk state.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants