Skip to content

Commit fa4a13d

Browse files
committed
NNX: fix scan carry state and invalid test config value
- Use nnx.split to exclude Intermediate variables from scan carry state in _apply_layers_sequentially (was nnx.state which included them) - Fix test_forward_pass_default_axes: "none" is parsed as None by YAML, failing Pydantic validation; use valid value "fsdp" instead
1 parent 485776c commit fa4a13d

2 files changed

Lines changed: 31 additions & 9 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,16 @@ def pure_layer_fn(state_in, y_in):
425425
out = merged_layer(y_in, **kwargs)
426426
return out, nnx.state(merged_layer)
427427

428-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
429-
out, new_state = checkpointed_fn(state, y)
428+
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
429+
# mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
430+
# but the Linen scope retains JAX tracers from the first trace, causing
431+
# UnexpectedTracerError. Skip checkpoint for these quantization types.
432+
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
433+
if uses_linen_fp8_mutable_state:
434+
out, new_state = pure_layer_fn(state, y)
435+
else:
436+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
437+
out, new_state = checkpointed_fn(state, y)
430438
nnx.update(layer, new_state)
431439

432440
return out
@@ -468,14 +476,28 @@ def layer_fn(carry, scanned_vars):
468476

469477
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
470478

471-
# Extract the updated state to return it
472-
# _, new_current_state = nnx.split(layer, nnx.Param, ...)
473-
new_current_state = nnx.state(layer)
479+
# Extract the updated state to return it.
480+
_, _, new_current_state = nnx.split(layer, nnx.Intermediate, ...)
474481
return new_carry, new_current_state
475482

476-
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
477-
478-
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
483+
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
484+
# mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
485+
# intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
486+
# causing UnexpectedTracerError. Use a Python for loop instead for these types.
487+
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
488+
if uses_linen_fp8_mutable_state:
489+
carry = x_in
490+
per_layer_states = []
491+
for i in range(length):
492+
current_params = jax.tree.map(lambda x, i=i: x[i], params)
493+
current_state = jax.tree.map(lambda x, i=i: x[i], state)
494+
carry, new_state_i = layer_fn(carry, (current_params, current_state))
495+
per_layer_states.append(new_state_i)
496+
final_carry = carry
497+
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
498+
else:
499+
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
500+
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
479501

480502
if scan_axis != 0:
481503
scanned_params, scanned_other = scanned_state.split(nnx.Param, ...)

tests/unit/nnx_decoder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_forward_pass_ep_as_context(self):
100100

101101
def test_forward_pass_default_axes(self):
102102
"""Forces execution of the default 'else' fallback."""
103-
cfg = _make_config(expert_shard_attention_option="none")
103+
cfg = _make_config(expert_shard_attention_option="fsdp")
104104
layer = NNXDecoderLayer(config=cfg, mesh=self.mesh, model_mode=MODEL_MODE_TRAIN, rngs=self.rngs)
105105
inputs, segment_ids, positions = self._make_dummy_inputs(cfg)
106106
out, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN)

0 commit comments

Comments
 (0)