Skip to content

Commit 46dd5a6

Browse files
Fix linting
1 parent f8669ea commit 46dd5a6

3 files changed

Lines changed: 548 additions & 122 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,8 +1111,8 @@ position_id_per_seconds: 25
11111111
subslice_shape: ""
11121112

11131113
# NNX
1114-
enable_nnx: false
1115-
pure_nnx_decoder: false
1114+
enable_nnx: True
1115+
pure_nnx_decoder: True
11161116

11171117
################################## Qwen3-Next Specific Configs ##################################
11181118
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/layers/nnx_decoders.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def __init__(
311311

312312
num_moe = config.num_decoder_layers - config.first_num_dense_layers
313313

314-
self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
314+
self.moe_layers = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
315315
elif self.is_gemma3:
316316
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
317317
scan_length = config.num_decoder_layers // attention_pattern_length
@@ -337,7 +337,11 @@ def __init__(
337337
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
338338
}
339339

340-
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs)
340+
if num_layers > 0:
341+
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs)
342+
else:
343+
self.layers = nnx.List([])
344+
341345
else:
342346
self.layers = nnx.List([])
343347

@@ -437,7 +441,7 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
437441
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
438442
graphdef, params, state = nnx.split(
439443
layers, nnx.Param, ...
440-
)
444+
) # state: the mutable state we carry (KV cache, RNGs, etc.)
441445

442446
scan_axis = self.config.param_scan_axis
443447
if scan_axis != 0:
@@ -447,21 +451,16 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
447451
sig = inspect.signature(layer_cls.__call__)
448452
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
449453

450-
layer_cls = layers.__class__
451-
sig = inspect.signature(layer_cls.__call__)
452-
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
453-
454454
def layer_fn(carry, scanned_vars):
455455
current_params, current_state = scanned_vars
456456

457457
if self.config.parameter_memory_host_offload:
458458
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
459459

460460
layer = nnx.merge(graphdef, current_params, current_state)
461-
462461
layer_out = layer(carry, *args, **valid_kwargs)
463-
464462
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
463+
new_current_state = nnx.state(layer)
465464

466465
return new_carry, new_current_state
467466

@@ -823,43 +822,41 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg
823822
graphdef, state = nnx.split(layer_stack)
824823
params, rest = state.split(nnx.Param, ...)
825824
scan_axis = self.config.param_scan_axis
826-
825+
827826
# Helper to generate N-dimensional basic slices (e.g., x[:, idx, :])
828827
def _extract_slice(x, idx, axis):
829828
slices = tuple(idx if i == axis else slice(None) for i in range(x.ndim))
830829
return x[slices]
831-
830+
832831
# Slice using native indexing instead of jnp.take
833832
sliced_params = jax.tree.map(lambda x: _extract_slice(x, current_idx, scan_axis), params)
834833
sliced_rest = jax.tree.map(lambda x: _extract_slice(x, current_idx, 0), rest)
835-
834+
836835
single_layer = nnx.merge(graphdef, sliced_params, sliced_rest)
837-
836+
838837
# Run the single layer
839838
out = single_layer(
840-
y, *args,
841-
decoder_input_tokens=kwargs.get("decoder_input_tokens"),
842-
**kwargs.get("layer_kwargs", {})
839+
y, *args, decoder_input_tokens=kwargs.get("decoder_input_tokens"), **kwargs.get("layer_kwargs", {})
843840
)
844841
y = out[0] if isinstance(out, tuple) else out
845-
842+
846843
# Re-merge the updated state back into the specific slice of the stack
847844
new_state = nnx.state(single_layer)
848845
new_params, new_rest = new_state.split(nnx.Param, ...)
849-
846+
850847
updated_params = jax.tree.map(
851848
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(
852849
s, jnp.expand_dims(new_s, axis=scan_axis), current_idx, axis=scan_axis
853-
),
854-
params, new_params
850+
),
851+
params,
852+
new_params,
855853
)
856854
updated_rest = jax.tree.map(
857-
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(
858-
s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0
859-
),
860-
rest, new_rest
855+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0),
856+
rest,
857+
new_rest,
861858
)
862-
859+
863860
nnx.update(layer_stack, updated_params, updated_rest)
864861
return y
865862

@@ -870,38 +867,32 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args
870867
graphdef, state = nnx.split(layer_stack)
871868
params, rest = state.split(nnx.Param, ...)
872869
scan_axis = self.config.param_scan_axis
873-
870+
874871
# Slice the chunk state along the correct axes
875872
chunk_params = jax.tree.map(
876-
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis),
877-
params
878-
)
879-
chunk_rest = jax.tree.map(
880-
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0),
881-
rest
873+
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis), params
882874
)
875+
chunk_rest = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), rest)
883876
chunk_stack = nnx.merge(graphdef, chunk_params, chunk_rest)
884-
877+
885878
# Apply sequentially
886879
y, chunk_stack = self._apply_layers_sequentially(
887880
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
888881
)
889-
882+
890883
# Update the original stack state
891884
new_state = nnx.state(chunk_stack)
892885
new_params, new_rest = new_state.split(nnx.Param, ...)
893-
886+
894887
updated_params = jax.tree.map(
895-
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis),
896-
params, new_params
888+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis), params, new_params
897889
)
898890
updated_rest = jax.tree.map(
899-
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0),
900-
rest, new_rest
891+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), rest, new_rest
901892
)
902-
893+
903894
nnx.update(layer_stack, updated_params, updated_rest)
904-
895+
905896
return y
906897

907898
def _apply_interleaved_scanned_layers(self, y, layer_stack, start_idx, end_idx, engram_indices, *args, **kwargs):
@@ -990,7 +981,7 @@ def __call__(
990981

991982
y = self._apply_interleaved_scanned_layers(
992983
y,
993-
self.moe_layer,
984+
self.moe_layers,
994985
0,
995986
(cfg.num_decoder_layers - cfg.first_num_dense_layers),
996987
[e - cfg.first_num_dense_layers for e in cfg.engram_layers],
@@ -1007,7 +998,7 @@ def __call__(
1007998
if cfg.use_batch_split_schedule:
1008999
policy = self.get_remat_policy()
10091000

1010-
mock_params = self._build_linen_params(self.moe_layer)
1001+
mock_params = self._build_linen_params(self.moe_layers)
10111002

10121003
y = deepseek_batchsplit.scan_batch_split_layers(
10131004
y,
@@ -1021,8 +1012,8 @@ def __call__(
10211012
policy=policy,
10221013
)
10231014
else:
1024-
y, self.moe_layer = self._apply_layers_sequentially(
1025-
self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs
1015+
y, self.moe_layers = self._apply_layers_sequentially(
1016+
self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs
10261017
)
10271018
elif self.is_gemma3:
10281019
y = self._apply_gemma3_scanned_blocks(
@@ -1038,7 +1029,8 @@ def __call__(
10381029
)
10391030
else:
10401031
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
1041-
y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
1032+
if scan_length > 0:
1033+
y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
10421034
else:
10431035
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg)
10441036

@@ -1056,7 +1048,16 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
10561048

10571049
for lyr, layer in enumerate(self.layers):
10581050
graphdef, state = nnx.split(layer)
1059-
kv_cache = kv_caches[lyr] if kv_caches is not None else None
1051+
if kv_caches is not None:
1052+
if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
1053+
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
1054+
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])
1055+
else:
1056+
kv_cache = None
1057+
else:
1058+
kv_cache = kv_caches[lyr]
1059+
else:
1060+
kv_cache = None
10601061

10611062
input_tokens = decoder_input_tokens if cfg.engram_layers else None
10621063
if input_tokens is not None:
@@ -1066,7 +1067,12 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
10661067
nnx.update(layer, new_state)
10671068

10681069
if kv_caches is not None and kv_cache is not None:
1069-
kv_caches[lyr] = kv_cache
1070+
if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
1071+
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
1072+
kv_caches["key_cache"][lyr] = kv_cache[0]
1073+
kv_caches["value_cache"][lyr] = kv_cache[1]
1074+
else:
1075+
kv_caches[lyr] = kv_cache
10701076

10711077
if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
10721078
visual_embeds = deepstack_visual_embeds[lyr]
@@ -1088,7 +1094,7 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
10881094

10891095
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
10901096
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
1091-
if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
1097+
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
10921098
logits = None
10931099
self.sow(nnx.Intermediate, "hidden_states", hidden_state)
10941100

0 commit comments

Comments
 (0)