@@ -530,9 +530,8 @@ def pure_layer_fn(state_in, y_in):
530530 out = merged_layer (y_in , ** kwargs )
531531 return out , nnx .state (merged_layer )
532532
533- if not self ._has_linen_fp8_side_effects ():
534- pure_layer_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
535- out , new_state = pure_layer_fn (state , y )
533+ checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
534+ out , new_state = checkpointed_fn (state , y )
536535 nnx .update (layer , new_state )
537536
538537 return out
@@ -574,8 +573,7 @@ def layer_fn(carry, scanned_vars):
574573 # ONLY return non-param state to prevent memory duplication of weights
575574 return new_carry , new_current_state
576575
577- if not self ._has_linen_fp8_side_effects ():
578- layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
576+ layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
579577
580578 final_carry , scanned_other = jax .lax .scan (layer_fn , x_in , (params , state ))
581579
@@ -647,19 +645,6 @@ def minimal_policy(self, with_context=False, with_quantization=False):
647645 names .append ("quantization" )
648646 return jax .checkpoint_policies .save_only_these_names (* names )
649647
650- def _has_linen_fp8_side_effects (self ):
651- """Check if the current quantization uses Linen FP8 modules that create mutable state.
652-
653- FP8 GPU/NANOO quantization with QWIX creates Linen FP8 modules (e.g.,
654- nn.Fp8DirectDotGeneralOp, nn.NANOOFp8DotGeneralOp) during the forward pass.
655- These modules use self.variable() to create mutable state (amax histories,
656- scales) as side effects. When called inside jax.checkpoint, these side effects
657- cause UnexpectedTracerError because the traced values escape the checkpoint scope
658- through the Linen variable scope.
659- """
660- cfg = self .config
661- return cfg .use_qwix_quantization and cfg .quantization in ("fp8_gpu" , "fp8_nanoo" )
662-
663648 def get_remat_policy (self ):
664649 """Get remat policy for jax.checkpoint."""
665650 policy = None
@@ -1154,10 +1139,7 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
11541139 out_y , out_kv = merged_layer (y_in , * layer_args , kv_cache = kv_in , ** layer_kwargs )
11551140 return out_y , out_kv , nnx .state (merged_layer )
11561141
1157- if not self ._has_linen_fp8_side_effects ():
1158- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
1159- else :
1160- checkpointed_fn = pure_layer_fn
1142+ checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
11611143
11621144 for lyr , layer in enumerate (self .layers ):
11631145 graphdef , state = nnx .split (layer )
@@ -1261,10 +1243,7 @@ def pure_gemma_fn(graphdef, state_in, y_in):
12611243 )
12621244 return out_y , nnx .state (merged_layer )
12631245
1264- if not self ._has_linen_fp8_side_effects ():
1265- checkpointed_gemma_fn = jax .checkpoint (pure_gemma_fn , policy = policy , prevent_cse = prevent_cse )
1266- else :
1267- checkpointed_gemma_fn = pure_gemma_fn
1246+ checkpointed_gemma_fn = jax .checkpoint (pure_gemma_fn , policy = policy , prevent_cse = prevent_cse )
12681247
12691248 graphdef , state = nnx .split (self .layers_remainder )
12701249 y , new_state = checkpointed_gemma_fn (graphdef , state , y )
0 commit comments