Skip to content

Commit 1d3cc0c

Browse files
Fix linting
1 parent 585f1d4 commit 1d3cc0c

3 files changed

Lines changed: 86 additions & 105 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,8 +1086,8 @@ position_id_per_seconds: 25
10861086
subslice_shape: ""
10871087

10881088
# NNX
1089-
enable_nnx: false
1090-
pure_nnx_decoder: false
1089+
enable_nnx: True
1090+
pure_nnx_decoder: True
10911091

10921092
################################## Qwen3-Next Specific Configs ##################################
10931093
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/layers/nnx_decoders.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def __init__(
311311

312312
num_moe = config.num_decoder_layers - config.first_num_dense_layers
313313

314-
self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
314+
self.moe_layers = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
315315
elif self.is_gemma3:
316316
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
317317
scan_length = config.num_decoder_layers // attention_pattern_length
@@ -437,7 +437,7 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
437437
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
438438
graphdef, params, state = nnx.split(
439439
layers, nnx.Param, ...
440-
)
440+
) # state: the mutable state we carry (KV cache, RNGs, etc.)
441441

442442
scan_axis = self.config.param_scan_axis
443443
if scan_axis != 0:
@@ -458,10 +458,9 @@ def layer_fn(carry, scanned_vars):
458458
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
459459

460460
layer = nnx.merge(graphdef, current_params, current_state)
461-
462461
layer_out = layer(carry, *args, **valid_kwargs)
463-
464462
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
463+
new_current_state = nnx.state(layer)
465464

466465
return new_carry, new_current_state
467466

@@ -823,43 +822,41 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg
823822
graphdef, state = nnx.split(layer_stack)
824823
params, rest = state.split(nnx.Param, ...)
825824
scan_axis = self.config.param_scan_axis
826-
825+
827826
# Helper to generate N-dimensional basic slices (e.g., x[:, idx, :])
828827
def _extract_slice(x, idx, axis):
829828
slices = tuple(idx if i == axis else slice(None) for i in range(x.ndim))
830829
return x[slices]
831-
830+
832831
# Slice using native indexing instead of jnp.take
833832
sliced_params = jax.tree.map(lambda x: _extract_slice(x, current_idx, scan_axis), params)
834833
sliced_rest = jax.tree.map(lambda x: _extract_slice(x, current_idx, 0), rest)
835-
834+
836835
single_layer = nnx.merge(graphdef, sliced_params, sliced_rest)
837-
836+
838837
# Run the single layer
839838
out = single_layer(
840-
y, *args,
841-
decoder_input_tokens=kwargs.get("decoder_input_tokens"),
842-
**kwargs.get("layer_kwargs", {})
839+
y, *args, decoder_input_tokens=kwargs.get("decoder_input_tokens"), **kwargs.get("layer_kwargs", {})
843840
)
844841
y = out[0] if isinstance(out, tuple) else out
845-
842+
846843
# Re-merge the updated state back into the specific slice of the stack
847844
new_state = nnx.state(single_layer)
848845
new_params, new_rest = new_state.split(nnx.Param, ...)
849-
846+
850847
updated_params = jax.tree.map(
851848
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(
852849
s, jnp.expand_dims(new_s, axis=scan_axis), current_idx, axis=scan_axis
853-
),
854-
params, new_params
850+
),
851+
params,
852+
new_params,
855853
)
856854
updated_rest = jax.tree.map(
857-
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(
858-
s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0
859-
),
860-
rest, new_rest
855+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0),
856+
rest,
857+
new_rest,
861858
)
862-
859+
863860
nnx.update(layer_stack, updated_params, updated_rest)
864861
return y
865862

@@ -870,38 +867,32 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args
870867
graphdef, state = nnx.split(layer_stack)
871868
params, rest = state.split(nnx.Param, ...)
872869
scan_axis = self.config.param_scan_axis
873-
870+
874871
# Slice the chunk state along the correct axes
875872
chunk_params = jax.tree.map(
876-
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis),
877-
params
878-
)
879-
chunk_rest = jax.tree.map(
880-
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0),
881-
rest
873+
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis), params
882874
)
875+
chunk_rest = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), rest)
883876
chunk_stack = nnx.merge(graphdef, chunk_params, chunk_rest)
884-
877+
885878
# Apply sequentially
886879
y, chunk_stack = self._apply_layers_sequentially(
887880
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
888881
)
889-
882+
890883
# Update the original stack state
891884
new_state = nnx.state(chunk_stack)
892885
new_params, new_rest = new_state.split(nnx.Param, ...)
893-
886+
894887
updated_params = jax.tree.map(
895-
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis),
896-
params, new_params
888+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis), params, new_params
897889
)
898890
updated_rest = jax.tree.map(
899-
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0),
900-
rest, new_rest
891+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), rest, new_rest
901892
)
902-
893+
903894
nnx.update(layer_stack, updated_params, updated_rest)
904-
895+
905896
return y
906897

907898
def _apply_interleaved_scanned_layers(self, y, layer_stack, start_idx, end_idx, engram_indices, *args, **kwargs):
@@ -990,7 +981,7 @@ def __call__(
990981

991982
y = self._apply_interleaved_scanned_layers(
992983
y,
993-
self.moe_layer,
984+
self.moe_layers,
994985
0,
995986
(cfg.num_decoder_layers - cfg.first_num_dense_layers),
996987
[e - cfg.first_num_dense_layers for e in cfg.engram_layers],
@@ -1007,7 +998,7 @@ def __call__(
1007998
if cfg.use_batch_split_schedule:
1008999
policy = self.get_remat_policy()
10091000

1010-
mock_params = self._build_linen_params(self.moe_layer)
1001+
mock_params = self._build_linen_params(self.moe_layers)
10111002

10121003
y = deepseek_batchsplit.scan_batch_split_layers(
10131004
y,
@@ -1021,8 +1012,8 @@ def __call__(
10211012
policy=policy,
10221013
)
10231014
else:
1024-
y, self.moe_layer = self._apply_layers_sequentially(
1025-
self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs
1015+
y, self.moe_layers = self._apply_layers_sequentially(
1016+
self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs
10261017
)
10271018
elif self.is_gemma3:
10281019
y = self._apply_gemma3_scanned_blocks(

tests/unit/nnx_decoder_test.py

Lines changed: 53 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -48,69 +48,59 @@
4848
# from maxtext.layers.nnx_decoders import decoder_as_linen
4949
# from maxtext.common.common_types import MODEL_MODE_TRAIN
5050

51+
5152
class TestNNXDecoderLayerLogicalAxesUnmocked(unittest.TestCase):
52-
"""
53-
Executes pure, unmocked forward passes through NNXDecoderLayer to
54-
guarantee coverage of the logical_axis_names assignment block.
55-
"""
56-
57-
def setUp(self):
58-
super().setUp()
59-
self.rngs = nnx.Rngs(params=0, dropout=1)
60-
self.base_cfg = _make_config()
61-
self.mesh = _make_mesh(self.base_cfg)
62-
63-
def _make_dummy_inputs(self, cfg):
64-
batch = cfg.global_batch_size_to_train_on
65-
seq_len = cfg.max_target_length
66-
emb_dim = cfg.emb_dim
67-
68-
# Use jnp.ones to ensure stable, non-stochastic arrays for the forward pass
69-
inputs = jnp.ones((batch, seq_len, emb_dim), dtype=cfg.dtype)
70-
segment_ids = jnp.ones((batch, seq_len), dtype=jnp.int32)
71-
positions = jnp.broadcast_to(jnp.arange(seq_len)[None], (batch, seq_len))
72-
73-
return inputs, segment_ids, positions
74-
75-
def test_forward_pass_prefill_mode(self):
76-
"""Forces execution of: if self.model_mode == MODEL_MODE_PREFILL"""
77-
cfg = _make_config()
78-
layer = NNXDecoderLayer(
79-
config=cfg, mesh=self.mesh, model_mode=MODEL_MODE_PREFILL, rngs=self.rngs
80-
)
81-
inputs, segment_ids, positions = self._make_dummy_inputs(cfg)
82-
83-
# A real forward pass ensures all sharding and normalization lines are executed
84-
out, _ = layer(
85-
inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_PREFILL
86-
)
87-
self.assertEqual(out.shape, inputs.shape)
88-
89-
def test_forward_pass_ep_as_context(self):
90-
"""Forces execution of: elif self.config.expert_shard_attention_option == EP_AS_CONTEXT..."""
91-
cfg = _make_config(expert_shard_attention_option=EP_AS_CONTEXT)
92-
layer = NNXDecoderLayer(
93-
config=cfg, mesh=self.mesh, model_mode=MODEL_MODE_TRAIN, rngs=self.rngs
94-
)
95-
inputs, segment_ids, positions = self._make_dummy_inputs(cfg)
96-
97-
out, _ = layer(
98-
inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN
99-
)
100-
self.assertEqual(out.shape, inputs.shape)
101-
102-
def test_forward_pass_default_axes(self):
103-
"""Forces execution of the default 'else' fallback."""
104-
cfg = _make_config(expert_shard_attention_option="none")
105-
layer = NNXDecoderLayer(
106-
config=cfg, mesh=self.mesh, model_mode=MODEL_MODE_TRAIN, rngs=self.rngs
107-
)
108-
inputs, segment_ids, positions = self._make_dummy_inputs(cfg)
109-
110-
out, _ = layer(
111-
inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN
112-
)
113-
self.assertEqual(out.shape, inputs.shape)
53+
"""
54+
Executes pure, unmocked forward passes through NNXDecoderLayer to
55+
guarantee coverage of the logical_axis_names assignment block.
56+
"""
57+
58+
def setUp(self):
59+
super().setUp()
60+
self.rngs = nnx.Rngs(params=0, dropout=1)
61+
self.base_cfg = _make_config()
62+
self.mesh = _make_mesh(self.base_cfg)
63+
64+
def _make_dummy_inputs(self, cfg):
65+
batch = cfg.global_batch_size_to_train_on
66+
seq_len = cfg.max_target_length
67+
emb_dim = cfg.emb_dim
68+
69+
# Use jnp.ones to ensure stable, non-stochastic arrays for the forward pass
70+
inputs = jnp.ones((batch, seq_len, emb_dim), dtype=cfg.dtype)
71+
segment_ids = jnp.ones((batch, seq_len), dtype=jnp.int32)
72+
positions = jnp.broadcast_to(jnp.arange(seq_len)[None], (batch, seq_len))
73+
74+
return inputs, segment_ids, positions
75+
76+
def test_forward_pass_prefill_mode(self):
77+
"""Forces execution of: if self.model_mode == MODEL_MODE_PREFILL"""
78+
cfg = _make_config()
79+
layer = NNXDecoderLayer(config=cfg, mesh=self.mesh, model_mode=MODEL_MODE_PREFILL, rngs=self.rngs)
80+
inputs, segment_ids, positions = self._make_dummy_inputs(cfg)
81+
82+
# A real forward pass ensures all sharding and normalization lines are executed
83+
out, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_PREFILL)
84+
self.assertEqual(out.shape, inputs.shape)
85+
86+
def test_forward_pass_ep_as_context(self):
87+
"""Forces execution of: elif self.config.expert_shard_attention_option == EP_AS_CONTEXT..."""
88+
cfg = _make_config(expert_shard_attention_option=EP_AS_CONTEXT)
89+
layer = NNXDecoderLayer(config=cfg, mesh=self.mesh, model_mode=MODEL_MODE_TRAIN, rngs=self.rngs)
90+
inputs, segment_ids, positions = self._make_dummy_inputs(cfg)
91+
92+
out, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN)
93+
self.assertEqual(out.shape, inputs.shape)
94+
95+
def test_forward_pass_default_axes(self):
96+
"""Forces execution of the default 'else' fallback."""
97+
cfg = _make_config(expert_shard_attention_option="none")
98+
layer = NNXDecoderLayer(config=cfg, mesh=self.mesh, model_mode=MODEL_MODE_TRAIN, rngs=self.rngs)
99+
inputs, segment_ids, positions = self._make_dummy_inputs(cfg)
100+
101+
out, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN)
102+
self.assertEqual(out.shape, inputs.shape)
103+
114104

115105
if __name__ == "__main__":
116-
unittest.main()
106+
unittest.main()

0 commit comments

Comments
 (0)