@@ -69,34 +69,21 @@ def init_pyconfig(self, **kwargs):
6969 )
7070 return config
7171
72- def test_stack_and_unstack_prefill_cache (self ):
73- config = pyconfig .initialize (
74- [None , get_test_config_path ()],
75- enable_checkpointing = False ,
76- stack_prefill_result_cache = True ,
77- )
78- engine = maxengine .MaxEngine (config , jax .devices ())
72+ def test_stack_and_unstack_prefill_cache_nnx (self ):
73+ """scan_layers=False: per-layer cache subtrees stack onto a leading layer axis and back."""
74+ cfg = self ._init_nnx_pyconfig (stack_prefill_result_cache = True , scan_layers = False )
75+ engine = maxengine .MaxEngine (cfg , jax .devices ())
7976 num_layers = engine .config .num_decoder_layers
80- input_d = {
81- "decoder" : {},
82- }
83- for i in range (num_layers ):
84- input_d ["decoder" ][f"layers_{ i } " ] = {
85- "a" : jnp .ones ((1 , 10 )),
86- "b" : jnp .ones ((1 , 9 )),
87- }
88-
89- expected_stacked = {
90- "a" : jnp .ones ((num_layers , 1 , 10 )),
91- "b" : jnp .ones ((num_layers , 1 , 9 )),
92- }
77+ # scan_layers=False keeps the per-layer subtrees under decoder/layers, keyed by layer index.
78+ cache = {"decoder" : {"layers" : {i : {"a" : jnp .ones ((1 , 10 )), "b" : jnp .ones ((1 , 9 ))} for i in range (num_layers )}}}
79+
80+ expected_stacked = {"decoder" : {"layers" : {"a" : jnp .ones ((num_layers , 1 , 10 )), "b" : jnp .ones ((num_layers , 1 , 9 ))}}}
9381 # pylint: disable=protected-access
94- got_stacked = engine ._maybe_stack_prefill_result_cache (input_d )
82+ got_stacked = engine ._maybe_stack_prefill_result_cache (cache )
9583 jax .tree .map (np .testing .assert_array_equal , got_stacked , expected_stacked )
9684
97- # pylint: disable=protected-access
9885 got_unstacked = engine ._maybe_unstack_prefill_result_cache (got_stacked )
99- jax .tree .map (np .testing .assert_array_equal , got_unstacked , input_d )
86+ jax .tree .map (np .testing .assert_array_equal , got_unstacked , cache )
10087
10188 # The Linen-path basic prefill/decode tests were removed when NNX became the
10289 # default. test_basic_prefill_nnx / test_basic_decode_nnx below cover the NNX path.
@@ -113,18 +100,6 @@ def _build_nnx_params(self, cfg, mesh):
113100 _ , params_state , _ = nnx .split (model , nnx .Param , ...)
114101 return params_state
115102
116- def _build_linen_params (self , cfg , mesh ):
117- """Materialize a Linen Transformer and return its init vars (for NNX/Linen shape parity)."""
118- quant = quantizations .configure_quantization (cfg )
119- model = models .transformer_as_linen (config = cfg , mesh = mesh , quant = quant , model_mode = MODEL_MODE_PREFILL )
120- s = (cfg .global_batch_size_to_train_on , cfg .max_target_length )
121- ids = jax .random .randint (self .rng , s , 0 , cfg .vocab_size )
122- segment_ids = jnp .zeros (s ) + DECODING_ACTIVE_SEQUENCE_INDICATOR
123- positions = jnp .stack ([jnp .arange (cfg .max_target_length , dtype = jnp .int32 ) for _ in range (s [0 ])])
124- return model .init (
125- {"params" : self .rng , "aqt" : self .rng , "dropout" : self .rng }, ids , positions , segment_ids , enable_dropout = False
126- )
127-
128103 def test_init_nnx (self ):
129104 """NNX engine init exposes graphdef + abstract Transformer."""
130105 cfg = self ._init_nnx_pyconfig ()
@@ -249,7 +224,7 @@ def test_lora_load_single_adapter_reaches_loader_on_nnx(self):
249224 engine .load_single_adapter ("/nonexistent/adapter/path" )
250225
251226 def test_prefill_multisampling_nnx (self ):
252- """NNX prefill_multisampling matches the Linen result shape ; logits + cache stay finite."""
227+ """NNX prefill_multisampling draws num_samples first tokens ; logits + cache stay finite."""
253228 num_samples = 3
254229 input_tokens = jnp .array ([1 , 306 , 5360 , 304 , 0 , 0 , 0 , 0 ])
255230 true_length = 4
@@ -258,27 +233,19 @@ def test_prefill_multisampling_nnx(self):
258233 mesh = Mesh (maxtext_utils .create_device_mesh (cfg ), cfg .mesh_axes )
259234 engine = maxengine .MaxEngine (cfg , jax .devices ())
260235 params = engine .load_params (params = self ._build_nnx_params (cfg , mesh ))
261- nnx_result , nnx_first = engine .prefill_multisampling (
236+ result , first = engine .prefill_multisampling (
262237 params = params , padded_tokens = input_tokens , true_length = true_length , num_samples = num_samples
263238 )
264239
265- lin_cfg = self .init_pyconfig ()
266- lin_mesh = Mesh (maxtext_utils .create_device_mesh (lin_cfg ), lin_cfg .mesh_axes )
267- lin_engine = maxengine .MaxEngine (lin_cfg , jax .devices ())
268- lin_params = lin_engine .load_params (params = self ._build_linen_params (lin_cfg , lin_mesh ))
269- lin_result , lin_first = lin_engine .prefill_multisampling (
270- params = lin_params , padded_tokens = input_tokens , true_length = true_length , num_samples = num_samples
271- )
272-
273- self .assertEqual (nnx_result ["tokens" ].shape , lin_result ["tokens" ].shape )
274- self .assertEqual (nnx_result ["tokens" ].shape [0 ], num_samples )
275- self .assertEqual (nnx_first .data .shape , lin_first .data .shape )
276- self .assertTrue (jnp .all (jnp .isfinite (nnx_result ["logits" ])))
277- for leaf in jax .tree .leaves (nnx_result ["cache" ]):
240+ self .assertEqual (result ["tokens" ].shape [0 ], num_samples )
241+ # data packs [token, valid, length] for each sample.
242+ self .assertEqual (first .data .shape , (num_samples , 3 ))
243+ self .assertTrue (jnp .all (jnp .isfinite (result ["logits" ])))
244+ for leaf in jax .tree .leaves (result ["cache" ]):
278245 self .assertTrue (jnp .all (jnp .isfinite (leaf )), msg = f"non-finite cache leaf, shape={ leaf .shape } " )
279246
280247 def test_prefill_concat_nnx (self ):
281- """NNX prefill_concat matches the Linen result shape for packed prompts ."""
248+ """NNX prefill_concat returns one result per packed prompt; logits + cache stay finite ."""
282249 # Two prompts of length 2 packed into one prefill of length max_prefill_predict_length=4.
283250 packed = {
284251 "padded_tokens" : jnp .array ([1 , 306 , 5360 , 304 ]),
@@ -293,19 +260,12 @@ def test_prefill_concat_nnx(self):
293260 mesh = Mesh (maxtext_utils .create_device_mesh (cfg ), cfg .mesh_axes )
294261 engine = maxengine .MaxEngine (cfg , jax .devices ())
295262 params = engine .load_params (params = self ._build_nnx_params (cfg , mesh ))
296- nnx_cache , nnx_result , nnx_first = engine .prefill_concat (params = params , ** packed )
297-
298- lin_cfg = self .init_pyconfig ()
299- lin_mesh = Mesh (maxtext_utils .create_device_mesh (lin_cfg ), lin_cfg .mesh_axes )
300- lin_engine = maxengine .MaxEngine (lin_cfg , jax .devices ())
301- lin_params = lin_engine .load_params (params = self ._build_linen_params (lin_cfg , lin_mesh ))
302- _ , lin_result , lin_first = lin_engine .prefill_concat (params = lin_params , ** packed )
303-
304- self .assertEqual (nnx_result ["tokens" ].shape , lin_result ["tokens" ].shape )
305- self .assertEqual (len (nnx_first ), len (lin_first ))
306- self .assertEqual (len (nnx_first ), packed ["num_prompts" ])
307- self .assertTrue (jnp .all (jnp .isfinite (nnx_result ["logits" ])))
308- for leaf in jax .tree .leaves (nnx_cache ):
263+ cache , result , first_tokens = engine .prefill_concat (params = params , ** packed )
264+
265+ self .assertEqual (result ["tokens" ].shape [0 ], packed ["num_prompts" ])
266+ self .assertEqual (len (first_tokens ), packed ["num_prompts" ])
267+ self .assertTrue (jnp .all (jnp .isfinite (result ["logits" ])))
268+ for leaf in jax .tree .leaves (cache ):
309269 self .assertTrue (jnp .all (jnp .isfinite (leaf )), msg = f"non-finite cache leaf, shape={ leaf .shape } " )
310270
311271 def _stack_prefill_roundtrip (self , cfg ):
0 commit comments