Skip to content

Commit 2f7f039

Browse files
committed
NNX: native MaxEngine inference (drop route-to-Linen path in maxengine.py)
PR5 audited maxengine.py and routed the inference path to the Linen implementation regardless of pure_nnx, with a comment block explaining that "the flag affects training, not inference serving." That kept the Linen serving path unchanged but meant pure_nnx=True users silently got the Linen engine. This change replaces the route with a real NNX flow: when config.pure_nnx=True, the engine builds an NNX Transformer, splits out (params, cache, rest) with nnx.split, and at every JIT body merges the model concretely with nnx.merge to run the forward pass. Linen is preserved byte-for-byte; every NNX edit is gated `if config.pure_nnx:` and pure_nnx=False is still the default. maxengine.py (__init__): - Build two abstract NNX Transformers on the NNX path: self.model with model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on, decode_state shape). Both are needed because NNX cache vars inherit CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and bulk_insert searches for the substring "cache_batch" in the AR-mode logical-axes tuple. nnx.eval_shape is called directly inside nn_partitioning.axis_rules rather than through create_nnx_abstract_model to avoid the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm" (same reason get_abstract_state_nnx avoids set_mesh). - Cache the graphdef from a 3-way nnx.split(Param, Cache, ...) so JIT bodies can pass (params, cache, rest) separately to nnx.merge. The rest slot (RNG vars etc.) is materialized concretely in load_params. maxengine.py (cache adapter + _nnx_run_model): - bulk_insert / _insert_jit / _maybe_*_prefill_result_cache walk the cache via tree_map_with_path and switch on path[-1].key (the cache variable name like "cached_prefill_key"). Linen mutable cache is a plain nested dict. NNX Cache state would expose a ".value" accessor at that position. Bridge via nnx.State.to_pure_dict() (after the model run) and nnx.replace_by_pure_dict (before nnx.merge), so the cache plumbing helpers see the same shape on both paths. - Add _nnx_run_model: nnx.merge(graphdef, params, cache, rest, copy=True) -> model(...) -> nnx.state(model, nnx.Cache).to_pure_dict(). copy=True avoids reusing Variable objects across traces (TraceContextError), mirroring train.py's diff_wrapper workaround. - Add _nnx_cache_state_template / _nnx_init_cache_dict helpers parametrised by mode so prefill (batch 1) and decode_state (batch N) pull from the right abstract model. maxengine.py (load_params): - New _load_params_nnx: accepts user-provided NNX-shape params or loads via from_pretrained. For user-provided params, materializes a concrete model once via _create_model_fn() to capture a real rest state for nnx.merge (wasteful but simple; the from_pretrained branch avoids this). Refreshes self.graphdef from the concrete model so subsequent merges line up exactly. - Builds self.abstract_params, populates self.prefill_kv_cache_annotations and self.kv_cache_annotations (using model_ar for the latter so bulk_insert's substring lookup hits), wraps both into NamedSharding. - pure_nnx + quantization, pure_nnx + LoRA, pure_nnx + stack_prefill_result_cache=True, pure_nnx + prefill_multisampling, and pure_nnx + prefill_concat raise NotImplementedError for now; the Linen path is the workaround. AOT compilation (aot_compile / _compile_generate_and_get_layouts) is not gated and may work as-is; not exercised by tests yet. maxengine.py (init_decode_state, _prefill_jit, _generate_jit): - _init_decode_state_nnx zero-initializes a pure-dict cache from model_ar (so the leading batch dim matches generate's input shape) and builds kv_cache_annotations_named per leaf by reading nnx.Cache.metadata. Tries "out_sharding", "sharding", and "sharding_names" because Flax 0.12.6 renamed these. - _prefill_jit / _generate_jit add an `if config.pure_nnx:` branch that calls _nnx_run_model in place of self.model.apply with mutable=["cache"]. existing_prefix.cache is threaded as a pure-dict cache directly (no params|{"cache":...} dict-merge — params is an nnx.State, not a dict). maxtext_utils.py: - New get_prefill_kv_cache_annotations_nnx / get_kv_cache_annotations_nnx that mirror the Linen helpers' return shape (per-leaf PartitionSpec tree). Both delegate to _nnx_cache_partition_specs which extracts nnx.Cache state via nnx.split, calls get_nnx_named_sharding_with_scan_axis inside nn_partitioning.axis_rules so logical axes ("layers", "cache_batch", "norm", ...) resolve to physical mesh axes, and converts the result to a pure-dict tree. tests/unit/maxengine_test.py: - New tests: test_init_nnx, test_basic_prefill_nnx (with NaN/inf and per-layer cache shape checks), test_basic_decode_nnx (4-step generate with next_pos advancement check), test_quantize_raises_for_nnx, test_lora_raises_for_nnx. - New test_linen_nnx_parity_prefill: bridges Linen-init params into the NNX engine via linen_nnx_converter (convert_linen_to_nnx -> _strip_value_wrappers -> nnx.replace_by_pure_dict) and asserts the NNX engine's prefill matches Linen on the same weights — logits within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses bf16 compute) and exact greedy first-token argmax. - Existing Linen tests untouched. Test summary: 9 passed, 1 skipped (test_chunked_prefill is a pre-existing CPU-only skip). bash lint.sh: codespell + pylint + pyink all green.
1 parent 7796d38 commit 2f7f039

3 files changed

Lines changed: 413 additions & 35 deletions

File tree

0 commit comments

Comments
 (0)