Skip to content

Commit b65eb21

Browse files
committed
test: make maxengine prefill/cache tests NNX-only
PR#11 flips the defaults to NNX, so the Linen reference engine in the prefill_multisampling/prefill_concat parity tests silently became NNX and crashed (device_put State-vs-dict), and test_stack_and_unstack_prefill_cache hit the NNX no-op branch. Drop the Linen comparisons and assert the NNX result shapes directly, rewrite the cache test for the NNX scan_layers=False path, and remove _build_linen_params and its imports.
1 parent 5f6366d commit b65eb21

1 file changed

Lines changed: 24 additions & 64 deletions

File tree

tests/integration/maxengine_test.py

Lines changed: 24 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -69,34 +69,21 @@ def init_pyconfig(self, **kwargs):
6969
)
7070
return config
7171

72-
def test_stack_and_unstack_prefill_cache(self):
73-
config = pyconfig.initialize(
74-
[None, get_test_config_path()],
75-
enable_checkpointing=False,
76-
stack_prefill_result_cache=True,
77-
)
78-
engine = maxengine.MaxEngine(config, jax.devices())
72+
def test_stack_and_unstack_prefill_cache_nnx(self):
73+
"""scan_layers=False: per-layer cache subtrees stack onto a leading layer axis and back."""
74+
cfg = self._init_nnx_pyconfig(stack_prefill_result_cache=True, scan_layers=False)
75+
engine = maxengine.MaxEngine(cfg, jax.devices())
7976
num_layers = engine.config.num_decoder_layers
80-
input_d = {
81-
"decoder": {},
82-
}
83-
for i in range(num_layers):
84-
input_d["decoder"][f"layers_{i}"] = {
85-
"a": jnp.ones((1, 10)),
86-
"b": jnp.ones((1, 9)),
87-
}
88-
89-
expected_stacked = {
90-
"a": jnp.ones((num_layers, 1, 10)),
91-
"b": jnp.ones((num_layers, 1, 9)),
92-
}
77+
# scan_layers=False keeps the per-layer subtrees under decoder/layers, keyed by layer index.
78+
cache = {"decoder": {"layers": {i: {"a": jnp.ones((1, 10)), "b": jnp.ones((1, 9))} for i in range(num_layers)}}}
79+
80+
expected_stacked = {"decoder": {"layers": {"a": jnp.ones((num_layers, 1, 10)), "b": jnp.ones((num_layers, 1, 9))}}}
9381
# pylint: disable=protected-access
94-
got_stacked = engine._maybe_stack_prefill_result_cache(input_d)
82+
got_stacked = engine._maybe_stack_prefill_result_cache(cache)
9583
jax.tree.map(np.testing.assert_array_equal, got_stacked, expected_stacked)
9684

97-
# pylint: disable=protected-access
9885
got_unstacked = engine._maybe_unstack_prefill_result_cache(got_stacked)
99-
jax.tree.map(np.testing.assert_array_equal, got_unstacked, input_d)
86+
jax.tree.map(np.testing.assert_array_equal, got_unstacked, cache)
10087

10188
# The Linen-path basic prefill/decode tests were removed when NNX became the
10289
# default. test_basic_prefill_nnx / test_basic_decode_nnx below cover the NNX path.
@@ -113,18 +100,6 @@ def _build_nnx_params(self, cfg, mesh):
113100
_, params_state, _ = nnx.split(model, nnx.Param, ...)
114101
return params_state
115102

116-
def _build_linen_params(self, cfg, mesh):
117-
"""Materialize a Linen Transformer and return its init vars (for NNX/Linen shape parity)."""
118-
quant = quantizations.configure_quantization(cfg)
119-
model = models.transformer_as_linen(config=cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
120-
s = (cfg.global_batch_size_to_train_on, cfg.max_target_length)
121-
ids = jax.random.randint(self.rng, s, 0, cfg.vocab_size)
122-
segment_ids = jnp.zeros(s) + DECODING_ACTIVE_SEQUENCE_INDICATOR
123-
positions = jnp.stack([jnp.arange(cfg.max_target_length, dtype=jnp.int32) for _ in range(s[0])])
124-
return model.init(
125-
{"params": self.rng, "aqt": self.rng, "dropout": self.rng}, ids, positions, segment_ids, enable_dropout=False
126-
)
127-
128103
def test_init_nnx(self):
129104
"""NNX engine init exposes graphdef + abstract Transformer."""
130105
cfg = self._init_nnx_pyconfig()
@@ -248,7 +223,7 @@ def test_lora_load_single_adapter_reaches_loader_on_nnx(self):
248223
engine.load_single_adapter("/nonexistent/adapter/path")
249224

250225
def test_prefill_multisampling_nnx(self):
251-
"""NNX prefill_multisampling matches the Linen result shape; logits + cache stay finite."""
226+
"""NNX prefill_multisampling draws num_samples first tokens; logits + cache stay finite."""
252227
num_samples = 3
253228
input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0])
254229
true_length = 4
@@ -257,27 +232,19 @@ def test_prefill_multisampling_nnx(self):
257232
mesh = Mesh(maxtext_utils.create_device_mesh(cfg), cfg.mesh_axes)
258233
engine = maxengine.MaxEngine(cfg, jax.devices())
259234
params = engine.load_params(params=self._build_nnx_params(cfg, mesh))
260-
nnx_result, nnx_first = engine.prefill_multisampling(
235+
result, first = engine.prefill_multisampling(
261236
params=params, padded_tokens=input_tokens, true_length=true_length, num_samples=num_samples
262237
)
263238

264-
lin_cfg = self.init_pyconfig()
265-
lin_mesh = Mesh(maxtext_utils.create_device_mesh(lin_cfg), lin_cfg.mesh_axes)
266-
lin_engine = maxengine.MaxEngine(lin_cfg, jax.devices())
267-
lin_params = lin_engine.load_params(params=self._build_linen_params(lin_cfg, lin_mesh))
268-
lin_result, lin_first = lin_engine.prefill_multisampling(
269-
params=lin_params, padded_tokens=input_tokens, true_length=true_length, num_samples=num_samples
270-
)
271-
272-
self.assertEqual(nnx_result["tokens"].shape, lin_result["tokens"].shape)
273-
self.assertEqual(nnx_result["tokens"].shape[0], num_samples)
274-
self.assertEqual(nnx_first.data.shape, lin_first.data.shape)
275-
self.assertTrue(jnp.all(jnp.isfinite(nnx_result["logits"])))
276-
for leaf in jax.tree.leaves(nnx_result["cache"]):
239+
self.assertEqual(result["tokens"].shape[0], num_samples)
240+
# data packs [token, valid, length] for each sample.
241+
self.assertEqual(first.data.shape, (num_samples, 3))
242+
self.assertTrue(jnp.all(jnp.isfinite(result["logits"])))
243+
for leaf in jax.tree.leaves(result["cache"]):
277244
self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}")
278245

279246
def test_prefill_concat_nnx(self):
280-
"""NNX prefill_concat matches the Linen result shape for packed prompts."""
247+
"""NNX prefill_concat returns one result per packed prompt; logits + cache stay finite."""
281248
# Two prompts of length 2 packed into one prefill of length max_prefill_predict_length=4.
282249
packed = {
283250
"padded_tokens": jnp.array([1, 306, 5360, 304]),
@@ -292,19 +259,12 @@ def test_prefill_concat_nnx(self):
292259
mesh = Mesh(maxtext_utils.create_device_mesh(cfg), cfg.mesh_axes)
293260
engine = maxengine.MaxEngine(cfg, jax.devices())
294261
params = engine.load_params(params=self._build_nnx_params(cfg, mesh))
295-
nnx_cache, nnx_result, nnx_first = engine.prefill_concat(params=params, **packed)
296-
297-
lin_cfg = self.init_pyconfig()
298-
lin_mesh = Mesh(maxtext_utils.create_device_mesh(lin_cfg), lin_cfg.mesh_axes)
299-
lin_engine = maxengine.MaxEngine(lin_cfg, jax.devices())
300-
lin_params = lin_engine.load_params(params=self._build_linen_params(lin_cfg, lin_mesh))
301-
_, lin_result, lin_first = lin_engine.prefill_concat(params=lin_params, **packed)
302-
303-
self.assertEqual(nnx_result["tokens"].shape, lin_result["tokens"].shape)
304-
self.assertEqual(len(nnx_first), len(lin_first))
305-
self.assertEqual(len(nnx_first), packed["num_prompts"])
306-
self.assertTrue(jnp.all(jnp.isfinite(nnx_result["logits"])))
307-
for leaf in jax.tree.leaves(nnx_cache):
262+
cache, result, first_tokens = engine.prefill_concat(params=params, **packed)
263+
264+
self.assertEqual(result["tokens"].shape[0], packed["num_prompts"])
265+
self.assertEqual(len(first_tokens), packed["num_prompts"])
266+
self.assertTrue(jnp.all(jnp.isfinite(result["logits"])))
267+
for leaf in jax.tree.leaves(cache):
308268
self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}")
309269

310270
def _stack_prefill_roundtrip(self, cfg):

0 commit comments

Comments
 (0)