Skip to content

Commit e1bc3f2

Browse files
committed
fix: update
1 parent 4ae99c2 commit e1bc3f2

3 files changed

Lines changed: 17 additions & 27 deletions

File tree

src/maxtext/layers/linears.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ def _compute_dot_general_nnx(
9494
if quant_dot_general is not None:
9595
if initializing:
9696
quant_dot_general.lazy_init(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None)
97-
return quant_dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None, mutable=["aqt"])
97+
return quant_dot_general(
98+
inputs, kernel, ((axis, contract_ind), ((), ())),
99+
precision=None, mutable=["aqt", "_overwrite_with_gradient"],
100+
)
98101

99102
return dot_general(
100103
inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision, out_sharding=out_sharding

src/maxtext/layers/nnx_decoders.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,8 @@ def pure_layer_fn(state_in, y_in):
530530
out = merged_layer(y_in, **kwargs)
531531
return out, nnx.state(merged_layer)
532532

533-
if not self._has_linen_fp8_side_effects():
534-
pure_layer_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
535-
out, new_state = pure_layer_fn(state, y)
533+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
534+
out, new_state = checkpointed_fn(state, y)
536535
nnx.update(layer, new_state)
537536

538537
return out
@@ -574,8 +573,7 @@ def layer_fn(carry, scanned_vars):
574573
# ONLY return non-param state to prevent memory duplication of weights
575574
return new_carry, new_current_state
576575

577-
if not self._has_linen_fp8_side_effects():
578-
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
576+
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
579577

580578
final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state))
581579

@@ -647,19 +645,6 @@ def minimal_policy(self, with_context=False, with_quantization=False):
647645
names.append("quantization")
648646
return jax.checkpoint_policies.save_only_these_names(*names)
649647

650-
def _has_linen_fp8_side_effects(self):
651-
"""Check if the current quantization uses Linen FP8 modules that create mutable state.
652-
653-
FP8 GPU/NANOO quantization with QWIX creates Linen FP8 modules (e.g.,
654-
nn.Fp8DirectDotGeneralOp, nn.NANOOFp8DotGeneralOp) during the forward pass.
655-
These modules use self.variable() to create mutable state (amax histories,
656-
scales) as side effects. When called inside jax.checkpoint, these side effects
657-
cause UnexpectedTracerError because the traced values escape the checkpoint scope
658-
through the Linen variable scope.
659-
"""
660-
cfg = self.config
661-
return cfg.use_qwix_quantization and cfg.quantization in ("fp8_gpu", "fp8_nanoo")
662-
663648
def get_remat_policy(self):
664649
"""Get remat policy for jax.checkpoint."""
665650
policy = None
@@ -1154,10 +1139,7 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
11541139
out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs)
11551140
return out_y, out_kv, nnx.state(merged_layer)
11561141

1157-
if not self._has_linen_fp8_side_effects():
1158-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
1159-
else:
1160-
checkpointed_fn = pure_layer_fn
1142+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
11611143

11621144
for lyr, layer in enumerate(self.layers):
11631145
graphdef, state = nnx.split(layer)
@@ -1261,10 +1243,7 @@ def pure_gemma_fn(graphdef, state_in, y_in):
12611243
)
12621244
return out_y, nnx.state(merged_layer)
12631245

1264-
if not self._has_linen_fp8_side_effects():
1265-
checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse)
1266-
else:
1267-
checkpointed_gemma_fn = pure_gemma_fn
1246+
checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse)
12681247

12691248
graphdef, state = nnx.split(self.layers_remainder)
12701249
y, new_state = checkpointed_gemma_fn(graphdef, state, y)

src/maxtext/layers/quantizations.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,10 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"):
638638
)
639639

640640
if config.use_qwix_quantization:
641+
if config.quantization == "fp8_gpu":
642+
return Fp8Quantization()
643+
if config.quantization == "fp8_nanoo":
644+
return NANOOFp8Quantization()
641645
return None
642646
quant_cfg = _get_quant_config(config)
643647
if quant_cfg:
@@ -819,6 +823,10 @@ def maybe_quantize_model(model, config):
819823
"""Quantize the model if quantization is enabled."""
820824
# Batch split is not using Qwix's interception feature but manual plumbing
821825
if config.use_qwix_quantization and not config.use_batch_split_schedule:
826+
# fp8_gpu/fp8_nanoo dot_general is handled by DenseGeneral's ToNNX wrapper,
827+
# bypassing QWIX interception to avoid tracer leaks inside jax.checkpoint.
828+
if config.quantization in {"fp8_gpu", "fp8_nanoo"}:
829+
return model
822830
quantization_provider = get_qt_provider(config)
823831
if quantization_provider:
824832
model = qwix.quantize_model(model, quantization_provider)

0 commit comments

Comments
 (0)