Skip to content

Commit 17c9052

Browse files
hsuan-lun-chiangecnal-cienet
authored andcommitted
Revert the incorrect Fp8 fix in nnx_decoders.py
1 parent 4f4c0b0 commit 17c9052

1 file changed

Lines changed: 1 addition & 16 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -665,22 +665,7 @@ def layer_fn(carry, scanned_vars):
665665
params = nnx_ensure_scan_leading_axis(params, length)
666666
state = nnx_ensure_scan_leading_axis(state, length)
667667

668-
# Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan
669-
# leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop
670-
# for FP8 instead.
671-
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
672-
if uses_linen_fp8_mutable_state:
673-
carry = x_in
674-
per_layer_states = []
675-
for i in range(length):
676-
current_params = jax.tree.map(lambda x, i=i: x[i], params)
677-
current_state = jax.tree.map(lambda x, i=i: x[i], state)
678-
carry, new_state_i = layer_fn(carry, (current_params, current_state))
679-
per_layer_states.append(new_state_i)
680-
final_carry = carry
681-
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
682-
else:
683-
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
668+
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
684669
returned_kv_stacked = None
685670

686671
if scan_axis != 0:

0 commit comments

Comments
 (0)