@@ -425,8 +425,16 @@ def pure_layer_fn(state_in, y_in):
425425 out = merged_layer (y_in , ** kwargs )
426426 return out , nnx .state (merged_layer )
427427
428- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
429- out , new_state = checkpointed_fn (state , y )
428+ # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
429+ # mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
430+ # but the Linen scope retains JAX tracers from the first trace, causing
431+ # UnexpectedTracerError. Skip checkpoint for these quantization types.
432+ uses_linen_fp8_mutable_state = self .config .quantization in ("fp8_nanoo" , "fp8_gpu" )
433+ if uses_linen_fp8_mutable_state :
434+ out , new_state = pure_layer_fn (state , y )
435+ else :
436+ checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
437+ out , new_state = checkpointed_fn (state , y )
430438 nnx .update (layer , new_state )
431439
432440 return out
@@ -468,14 +476,28 @@ def layer_fn(carry, scanned_vars):
468476
469477 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
470478
471- # Extract the updated state to return it
472- # _, new_current_state = nnx.split(layer, nnx.Param, ...)
473- new_current_state = nnx .state (layer )
479+ # Extract the updated state to return it.
480+ _ , _ , new_current_state = nnx .split (layer , nnx .Intermediate , ...)
474481 return new_carry , new_current_state
475482
476- layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
477-
478- final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
483+ # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
484+ # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
485+ # intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
486+ # causing UnexpectedTracerError. Use a Python for loop instead for these types.
487+ uses_linen_fp8_mutable_state = self .config .quantization in ("fp8_nanoo" , "fp8_gpu" )
488+ if uses_linen_fp8_mutable_state :
489+ carry = x_in
490+ per_layer_states = []
491+ for i in range (length ):
492+ current_params = jax .tree .map (lambda x , i = i : x [i ], params )
493+ current_state = jax .tree .map (lambda x , i = i : x [i ], state )
494+ carry , new_state_i = layer_fn (carry , (current_params , current_state ))
495+ per_layer_states .append (new_state_i )
496+ final_carry = carry
497+ scanned_state = jax .tree .map (lambda * xs : jnp .stack (list (xs )), * per_layer_states )
498+ else :
499+ layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
500+ final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
479501
480502 if scan_axis != 0 :
481503 scanned_params , scanned_other = scanned_state .split (nnx .Param , ...)
0 commit comments