Skip to content

Commit 4ae99c2

Browse files
committed
fix: update
1 parent 21fd4f5 commit 4ae99c2

4 files changed

Lines changed: 39 additions & 12 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from maxtext.layers import initializers, linears, mhc, normalizations, quantizations
4343
from maxtext.layers.attentions import Attention
4444
from maxtext.layers.embeddings import Embed, PositionalEmbedding, attend_on_embedding
45-
from maxtext.layers.engram import Engram, NgramHashMapping
4645
from maxtext.layers.normalizations import RMSNorm
4746
from maxtext.layers.quantizations import AqtQuantization as Quant
4847
from maxtext.models import (
@@ -333,7 +332,7 @@ def __init__(
333332
dense_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
334333
))
335334
current_idx = next_boundary
336-
335+
337336
# 2. Create MoE Chunks (Direct setattr, NO nnx.Dict)
338337
current_idx = config.first_num_dense_layers
339338
while current_idx < config.num_decoder_layers:
@@ -531,8 +530,9 @@ def pure_layer_fn(state_in, y_in):
531530
out = merged_layer(y_in, **kwargs)
532531
return out, nnx.state(merged_layer)
533532

534-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
535-
out, new_state = checkpointed_fn(state, y)
533+
if not self._has_linen_fp8_side_effects():
534+
pure_layer_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
535+
out, new_state = pure_layer_fn(state, y)
536536
nnx.update(layer, new_state)
537537

538538
return out
@@ -574,7 +574,8 @@ def layer_fn(carry, scanned_vars):
574574
# ONLY return non-param state to prevent memory duplication of weights
575575
return new_carry, new_current_state
576576

577-
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
577+
if not self._has_linen_fp8_side_effects():
578+
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
578579

579580
final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state))
580581

@@ -646,6 +647,19 @@ def minimal_policy(self, with_context=False, with_quantization=False):
646647
names.append("quantization")
647648
return jax.checkpoint_policies.save_only_these_names(*names)
648649

650+
def _has_linen_fp8_side_effects(self):
651+
"""Check if the current quantization uses Linen FP8 modules that create mutable state.
652+
653+
FP8 GPU/NANOO quantization with QWIX creates Linen FP8 modules (e.g.,
654+
nn.Fp8DirectDotGeneralOp, nn.NANOOFp8DotGeneralOp) during the forward pass.
655+
These modules use self.variable() to create mutable state (amax histories,
656+
scales) as side effects. When called inside jax.checkpoint, these side effects
657+
cause UnexpectedTracerError because the traced values escape the checkpoint scope
658+
through the Linen variable scope.
659+
"""
660+
cfg = self.config
661+
return cfg.use_qwix_quantization and cfg.quantization in ("fp8_gpu", "fp8_nanoo")
662+
649663
def get_remat_policy(self):
650664
"""Get remat policy for jax.checkpoint."""
651665
policy = None
@@ -935,7 +949,7 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
935949
def _apply_single_engram_layer(self, y, layer_name, *args, **kwargs):
936950
"""Applies a single, unscanned Engram layer."""
937951
layer = getattr(self, layer_name)
938-
952+
939953
decoder_input_tokens = kwargs.get("decoder_input_tokens")
940954
layer_kwargs = kwargs.get("layer_kwargs", {})
941955

@@ -1000,7 +1014,7 @@ def _apply_interleaved_scanned_layers(self, y, layer_prefix, start_idx, end_idx,
10001014
chunk_name = f"{layer_prefix}_{current_idx}_{next_boundary - 1}"
10011015
chunk_stack = getattr(self, chunk_name)
10021016
scan_length = next_boundary - current_idx
1003-
1017+
10041018
y, chunk_stack = self._apply_layers_sequentially(
10051019
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
10061020
)
@@ -1079,7 +1093,8 @@ def __call__(
10791093
)
10801094

10811095
y = self._apply_interleaved_scanned_layers(
1082-
y, "moe_layers", cfg.first_num_dense_layers, cfg.num_decoder_layers, cfg.engram_layers, *layer_args, **common_kwargs
1096+
y, "moe_layers", cfg.first_num_dense_layers, cfg.num_decoder_layers,
1097+
cfg.engram_layers, *layer_args, **common_kwargs
10831098
)
10841099
else:
10851100
y, self.dense_layers = self._apply_layers_sequentially(
@@ -1139,7 +1154,10 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
11391154
out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs)
11401155
return out_y, out_kv, nnx.state(merged_layer)
11411156

1142-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
1157+
if not self._has_linen_fp8_side_effects():
1158+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
1159+
else:
1160+
checkpointed_fn = pure_layer_fn
11431161

11441162
for lyr, layer in enumerate(self.layers):
11451163
graphdef, state = nnx.split(layer)
@@ -1243,7 +1261,10 @@ def pure_gemma_fn(graphdef, state_in, y_in):
12431261
)
12441262
return out_y, nnx.state(merged_layer)
12451263

1246-
checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse)
1264+
if not self._has_linen_fp8_side_effects():
1265+
checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse)
1266+
else:
1267+
checkpointed_gemma_fn = pure_gemma_fn
12471268

12481269
graphdef, state = nnx.split(self.layers_remainder)
12491270
y, new_state = checkpointed_gemma_fn(graphdef, state, y)

src/maxtext/layers/quantizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from aqt.jax.v2 import tiled_dot_general
2727
from aqt.jax.v2 import calibration
2828

29-
from maxtext.layers import nnx_wrappers
29+
3030
import qwix
3131
from qwix._src.core import dot_general_qt
3232

tests/unit/nnx_decoder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,4 +533,4 @@ def test_different_random_seeds_produce_different_logits(self):
533533

534534

535535
if __name__ == "__main__":
536-
unittest.main()
536+
unittest.main()

tests/unit/tiling_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def test_vocab_tiling_gradient_with_z_loss(self):
236236
run_name="grad_test_z_loss_with_tiling",
237237
enable_checkpointing=False,
238238
enable_dropout=False,
239+
enable_nnx=False,
239240
max_target_length=self.seq_len,
240241
per_device_batch_size=self.batch_size,
241242
logits_via_embedding=False,
@@ -302,6 +303,7 @@ def test_vocab_tiling_gradient_non_tied_embedding(self):
302303
run_name="value_and_grad_test_non_tied_with_tiling",
303304
enable_checkpointing=False,
304305
enable_dropout=False,
306+
enable_nnx=False,
305307
max_target_length=self.seq_len,
306308
per_device_batch_size=self.batch_size,
307309
logits_via_embedding=False,
@@ -366,6 +368,7 @@ def test_vocab_tiling_gradient_tied_embedding(self):
366368
self.base_config,
367369
run_name="grad_test_tied_with_tiling",
368370
enable_checkpointing=False,
371+
enable_nnx=False,
369372
max_target_length=self.seq_len,
370373
per_device_batch_size=self.batch_size,
371374
logits_via_embedding=True,
@@ -428,6 +431,7 @@ def test_vocab_tiling_gradient_data_parallelism(self):
428431
run_name="value_and_grad_test_dp_tiling",
429432
enable_checkpointing=False,
430433
enable_dropout=False,
434+
enable_nnx=False,
431435
max_target_length=self.seq_len,
432436
per_device_batch_size=self.batch_size,
433437
logits_via_embedding=False,
@@ -492,6 +496,7 @@ def test_vocab_tiling_gradient_tensor_parallelism(self):
492496
run_name="value_and_grad_test_tp_tiling",
493497
enable_checkpointing=False,
494498
enable_dropout=False,
499+
enable_nnx=False,
495500
max_target_length=self.seq_len,
496501
per_device_batch_size=self.batch_size,
497502
logits_via_embedding=False,
@@ -558,6 +563,7 @@ def test_vocab_tiling_gradient_context_parallelism(self):
558563
run_name="value_and_grad_test_cp_tiling",
559564
enable_checkpointing=False,
560565
enable_dropout=False,
566+
enable_nnx=False,
561567
max_target_length=self.seq_len,
562568
per_device_batch_size=self.batch_size,
563569
logits_via_embedding=False,

0 commit comments

Comments
 (0)