@@ -470,8 +470,16 @@ def pure_layer_fn(state_in, y_in):
470470 out = merged_layer (y_in , ** kwargs )
471471 return out , nnx .state (merged_layer )
472472
473- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
474- out , new_state = checkpointed_fn (state , y )
473+ # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
474+ # mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
475+ # but the Linen scope retains JAX tracers from the first trace, causing
476+ # UnexpectedTracerError. Skip checkpoint for these quantization types.
477+ uses_linen_fp8_mutable_state = self .config .quantization in ("fp8_nanoo" , "fp8_gpu" )
478+ if uses_linen_fp8_mutable_state :
479+ out , new_state = pure_layer_fn (state , y )
480+ else :
481+ checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
482+ out , new_state = checkpointed_fn (state , y )
475483 nnx .update (layer , new_state )
476484
477485 return out
@@ -513,9 +521,24 @@ def layer_fn(carry, scanned_vars):
513521
514522 return new_carry , new_current_state
515523
516- layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
517-
518- final_carry , scanned_other = jax .lax .scan (layer_fn , x_in , (params , state ))
524+ # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
525+ # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
526+ # intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
527+ # causing UnexpectedTracerError. Use a Python for loop instead for these types.
528+ uses_linen_fp8_mutable_state = self .config .quantization in ("fp8_nanoo" , "fp8_gpu" )
529+ if uses_linen_fp8_mutable_state :
530+ carry = x_in
531+ per_layer_states = []
532+ for i in range (length ):
533+ current_params = jax .tree .map (lambda x , i = i : x [i ], params )
534+ current_state = jax .tree .map (lambda x , i = i : x [i ], state )
535+ carry , new_state_i = layer_fn (carry , (current_params , current_state ))
536+ per_layer_states .append (new_state_i )
537+ final_carry = carry
538+ scanned_state = jax .tree .map (lambda * xs : jnp .stack (list (xs )), * per_layer_states )
539+ else :
540+ layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
541+ final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
519542
520543 if scan_axis != 0 :
521544 params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), params )
@@ -525,7 +548,7 @@ def layer_fn(carry, scanned_vars):
525548 # scan-output params and keep the original params (correctly positioned at
526549 # scan_axis) to avoid a shape mismatch when _apply_scanned_chunk tries to
527550 # write them back via dynamic_update_slice_in_dim.
528- _ , non_param_scanned_state = scanned_other .split (nnx .Param , ...)
551+ _ , non_param_scanned_state = scanned_state .split (nnx .Param , ...)
529552 scanned_state = nnx .State .merge (params , non_param_scanned_state )
530553 return final_carry , nnx .merge (graphdef , scanned_state )
531554
0 commit comments