@@ -527,14 +527,14 @@ def layer_fn(carry, scanned_vars):
527527 # with jax.checkpoint's purity requirement, causing UnexpectedTracerError.
528528 # Linen's nn.scan avoids this via variable_axes={"aqt": 0,
529529 # "_overwrite_with_gradient": 0}, but jax.lax.scan has no collection awareness.
530- # Only skip checkpoint for FP8-based quant types that have this tracer leak.
531- # Other quant types (AQT int8, Qwix) don't use mutable Linen state and are
532- # compatible with jax.checkpoint.
533- _fp8_quant_types = ( quantizations . Fp8Quantization , quantizations . NANOOFp8Quantization )
534- _has_fp8_tracer_issue = isinstance (self .quant , _fp8_quant_types ) or (
535- hasattr ( quantizations , "TransformerEngineQuantization" )
536- and isinstance ( self . quant , quantizations . TransformerEngineQuantization )
537- )
530+ # Skip checkpoint for FP8-based quantization configs. We check the config
531+ # string rather than self.quant because qwix quantization (use_qwix_quantization=True)
532+ # sets self.quant=None but still intercepts lax.dot_general at runtime with
533+ # Linen FP8 modules (NvidaFp8Provider → nn.Fp8DirectDotGeneralOp).
534+ _quant_str = getattr (self .config , "quantization" , "" ) or ""
535+ _has_fp8_tracer_issue = any (
536+ fp8_id in _quant_str for fp8_id in ( "fp8" , "nanoo_fp8" )
537+ ) or _quant_str . startswith ( "te_" )
538538 if not _has_fp8_tracer_issue :
539539 layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
540540
0 commit comments