Skip to content

[NNX] NNX migration prep (8/N): NNX native lora grpo#3824

Open
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-native-lora-grpo
Open

[NNX] NNX migration prep (8/N): NNX native lora grpo#3824
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-native-lora-grpo

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented May 6, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
    4.6. ❌ Linen↔NNX checkpoint comparator (sibling branch on PR4.5).
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. 🔄 [This PR] NNX-native LoRA + GRPO. NNX-native serving / decode-checkpoint LoRA via apply_lora_on_base_params_nnx / unapply_lora_from_base_params_nnx / get_lora_abstract_state_nnx (the maxengine pure_nnx + LoRA carve-out from PR7 is cleared); NNX-native GRPO trainer via grpo_loss_fn_nnx + compute_log_probs_nnx + NNX setup_train_loop/train_step/eval_step paths. Stacks on PR7.
  9. ❌ NNX-aware QK-Clip + remaining checkpoint utilities.
    9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix.
  10. ❌ Vocab tiling custom_vjp for NNX.
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

This PR implements NNX-native LoRA serving and NNX-native GRPO by adding NNX-shape walkers and step helpers alongside the existing Linen ones, then dispatching on config.pure_nnx. Every NNX modification is gated by if config.pure_nnx:, preserving the Linen path byte-for-byte. The diff spans +551 / −84 across 5 source files, plus 2 new test files (515 lines).

Part 1: NNX-shape LoRA Walkers

New helpers in src/maxtext/utils/lora_utils.py operating on nnx.State pure trees (no {"params": ...} outer wrap):

  • apply_lora_on_base_params_nnx mutates base_params in place: W += B @ A * scale at target attention paths
  • unapply_lora_from_base_params_nnx is the symmetric inverse
  • get_lora_abstract_state_nnx walks the abstract state.model substate and emits a parallel tree with lora_a.kernel/lora_b.kernel leaves at target attention paths and None elsewhere
  • _nnx_param_subtree drops the outer TrainStateNNX wrapping

The base model stays pristine; "apply" merges the delta into the kernel, "unapply" reverses. No nnx.LoRA wrapper, no model surgery. The on-disk format (HuggingFace PEFT-style lora_a.kernel / lora_b.kernel) round-trips between Linen and NNX consumers unchanged.

Part 2: LoRA Dispatch in setup_initial_lora_state and load_adapter

Both top-level entry points in lora_utils.py branch on config.pure_nnx:

  • NNX init builds the abstract base via model_creation_utils.create_nnx_abstract_model + TrainStateNNX(model, optimizer)
  • Linen branch is the original init_initial_state + get_lora_abstract_state path, untouched

Part 3: MaxEngine LoRA Carve-out Cleared

src/maxtext/inference/maxengine/maxengine.py:

  • load_single_adapter no longer raises NotImplementedError on pure_nnx
  • apply_adapter / unapply_adapter branch on config.pure_nnx to call the _nnx siblings

Part 4: GRPO Loss and Step Helpers

src/maxtext/experimental/rl/grpo_trainer.py:

  • grpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train). Signature matches Linen grpo_loss_fn so callers dispatch on the same shape. dropout_rng and params are unused on NNX; reference_model is a frozen nnx.Module and the reference forward is wrapped in stop_gradient. Returns (loss, LossAux), same dataclass as Linen.
  • _train_step_nnx: nnx.merge(graphdef, state) to reconstruct TrainStateNNX, value_and_grad over policy params, state.apply_gradients(grads), return nnx.state(new_state, nnx.Not(nnx.Intermediate)).
  • _eval_step_nnx: same merge + loss-fn call, no state update.
  • train_step / eval_step early-dispatch on config.pure_nnx; Linen branches verbatim.

Part 5: GRPO setup_train_loop on NNX

grpo_trainer.py::setup_train_loop:

  • Builds training and inference models via mt.from_config(rngs=create_nnx_rngs(...))
  • Initializes state via create_nnx_abstract_model + TrainStateNNX(model, optimizer, reference_model=...)
  • Reference uses the same init seed as policy and is never updated by apply_gradients (sibling field on TrainStateNNX, not embedded in params)
  • The WARNING: GRPO RL trainer does not yet support pure_nnx natively log is removed

Part 6: GRPO train_loop NNX Branches

grpo_trainer.py::train_loop — three Linen-coupled spots branched on pure_nnx:

  • Initial reference seeding is skipped on NNX (already set up by init_state_fn)
  • metric_logger.write_setup_info_to_tensorboard receives a flat nnx.Param state on NNX
  • Checkpoint save passes the whole TrainStateNNX on NNX; the Linen _split_grpo_state(state)[0] strip is bypassed

The reshard call routes to pathways_reshard_nnx when pure_nnx. New helpers in grpo_utils.py:

  • compute_log_probs_nnx: NNX model is called directly; intermediates pulled via nnx.state(model, nnx.Intermediate).to_pure_dict()
  • pathways_reshard_nnx: splits state.model to a flat nnx.Param state, reshards onto the inference mesh, calls inference_engine.update_params(...)

Part 7: Carve-outs (NotImplementedError Sites)

Feature Tracked In
GRPO + gradient_accumulation_steps > 1 Follow-up
GRPO + scan_layers=False Follow-up (needs an NNX-aware unscan helper)

Tests

New unit tests (tests/unit/lora_utils_nnx_test.py, 10 tests):

  • 5 on get_lora_abstract_state_nnx: q/k/v/o shape derivation, target-vs-non-target masking, sharding propagation, leaf type validation, error paths
  • 3 on apply_lora_on_base_params_nnx: apply→unapply identity, target-only mutation, numerical parity vs Linen apply_lora_on_base_params on the same random inputs
  • 2 Linen regression smoke tests on apply_lora_on_base_params and unapply_lora_from_base_params (no existing unit test for these helpers in the tree)

New unit tests (tests/unit/grpo_nnx_test.py, 8 tests):

  • 5 on grpo_loss_fn_nnx: LossAux shape parity, signature compatibility, identical-policy/reference → zero KL, grpo_beta=0aux.avg_kl=None, finite policy grads
  • 1 on compute_log_probs_nnx: shape [B, S] → [B, S-1]
  • 2 Linen regression smoke tests on grpo_loss_fn and compute_log_probs (the existing Linen integration test is TPU-only and currently @pytest.mark.skip)

Modified test: tests/unit/maxengine_test.py swaps test_lora_raises_for_nnx (asserted NotImplementedError) for test_lora_load_single_adapter_reaches_loader_on_nnx (asserts FileNotFoundError from the loader).

Existing Linen tests: untouched and still pass; pure_nnx=False stays default.

Test results: 198 passed, 1 skipped (pre-existing CPU-only skip) across the broader NNX regression sweep — maxengine_test, dpo_nnx_test, train_nnx_test, lora_utils_nnx_test, grpo_nnx_test, train_state_nnx_test, train_utils_nnx_test, gradient_accumulation_nnx_test, linen_nnx_converter_test, compare_linen_nnx_checkpoint_test.

Linting: bash lint.sh — pyink + pylint 10.00/10.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet changed the title Feat/nnx native lora grpo [NNX] NNX migration prep (8/N): Feat/nnx native lora grpo May 6, 2026
@ecnal-cienet ecnal-cienet changed the title [NNX] NNX migration prep (8/N): Feat/nnx native lora grpo [NNX] NNX migration prep (8/N): native lora grpo May 6, 2026
@ecnal-cienet ecnal-cienet changed the title [NNX] NNX migration prep (8/N): native lora grpo [NNX] NNX migration prep (8/N): NNX native lora grpo May 6, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented May 6, 2026

Codecov Report

❌ Patch coverage is 60.90909% with 43 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/lora_utils.py 60.90% 36 Missing and 7 partials ⚠️

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/experimental/rl/grpo_trainer.py Outdated
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 2, 2026

🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The Pull Request successfully implements NNX-native support for LoRA serving and GRPO training, which is a key milestone in the NNX migration for MaxText. The changes are comprehensive, covering utilities, trainer logic, and inference engine integration, while maintaining parity with the existing Linen implementation. The addition of thorough unit tests for both GRPO and LoRA in NNX ensures numerical correctness and structural integrity.

🔍 General Feedback

  • Preservation of Logic: The migration faithfully reproduces the Linen logic, including specialized LoRA update walkers and GRPO loss functions, which minimizes the risk of regression.
  • Variable Naming: As noted in the inline comments, some variable naming in the LoRA utilities is inherited from a confusing pattern in the Linen path. While correct, cleaning this up in the NNX implementation would improve long-term maintainability.
  • Modularity: The separation of NNX-specific logic into *_nnx functions and branches is well-handled and keeps the codebase clean during this transition period.
  • Testing: Excellent test coverage with parity checks against Linen is a major highlight of this PR.

Comment thread src/maxtext/utils/lora_utils.py Outdated
Comment thread src/maxtext/utils/lora_utils.py Outdated
Comment thread src/maxtext/experimental/rl/grpo_trainer.py
Comment thread src/maxtext/utils/lora_utils.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants