@@ -432,8 +432,16 @@ def pure_layer_fn(state_in, y_in):
432432 out = merged_layer (y_in , ** kwargs )
433433 return out , nnx .state (merged_layer )
434434
435- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
436- out , new_state = checkpointed_fn (state , y )
435+ # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
436+ # mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
437+ # but the Linen scope retains JAX tracers from the first trace, causing
438+ # UnexpectedTracerError. Skip checkpoint for these quantization types.
439+ uses_linen_fp8_mutable_state = self .config .quantization in ("fp8_nanoo" , "fp8_gpu" )
440+ if uses_linen_fp8_mutable_state :
441+ out , new_state = pure_layer_fn (state , y )
442+ else :
443+ checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
444+ out , new_state = checkpointed_fn (state , y )
437445 nnx .update (layer , new_state )
438446
439447 return out
@@ -475,20 +483,29 @@ def layer_fn(carry, scanned_vars):
475483
476484 return new_carry , new_current_state
477485
478- layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
479-
480- final_carry , scanned_other = jax .lax .scan (layer_fn , x_in , (params , state ))
486+ # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
487+ # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
488+ # intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
489+ # causing UnexpectedTracerError. Use a Python for loop instead for these types.
490+ uses_linen_fp8_mutable_state = self .config .quantization in ("fp8_nanoo" , "fp8_gpu" )
491+ if uses_linen_fp8_mutable_state :
492+ carry = x_in
493+ per_layer_states = []
494+ for i in range (length ):
495+ current_params = jax .tree .map (lambda x , i = i : x [i ], params )
496+ current_state = jax .tree .map (lambda x , i = i : x [i ], state )
497+ carry , new_state_i = layer_fn (carry , (current_params , current_state ))
498+ per_layer_states .append (new_state_i )
499+ final_carry = carry
500+ scanned_state = jax .tree .map (lambda * xs : jnp .stack (list (xs )), * per_layer_states )
501+ else :
502+ layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
503+ final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
481504
482505 if scan_axis != 0 :
483506 params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), params )
484507
485- # Params are read-only during the forward pass, so the scan output's copy of
486- # params is at axis=0 (lax.scan default) rather than scan_axis. Discard the
487- # scan-output params and keep the original params (correctly positioned at
488- # scan_axis) to avoid a shape mismatch when _apply_scanned_chunk tries to
489- # write them back via dynamic_update_slice_in_dim.
490- _ , non_param_scanned_state = scanned_state .split (nnx .Param , ...)
491- scanned_state = nnx .State .merge (params , non_param_scanned_state )
508+ scanned_state = nnx .State .merge (params , scanned_state )
492509 return final_carry , nnx .merge (graphdef , scanned_state )
493510
494511 def get_decoder_layers (self ):
0 commit comments