|
42 | 42 | from maxtext.layers import initializers, linears, mhc, normalizations, quantizations |
43 | 43 | from maxtext.layers.attentions import Attention |
44 | 44 | from maxtext.layers.embeddings import Embed, PositionalEmbedding, attend_on_embedding |
45 | | -from maxtext.layers.engram import Engram, NgramHashMapping |
46 | 45 | from maxtext.layers.normalizations import RMSNorm |
47 | 46 | from maxtext.layers.quantizations import AqtQuantization as Quant |
48 | 47 | from maxtext.models import ( |
@@ -333,7 +332,7 @@ def __init__( |
333 | 332 | dense_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs |
334 | 333 | )) |
335 | 334 | current_idx = next_boundary |
336 | | - |
| 335 | + |
337 | 336 | # 2. Create MoE Chunks (Direct setattr, NO nnx.Dict) |
338 | 337 | current_idx = config.first_num_dense_layers |
339 | 338 | while current_idx < config.num_decoder_layers: |
@@ -531,8 +530,9 @@ def pure_layer_fn(state_in, y_in): |
531 | 530 | out = merged_layer(y_in, **kwargs) |
532 | 531 | return out, nnx.state(merged_layer) |
533 | 532 |
|
534 | | - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) |
535 | | - out, new_state = checkpointed_fn(state, y) |
| 533 | + if not self._has_linen_fp8_side_effects(): |
| 534 | + pure_layer_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) |
| 535 | + out, new_state = pure_layer_fn(state, y) |
536 | 536 | nnx.update(layer, new_state) |
537 | 537 |
|
538 | 538 | return out |
@@ -574,7 +574,8 @@ def layer_fn(carry, scanned_vars): |
574 | 574 | # ONLY return non-param state to prevent memory duplication of weights |
575 | 575 | return new_carry, new_current_state |
576 | 576 |
|
577 | | - layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) |
| 577 | + if not self._has_linen_fp8_side_effects(): |
| 578 | + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) |
578 | 579 |
|
579 | 580 | final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state)) |
580 | 581 |
|
@@ -646,6 +647,19 @@ def minimal_policy(self, with_context=False, with_quantization=False): |
646 | 647 | names.append("quantization") |
647 | 648 | return jax.checkpoint_policies.save_only_these_names(*names) |
648 | 649 |
|
| 650 | + def _has_linen_fp8_side_effects(self): |
| 651 | + """Check if the current quantization uses Linen FP8 modules that create mutable state. |
| 652 | +
|
| 653 | + FP8 GPU/NANOO quantization with QWIX creates Linen FP8 modules (e.g., |
| 654 | + nn.Fp8DirectDotGeneralOp, nn.NANOOFp8DotGeneralOp) during the forward pass. |
| 655 | + These modules use self.variable() to create mutable state (amax histories, |
| 656 | + scales) as side effects. When called inside jax.checkpoint, these side effects |
| 657 | + cause UnexpectedTracerError because the traced values escape the checkpoint scope |
| 658 | + through the Linen variable scope. |
| 659 | + """ |
| 660 | + cfg = self.config |
| 661 | + return cfg.use_qwix_quantization and cfg.quantization in ("fp8_gpu", "fp8_nanoo") |
| 662 | + |
649 | 663 | def get_remat_policy(self): |
650 | 664 | """Get remat policy for jax.checkpoint.""" |
651 | 665 | policy = None |
@@ -935,7 +949,7 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices): |
935 | 949 | def _apply_single_engram_layer(self, y, layer_name, *args, **kwargs): |
936 | 950 | """Applies a single, unscanned Engram layer.""" |
937 | 951 | layer = getattr(self, layer_name) |
938 | | - |
| 952 | + |
939 | 953 | decoder_input_tokens = kwargs.get("decoder_input_tokens") |
940 | 954 | layer_kwargs = kwargs.get("layer_kwargs", {}) |
941 | 955 |
|
@@ -1000,7 +1014,7 @@ def _apply_interleaved_scanned_layers(self, y, layer_prefix, start_idx, end_idx, |
1000 | 1014 | chunk_name = f"{layer_prefix}_{current_idx}_{next_boundary - 1}" |
1001 | 1015 | chunk_stack = getattr(self, chunk_name) |
1002 | 1016 | scan_length = next_boundary - current_idx |
1003 | | - |
| 1017 | + |
1004 | 1018 | y, chunk_stack = self._apply_layers_sequentially( |
1005 | 1019 | chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {}) |
1006 | 1020 | ) |
@@ -1079,7 +1093,8 @@ def __call__( |
1079 | 1093 | ) |
1080 | 1094 |
|
1081 | 1095 | y = self._apply_interleaved_scanned_layers( |
1082 | | - y, "moe_layers", cfg.first_num_dense_layers, cfg.num_decoder_layers, cfg.engram_layers, *layer_args, **common_kwargs |
| 1096 | + y, "moe_layers", cfg.first_num_dense_layers, cfg.num_decoder_layers, |
| 1097 | + cfg.engram_layers, *layer_args, **common_kwargs |
1083 | 1098 | ) |
1084 | 1099 | else: |
1085 | 1100 | y, self.dense_layers = self._apply_layers_sequentially( |
@@ -1139,7 +1154,10 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): |
1139 | 1154 | out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs) |
1140 | 1155 | return out_y, out_kv, nnx.state(merged_layer) |
1141 | 1156 |
|
1142 | | - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) |
| 1157 | + if not self._has_linen_fp8_side_effects(): |
| 1158 | + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) |
| 1159 | + else: |
| 1160 | + checkpointed_fn = pure_layer_fn |
1143 | 1161 |
|
1144 | 1162 | for lyr, layer in enumerate(self.layers): |
1145 | 1163 | graphdef, state = nnx.split(layer) |
@@ -1243,7 +1261,10 @@ def pure_gemma_fn(graphdef, state_in, y_in): |
1243 | 1261 | ) |
1244 | 1262 | return out_y, nnx.state(merged_layer) |
1245 | 1263 |
|
1246 | | - checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) |
| 1264 | + if not self._has_linen_fp8_side_effects(): |
| 1265 | + checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) |
| 1266 | + else: |
| 1267 | + checkpointed_gemma_fn = pure_gemma_fn |
1247 | 1268 |
|
1248 | 1269 | graphdef, state = nnx.split(self.layers_remainder) |
1249 | 1270 | y, new_state = checkpointed_gemma_fn(graphdef, state, y) |
|
0 commit comments