Skip to content

Commit ef4f069

Browse files
committed
update
1 parent 069ca4e commit ef4f069

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -527,14 +527,14 @@ def layer_fn(carry, scanned_vars):
527527
# with jax.checkpoint's purity requirement, causing UnexpectedTracerError.
528528
# Linen's nn.scan avoids this via variable_axes={"aqt": 0,
529529
# "_overwrite_with_gradient": 0}, but jax.lax.scan has no collection awareness.
530-
# Only skip checkpoint for FP8-based quant types that have this tracer leak.
531-
# Other quant types (AQT int8, Qwix) don't use mutable Linen state and are
532-
# compatible with jax.checkpoint.
533-
_fp8_quant_types = (quantizations.Fp8Quantization, quantizations.NANOOFp8Quantization)
534-
_has_fp8_tracer_issue = isinstance(self.quant, _fp8_quant_types) or (
535-
hasattr(quantizations, "TransformerEngineQuantization")
536-
and isinstance(self.quant, quantizations.TransformerEngineQuantization)
537-
)
530+
# Skip checkpoint for FP8-based quantization configs. We check the config
531+
# string rather than self.quant because qwix quantization (use_qwix_quantization=True)
532+
# sets self.quant=None but still intercepts lax.dot_general at runtime with
533+
# Linen FP8 modules (NvidaFp8Provider → nn.Fp8DirectDotGeneralOp).
534+
_quant_str = getattr(self.config, "quantization", "") or ""
535+
_has_fp8_tracer_issue = any(
536+
fp8_id in _quant_str for fp8_id in ("fp8", "nanoo_fp8")
537+
) or _quant_str.startswith("te_")
538538
if not _has_fp8_tracer_issue:
539539
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
540540

0 commit comments

Comments
 (0)