Skip to content

Commit c91a2a6

Browse files
committed
NNX: finish MaxEngine inference carve-outs (multisampling, concat, stacked prefill cache)
PR7 (NNX-native MaxEngine inference) made the core prefill/generate/insert path work under pure_nnx=True but left three serving features raising NotImplementedError on the NNX path. This promotes all three to NNX-native. Linen is preserved byte-for-byte: the original model.apply(..., mutable=["cache"]) calls are unchanged, just moved into else: branches, and every NNX edit is gated `if config.pure_nnx:`. maxengine.py: - _prefill_multisampling_jit: drops the NotImplementedError; adds a pure_nnx branch that runs prefill through _nnx_run_model (MODEL_MODE_PREFILL, batch=1) with a fresh _nnx_init_cache_dict. The loop that draws num_samples first tokens from the shared logits is unchanged. - prefill_concat: same swap; the packed positions and segment ids thread through _nnx_run_model unchanged. - stack_prefill_result_cache=True: now supported for both scan_layers values. scan_layers=True already stacks the per-layer KV cache on axis 0 (the Linen post-stack shape), so _maybe_stack/_maybe_unstack_prefill_result_cache are no-ops and prefill_kv_cache_shardings stays the full tree. scan_layers=False keeps unstacked per-layer subtrees under cache["decoder"]["layers"][i] (int keys), so _maybe_stack stacks them into one subtree with a leading layer axis, _maybe_unstack splits it back into the int-keyed per-layer dict that bulk_insert/_insert_jit walk, and _load_params_nnx prepends a layer axis to each prefix-sharding spec (the NNX analog of the Linen P(None, *spec) + ["decoder"]["layers_0"] reshape). tests/integration/maxengine_test.py: - New _build_linen_params helper and a shared _stack_prefill_roundtrip helper. - test_prefill_multisampling_nnx, test_prefill_concat_nnx: NNX vs Linen result-shape parity, finite logits + cache. - test_stack_prefill_result_cache_nnx (scan_layers=True) and test_stack_prefill_result_cache_scan_layers_false_nnx (scan_layers=False): prefill -> insert -> generate round-trip, layer-stacked leaves, finite logits, next_pos advances. Remaining NNX MaxEngine carve-outs are quantization (PR9) and LoRA (PR8), which are other PRs' scope.
1 parent 2f7f039 commit c91a2a6

2 files changed

Lines changed: 187 additions & 33 deletions

File tree

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 81 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,15 @@ def _load_params_nnx(self, params, rng):
417417
lambda x: jax.sharding.NamedSharding(self._mesh, x),
418418
self.prefill_kv_cache_annotations,
419419
)
420-
if self.config.stack_prefill_result_cache:
421-
# With scan_layers=True the NNX cache leaves are already stacked on axis 0,
422-
# so the engine's manual-stack helper (which assumes an unstacked Linen tree)
423-
# doesn't apply. Wiring this up cleanly is a Phase-2 follow-up.
424-
raise NotImplementedError("pure_nnx + stack_prefill_result_cache=True not yet supported.")
420+
if self.config.stack_prefill_result_cache and not self.config.scan_layers:
421+
# scan_layers=False has unstacked per-layer subtrees; _maybe_stack_prefill_result_cache
422+
# stacks them on a new axis 0, so add that axis to each spec and keep one layer's subtree.
423+
self.prefill_kv_cache_shardings = jax.tree.map(
424+
lambda x: jax.sharding.NamedSharding(self._mesh, jax.sharding.PartitionSpec(None, *x.spec)),
425+
self.prefill_kv_cache_shardings,
426+
)
427+
self.prefill_kv_cache_shardings = {"decoder": {"layers": self.prefill_kv_cache_shardings["decoder"]["layers"][0]}}
428+
# scan_layers=True is already stacked on axis 0; shardings stay as-is and stack/unstack are no-ops.
425429
# AR-mode abstract model so axis names use CACHE_BATCH (not CACHE_BATCH_PREFILL);
426430
# bulk_insert / _insert_jit search for "cache_batch" in the per-leaf logical axes.
427431
self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations_nnx(self.model_ar, self.config, self._mesh)
@@ -525,6 +529,16 @@ def _maybe_stack_prefill_result_cache(self, cache):
525529
if not self.config.stack_prefill_result_cache:
526530
return cache
527531

532+
if self.config.pure_nnx:
533+
if self.config.scan_layers:
534+
# scan_layers already stacks the per-layer KV cache on axis 0; nothing to restack.
535+
return cache
536+
# scan_layers=False: stack the per-layer subtrees under decoder/layers into one
537+
# subtree with a leading layer axis (matching the scan_layers=True shape).
538+
layers = cache["decoder"]["layers"]
539+
stacked = jax.tree.map(lambda *c: jnp.stack(c), *[layers[i] for i in range(self.config.num_decoder_layers)])
540+
return {"decoder": {"layers": stacked}}
541+
528542
layer_keys = []
529543
for i in range(self.config.num_decoder_layers):
530544
layer_keys.append(f"layers_{i}")
@@ -538,6 +552,16 @@ def _maybe_unstack_prefill_result_cache(self, cache):
538552
if not self.config.stack_prefill_result_cache:
539553
return cache
540554

555+
if self.config.pure_nnx:
556+
if self.config.scan_layers:
557+
# Mirror _maybe_stack_prefill_result_cache: the cache already carries the
558+
# layer axis, so there is nothing to unstack.
559+
return cache
560+
# scan_layers=False: split the leading layer axis back into per-layer subtrees.
561+
stacked = cache["decoder"]["layers"]
562+
layers = {i: jax.tree.map(lambda x, i=i: x[i], stacked) for i in range(self.config.num_decoder_layers)}
563+
return {"decoder": {"layers": layers}}
564+
541565
flat_cache, treedef = jax.tree.flatten(cache)
542566
layer_cache = [jax.tree.unflatten(treedef, flat_cache_vars) for flat_cache_vars in zip(*flat_cache, strict=True)]
543567
res_cache = {"decoder": {}}
@@ -918,9 +942,6 @@ def _prefill_multisampling_jit(
918942
prefilling stage. The number of tokens is specified by num_samples.
919943
"""
920944

921-
if self.config.pure_nnx:
922-
raise NotImplementedError("pure_nnx + prefill_multisampling not yet supported. Use pure_nnx=False.")
923-
924945
input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE]
925946
positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0)
926947

@@ -930,17 +951,32 @@ def _prefill_multisampling_jit(
930951
sequence_indicator = jnp.expand_dims(one_d_output, 0)
931952

932953
rng, new_rng = jax.random.split(rng)
933-
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
934-
flat_logits, new_vars = self.model.apply(
935-
params,
936-
input_tokens,
937-
positions,
938-
decoder_segment_ids=sequence_indicator,
939-
enable_dropout=False,
940-
model_mode=MODEL_MODE_PREFILL,
941-
rngs={"params": new_rng},
942-
mutable=["cache"],
943-
)
954+
if self.config.pure_nnx:
955+
# Prefill is batch=1 (one prompt); multi-sampling only draws several first
956+
# tokens from the shared logits below. Mirror the _prefill_jit NNX branch.
957+
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
958+
flat_logits, new_cache_dict = self._nnx_run_model(
959+
params=params,
960+
cache_dict=self._nnx_init_cache_dict(mode=MODEL_MODE_PREFILL),
961+
decoder_input_tokens=input_tokens,
962+
decoder_positions=positions,
963+
decoder_segment_ids=sequence_indicator,
964+
enable_dropout=False,
965+
model_mode=MODEL_MODE_PREFILL,
966+
)
967+
new_vars = {"cache": new_cache_dict}
968+
else:
969+
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
970+
flat_logits, new_vars = self.model.apply(
971+
params,
972+
input_tokens,
973+
positions,
974+
decoder_segment_ids=sequence_indicator,
975+
enable_dropout=False,
976+
model_mode=MODEL_MODE_PREFILL,
977+
rngs={"params": new_rng},
978+
mutable=["cache"],
979+
)
944980

945981
next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32)
946982
selected_logits = jax.lax.dynamic_slice(
@@ -1046,26 +1082,38 @@ def prefill_concat(
10461082
if existing_prefix:
10471083
raise ValueError("We don't know what to do with existing_prefix")
10481084

1049-
if self.config.pure_nnx:
1050-
raise NotImplementedError("pure_nnx + prefill_concat not yet supported. Use pure_nnx=False.")
1051-
10521085
if rng is None:
10531086
rng = jax.random.PRNGKey(0)
10541087
input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE]
10551088
decoder_positions = jnp.expand_dims(decoder_positions, 0)
10561089
decoder_segment_ids = jnp.expand_dims(decoder_segment_ids, 0)
10571090
rng, new_rng = jax.random.split(rng)
1058-
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
1059-
flat_logits, new_vars = self.model.apply(
1060-
params,
1061-
input_tokens,
1062-
decoder_positions,
1063-
decoder_segment_ids=decoder_segment_ids,
1064-
enable_dropout=False,
1065-
model_mode=MODEL_MODE_PREFILL,
1066-
rngs={"params": new_rng},
1067-
mutable=["cache"],
1068-
)
1091+
if self.config.pure_nnx:
1092+
# Packed prompts run as a single batch=1 prefill; the packed positions and
1093+
# segment ids keep the prompts separated. Mirror the _prefill_jit NNX branch.
1094+
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
1095+
flat_logits, new_cache_dict = self._nnx_run_model(
1096+
params=params,
1097+
cache_dict=self._nnx_init_cache_dict(mode=MODEL_MODE_PREFILL),
1098+
decoder_input_tokens=input_tokens,
1099+
decoder_positions=decoder_positions,
1100+
decoder_segment_ids=decoder_segment_ids,
1101+
enable_dropout=False,
1102+
model_mode=MODEL_MODE_PREFILL,
1103+
)
1104+
new_vars = {"cache": new_cache_dict}
1105+
else:
1106+
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
1107+
flat_logits, new_vars = self.model.apply(
1108+
params,
1109+
input_tokens,
1110+
decoder_positions,
1111+
decoder_segment_ids=decoder_segment_ids,
1112+
enable_dropout=False,
1113+
model_mode=MODEL_MODE_PREFILL,
1114+
rngs={"params": new_rng},
1115+
mutable=["cache"],
1116+
)
10691117
cache = new_vars["cache"]
10701118
cache = self._maybe_stack_prefill_result_cache(cache)
10711119
if return_prompt_logp:

tests/integration/maxengine_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,18 @@ def _build_nnx_params(self, cfg, mesh):
178178
_, params_state, _ = nnx.split(model, nnx.Param, ...)
179179
return params_state
180180

181+
def _build_linen_params(self, cfg, mesh):
182+
"""Materialize a Linen Transformer and return its init vars (for NNX/Linen shape parity)."""
183+
quant = quantizations.configure_quantization(cfg)
184+
model = models.transformer_as_linen(config=cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
185+
s = (cfg.global_batch_size_to_train_on, cfg.max_target_length)
186+
ids = jax.random.randint(self.rng, s, 0, cfg.vocab_size)
187+
segment_ids = jnp.zeros(s) + DECODING_ACTIVE_SEQUENCE_INDICATOR
188+
positions = jnp.stack([jnp.arange(cfg.max_target_length, dtype=jnp.int32) for _ in range(s[0])])
189+
return model.init(
190+
{"params": self.rng, "aqt": self.rng, "dropout": self.rng}, ids, positions, segment_ids, enable_dropout=False
191+
)
192+
181193
def test_init_nnx(self):
182194
"""NNX engine init exposes graphdef + abstract Transformer."""
183195
cfg = self._init_nnx_pyconfig()
@@ -257,6 +269,100 @@ def test_lora_raises_for_nnx(self):
257269
with self.assertRaises(NotImplementedError):
258270
engine.load_single_adapter("/nonexistent/adapter/path")
259271

272+
def test_prefill_multisampling_nnx(self):
273+
"""NNX prefill_multisampling matches the Linen result shape; logits + cache stay finite."""
274+
num_samples = 3
275+
input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0])
276+
true_length = 4
277+
278+
cfg = self._init_nnx_pyconfig()
279+
mesh = Mesh(maxtext_utils.create_device_mesh(cfg), cfg.mesh_axes)
280+
engine = maxengine.MaxEngine(cfg, jax.devices())
281+
params = engine.load_params(params=self._build_nnx_params(cfg, mesh))
282+
nnx_result, nnx_first = engine.prefill_multisampling(
283+
params=params, padded_tokens=input_tokens, true_length=true_length, num_samples=num_samples
284+
)
285+
286+
lin_cfg = self.init_pyconfig()
287+
lin_mesh = Mesh(maxtext_utils.create_device_mesh(lin_cfg), lin_cfg.mesh_axes)
288+
lin_engine = maxengine.MaxEngine(lin_cfg, jax.devices())
289+
lin_params = lin_engine.load_params(params=self._build_linen_params(lin_cfg, lin_mesh))
290+
lin_result, lin_first = lin_engine.prefill_multisampling(
291+
params=lin_params, padded_tokens=input_tokens, true_length=true_length, num_samples=num_samples
292+
)
293+
294+
self.assertEqual(nnx_result["tokens"].shape, lin_result["tokens"].shape)
295+
self.assertEqual(nnx_result["tokens"].shape[0], num_samples)
296+
self.assertEqual(nnx_first.data.shape, lin_first.data.shape)
297+
self.assertTrue(jnp.all(jnp.isfinite(nnx_result["logits"])))
298+
for leaf in jax.tree.leaves(nnx_result["cache"]):
299+
self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}")
300+
301+
def test_prefill_concat_nnx(self):
302+
"""NNX prefill_concat matches the Linen result shape for packed prompts."""
303+
# Two prompts of length 2 packed into one prefill of length max_prefill_predict_length=4.
304+
packed = {
305+
"padded_tokens": jnp.array([1, 306, 5360, 304]),
306+
"decoder_positions": jnp.array([0, 1, 0, 1]),
307+
"decoder_segment_ids": jnp.array([1, 1, 2, 2]),
308+
"start_pos": jnp.array([0, 2]),
309+
"true_lengths": jnp.array([2, 2]),
310+
"num_prompts": 2,
311+
}
312+
313+
cfg = self._init_nnx_pyconfig()
314+
mesh = Mesh(maxtext_utils.create_device_mesh(cfg), cfg.mesh_axes)
315+
engine = maxengine.MaxEngine(cfg, jax.devices())
316+
params = engine.load_params(params=self._build_nnx_params(cfg, mesh))
317+
nnx_cache, nnx_result, nnx_first = engine.prefill_concat(params=params, **packed)
318+
319+
lin_cfg = self.init_pyconfig()
320+
lin_mesh = Mesh(maxtext_utils.create_device_mesh(lin_cfg), lin_cfg.mesh_axes)
321+
lin_engine = maxengine.MaxEngine(lin_cfg, jax.devices())
322+
lin_params = lin_engine.load_params(params=self._build_linen_params(lin_cfg, lin_mesh))
323+
_, lin_result, lin_first = lin_engine.prefill_concat(params=lin_params, **packed)
324+
325+
self.assertEqual(nnx_result["tokens"].shape, lin_result["tokens"].shape)
326+
self.assertEqual(len(nnx_first), len(lin_first))
327+
self.assertEqual(len(nnx_first), packed["num_prompts"])
328+
self.assertTrue(jnp.all(jnp.isfinite(nnx_result["logits"])))
329+
for leaf in jax.tree.leaves(nnx_cache):
330+
self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}")
331+
332+
def _stack_prefill_roundtrip(self, cfg):
333+
"""NNX prefill -> insert -> generate round-trip with stack_prefill_result_cache=True."""
334+
mesh = Mesh(maxtext_utils.create_device_mesh(cfg), cfg.mesh_axes)
335+
engine = maxengine.MaxEngine(cfg, jax.devices())
336+
params = engine.load_params(params=self._build_nnx_params(cfg, mesh))
337+
decode_state = engine.init_decode_state()
338+
prefill_result, _ = engine.prefill(params=params, padded_tokens=jnp.array([1, 306, 5360, 304]), true_length=4)
339+
340+
# stack=True puts the layer axis on axis 0: already there for scan_layers=True,
341+
# stacked from the per-layer subtrees for scan_layers=False.
342+
for leaf in jax.tree.leaves(prefill_result["cache"]):
343+
self.assertEqual(leaf.shape[0], cfg.num_decoder_layers, msg=f"layer-axis mismatch, got shape={leaf.shape}")
344+
self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}")
345+
346+
decode_state = engine.insert(prefill_result, decode_state, slot=0)
347+
initial_next_pos = int(decode_state["next_pos"][0, 0])
348+
for step in range(3):
349+
decode_state, result_token = engine.generate(params=params, decode_state=decode_state)
350+
self.assertEqual(result_token.data.shape[1], 3)
351+
self.assertTrue(jnp.all(jnp.isfinite(decode_state["logits"])))
352+
self.assertEqual(
353+
int(decode_state["next_pos"][0, 0]),
354+
initial_next_pos + step + 1,
355+
msg=f"next_pos didn't advance at step {step}",
356+
)
357+
358+
def test_stack_prefill_result_cache_nnx(self):
359+
"""stack_prefill_result_cache=True with scan_layers=True (cache is already layer-stacked)."""
360+
self._stack_prefill_roundtrip(self._init_nnx_pyconfig(stack_prefill_result_cache=True))
361+
362+
def test_stack_prefill_result_cache_scan_layers_false_nnx(self):
363+
"""stack_prefill_result_cache=True with scan_layers=False (per-layer subtrees get stacked)."""
364+
self._stack_prefill_roundtrip(self._init_nnx_pyconfig(stack_prefill_result_cache=True, scan_layers=False))
365+
260366
@pytest.mark.skip(reason="Can only pass on CPU.")
261367
def test_chunked_prefill(self):
262368
"""Test identical result between chunked prefill with single and multiple chunked.

0 commit comments

Comments
 (0)