Skip to content

Commit e6c5bee

Browse files
committed
test: make maxengine prefill/cache tests NNX-only
PR#11 flips the defaults to NNX, so the Linen reference engine in the prefill_multisampling/prefill_concat parity tests silently became NNX and crashed (device_put State-vs-dict), and test_stack_and_unstack_prefill_cache hit the NNX no-op branch. Drop the Linen comparisons and assert the NNX result shapes directly, rewrite the cache test for the NNX scan_layers=False path, and remove _build_linen_params and its imports.
1 parent 4122021 commit e6c5bee

1 file changed

Lines changed: 24 additions & 64 deletions

File tree

tests/integration/maxengine_test.py

Lines changed: 24 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -69,34 +69,21 @@ def init_pyconfig(self, **kwargs):
6969
)
7070
return config
7171

72-
def test_stack_and_unstack_prefill_cache(self):
73-
config = pyconfig.initialize(
74-
[None, get_test_config_path()],
75-
enable_checkpointing=False,
76-
stack_prefill_result_cache=True,
77-
)
78-
engine = maxengine.MaxEngine(config, jax.devices())
72+
def test_stack_and_unstack_prefill_cache_nnx(self):
73+
"""scan_layers=False: per-layer cache subtrees stack onto a leading layer axis and back."""
74+
cfg = self._init_nnx_pyconfig(stack_prefill_result_cache=True, scan_layers=False)
75+
engine = maxengine.MaxEngine(cfg, jax.devices())
7976
num_layers = engine.config.num_decoder_layers
80-
input_d = {
81-
"decoder": {},
82-
}
83-
for i in range(num_layers):
84-
input_d["decoder"][f"layers_{i}"] = {
85-
"a": jnp.ones((1, 10)),
86-
"b": jnp.ones((1, 9)),
87-
}
88-
89-
expected_stacked = {
90-
"a": jnp.ones((num_layers, 1, 10)),
91-
"b": jnp.ones((num_layers, 1, 9)),
92-
}
77+
# scan_layers=False keeps the per-layer subtrees under decoder/layers, keyed by layer index.
78+
cache = {"decoder": {"layers": {i: {"a": jnp.ones((1, 10)), "b": jnp.ones((1, 9))} for i in range(num_layers)}}}
79+
80+
expected_stacked = {"decoder": {"layers": {"a": jnp.ones((num_layers, 1, 10)), "b": jnp.ones((num_layers, 1, 9))}}}
9381
# pylint: disable=protected-access
94-
got_stacked = engine._maybe_stack_prefill_result_cache(input_d)
82+
got_stacked = engine._maybe_stack_prefill_result_cache(cache)
9583
jax.tree.map(np.testing.assert_array_equal, got_stacked, expected_stacked)
9684

97-
# pylint: disable=protected-access
9885
got_unstacked = engine._maybe_unstack_prefill_result_cache(got_stacked)
99-
jax.tree.map(np.testing.assert_array_equal, got_unstacked, input_d)
86+
jax.tree.map(np.testing.assert_array_equal, got_unstacked, cache)
10087

10188
# The Linen-path basic prefill/decode tests were removed when NNX became the
10289
# default. test_basic_prefill_nnx / test_basic_decode_nnx below cover the NNX path.
@@ -113,18 +100,6 @@ def _build_nnx_params(self, cfg, mesh):
113100
_, params_state, _ = nnx.split(model, nnx.Param, ...)
114101
return params_state
115102

116-
def _build_linen_params(self, cfg, mesh):
117-
"""Materialize a Linen Transformer and return its init vars (for NNX/Linen shape parity)."""
118-
quant = quantizations.configure_quantization(cfg)
119-
model = models.transformer_as_linen(config=cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
120-
s = (cfg.global_batch_size_to_train_on, cfg.max_target_length)
121-
ids = jax.random.randint(self.rng, s, 0, cfg.vocab_size)
122-
segment_ids = jnp.zeros(s) + DECODING_ACTIVE_SEQUENCE_INDICATOR
123-
positions = jnp.stack([jnp.arange(cfg.max_target_length, dtype=jnp.int32) for _ in range(s[0])])
124-
return model.init(
125-
{"params": self.rng, "aqt": self.rng, "dropout": self.rng}, ids, positions, segment_ids, enable_dropout=False
126-
)
127-
128103
def test_init_nnx(self):
129104
"""NNX engine init exposes graphdef + abstract Transformer."""
130105
cfg = self._init_nnx_pyconfig()
@@ -249,7 +224,7 @@ def test_lora_load_single_adapter_reaches_loader_on_nnx(self):
249224
engine.load_single_adapter("/nonexistent/adapter/path")
250225

251226
def test_prefill_multisampling_nnx(self):
252-
"""NNX prefill_multisampling matches the Linen result shape; logits + cache stay finite."""
227+
"""NNX prefill_multisampling draws num_samples first tokens; logits + cache stay finite."""
253228
num_samples = 3
254229
input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0])
255230
true_length = 4
@@ -258,27 +233,19 @@ def test_prefill_multisampling_nnx(self):
258233
mesh = Mesh(maxtext_utils.create_device_mesh(cfg), cfg.mesh_axes)
259234
engine = maxengine.MaxEngine(cfg, jax.devices())
260235
params = engine.load_params(params=self._build_nnx_params(cfg, mesh))
261-
nnx_result, nnx_first = engine.prefill_multisampling(
236+
result, first = engine.prefill_multisampling(
262237
params=params, padded_tokens=input_tokens, true_length=true_length, num_samples=num_samples
263238
)
264239

265-
lin_cfg = self.init_pyconfig()
266-
lin_mesh = Mesh(maxtext_utils.create_device_mesh(lin_cfg), lin_cfg.mesh_axes)
267-
lin_engine = maxengine.MaxEngine(lin_cfg, jax.devices())
268-
lin_params = lin_engine.load_params(params=self._build_linen_params(lin_cfg, lin_mesh))
269-
lin_result, lin_first = lin_engine.prefill_multisampling(
270-
params=lin_params, padded_tokens=input_tokens, true_length=true_length, num_samples=num_samples
271-
)
272-
273-
self.assertEqual(nnx_result["tokens"].shape, lin_result["tokens"].shape)
274-
self.assertEqual(nnx_result["tokens"].shape[0], num_samples)
275-
self.assertEqual(nnx_first.data.shape, lin_first.data.shape)
276-
self.assertTrue(jnp.all(jnp.isfinite(nnx_result["logits"])))
277-
for leaf in jax.tree.leaves(nnx_result["cache"]):
240+
self.assertEqual(result["tokens"].shape[0], num_samples)
241+
# data packs [token, valid, length] for each sample.
242+
self.assertEqual(first.data.shape, (num_samples, 3))
243+
self.assertTrue(jnp.all(jnp.isfinite(result["logits"])))
244+
for leaf in jax.tree.leaves(result["cache"]):
278245
self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}")
279246

280247
def test_prefill_concat_nnx(self):
281-
"""NNX prefill_concat matches the Linen result shape for packed prompts."""
248+
"""NNX prefill_concat returns one result per packed prompt; logits + cache stay finite."""
282249
# Two prompts of length 2 packed into one prefill of length max_prefill_predict_length=4.
283250
packed = {
284251
"padded_tokens": jnp.array([1, 306, 5360, 304]),
@@ -293,19 +260,12 @@ def test_prefill_concat_nnx(self):
293260
mesh = Mesh(maxtext_utils.create_device_mesh(cfg), cfg.mesh_axes)
294261
engine = maxengine.MaxEngine(cfg, jax.devices())
295262
params = engine.load_params(params=self._build_nnx_params(cfg, mesh))
296-
nnx_cache, nnx_result, nnx_first = engine.prefill_concat(params=params, **packed)
297-
298-
lin_cfg = self.init_pyconfig()
299-
lin_mesh = Mesh(maxtext_utils.create_device_mesh(lin_cfg), lin_cfg.mesh_axes)
300-
lin_engine = maxengine.MaxEngine(lin_cfg, jax.devices())
301-
lin_params = lin_engine.load_params(params=self._build_linen_params(lin_cfg, lin_mesh))
302-
_, lin_result, lin_first = lin_engine.prefill_concat(params=lin_params, **packed)
303-
304-
self.assertEqual(nnx_result["tokens"].shape, lin_result["tokens"].shape)
305-
self.assertEqual(len(nnx_first), len(lin_first))
306-
self.assertEqual(len(nnx_first), packed["num_prompts"])
307-
self.assertTrue(jnp.all(jnp.isfinite(nnx_result["logits"])))
308-
for leaf in jax.tree.leaves(nnx_cache):
263+
cache, result, first_tokens = engine.prefill_concat(params=params, **packed)
264+
265+
self.assertEqual(result["tokens"].shape[0], packed["num_prompts"])
266+
self.assertEqual(len(first_tokens), packed["num_prompts"])
267+
self.assertTrue(jnp.all(jnp.isfinite(result["logits"])))
268+
for leaf in jax.tree.leaves(cache):
309269
self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}")
310270

311271
def _stack_prefill_roundtrip(self, cfg):

0 commit comments

Comments
 (0)