@@ -429,8 +429,16 @@ def pure_layer_fn(state_in, y_in):
429429 out = merged_layer (y_in , ** kwargs )
430430 return out , nnx .state (merged_layer )
431431
432- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
433- out , new_state = checkpointed_fn (state , y )
432+ # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
433+ # mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
434+ # but the Linen scope retains JAX tracers from the first trace, causing
435+ # UnexpectedTracerError. Skip checkpoint for these quantization types.
436+ uses_linen_fp8_mutable_state = self .config .quantization in ("fp8_nanoo" , "fp8_gpu" )
437+ if uses_linen_fp8_mutable_state :
438+ out , new_state = pure_layer_fn (state , y )
439+ else :
440+ checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
441+ out , new_state = checkpointed_fn (state , y )
434442 nnx .update (layer , new_state )
435443
436444 return out
@@ -467,9 +475,24 @@ def layer_fn(carry, scanned_vars):
467475
468476 return new_carry , new_current_state
469477
470- layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
471-
472- final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
478+ # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
479+ # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
480+ # intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
481+ # causing UnexpectedTracerError. Use a Python for loop instead for these types.
482+ uses_linen_fp8_mutable_state = self .config .quantization in ("fp8_nanoo" , "fp8_gpu" )
483+ if uses_linen_fp8_mutable_state :
484+ carry = x_in
485+ per_layer_states = []
486+ for i in range (length ):
487+ current_params = jax .tree .map (lambda x , i = i : x [i ], params )
488+ current_state = jax .tree .map (lambda x , i = i : x [i ], state )
489+ carry , new_state_i = layer_fn (carry , (current_params , current_state ))
490+ per_layer_states .append (new_state_i )
491+ final_carry = carry
492+ scanned_state = jax .tree .map (lambda * xs : jnp .stack (list (xs )), * per_layer_states )
493+ else :
494+ layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
495+ final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
473496
474497 if scan_axis != 0 :
475498 scanned_params , scanned_other = scanned_state .split (nnx .Param , ...)
0 commit comments