@@ -311,7 +311,7 @@ def __init__(
311311
312312 num_moe = config .num_decoder_layers - config .first_num_dense_layers
313313
314- self .moe_layer = self ._create_scanned_layers (moe_cls , length = num_moe , rngs = rngs )
314+ self .moe_layers = self ._create_scanned_layers (moe_cls , length = num_moe , rngs = rngs )
315315 elif self .is_gemma3 :
316316 attention_pattern_length = len (gemma3 .GEMMA3_ATTENTION_PATTERN )
317317 scan_length = config .num_decoder_layers // attention_pattern_length
@@ -437,7 +437,7 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
437437 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
438438 graphdef , params , state = nnx .split (
439439 layers , nnx .Param , ...
440- )
440+ ) # state: the mutable state we carry (KV cache, RNGs, etc.)
441441
442442 scan_axis = self .config .param_scan_axis
443443 if scan_axis != 0 :
@@ -458,10 +458,9 @@ def layer_fn(carry, scanned_vars):
458458 current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
459459
460460 layer = nnx .merge (graphdef , current_params , current_state )
461-
462461 layer_out = layer (carry , * args , ** valid_kwargs )
463-
464462 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
463+ new_current_state = nnx .state (layer )
465464
466465 return new_carry , new_current_state
467466
@@ -823,43 +822,41 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg
823822 graphdef , state = nnx .split (layer_stack )
824823 params , rest = state .split (nnx .Param , ...)
825824 scan_axis = self .config .param_scan_axis
826-
825+
827826 # Helper to generate N-dimensional basic slices (e.g., x[:, idx, :])
828827 def _extract_slice (x , idx , axis ):
829828 slices = tuple (idx if i == axis else slice (None ) for i in range (x .ndim ))
830829 return x [slices ]
831-
830+
832831 # Slice using native indexing instead of jnp.take
833832 sliced_params = jax .tree .map (lambda x : _extract_slice (x , current_idx , scan_axis ), params )
834833 sliced_rest = jax .tree .map (lambda x : _extract_slice (x , current_idx , 0 ), rest )
835-
834+
836835 single_layer = nnx .merge (graphdef , sliced_params , sliced_rest )
837-
836+
838837 # Run the single layer
839838 out = single_layer (
840- y , * args ,
841- decoder_input_tokens = kwargs .get ("decoder_input_tokens" ),
842- ** kwargs .get ("layer_kwargs" , {})
839+ y , * args , decoder_input_tokens = kwargs .get ("decoder_input_tokens" ), ** kwargs .get ("layer_kwargs" , {})
843840 )
844841 y = out [0 ] if isinstance (out , tuple ) else out
845-
842+
846843 # Re-merge the updated state back into the specific slice of the stack
847844 new_state = nnx .state (single_layer )
848845 new_params , new_rest = new_state .split (nnx .Param , ...)
849-
846+
850847 updated_params = jax .tree .map (
851848 lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (
852849 s , jnp .expand_dims (new_s , axis = scan_axis ), current_idx , axis = scan_axis
853- ),
854- params , new_params
850+ ),
851+ params ,
852+ new_params ,
855853 )
856854 updated_rest = jax .tree .map (
857- lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (
858- s , jnp .expand_dims (new_s , axis = 0 ), current_idx , axis = 0
859- ),
860- rest , new_rest
855+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , jnp .expand_dims (new_s , axis = 0 ), current_idx , axis = 0 ),
856+ rest ,
857+ new_rest ,
861858 )
862-
859+
863860 nnx .update (layer_stack , updated_params , updated_rest )
864861 return y
865862
@@ -870,38 +867,32 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args
870867 graphdef , state = nnx .split (layer_stack )
871868 params , rest = state .split (nnx .Param , ...)
872869 scan_axis = self .config .param_scan_axis
873-
870+
874871 # Slice the chunk state along the correct axes
875872 chunk_params = jax .tree .map (
876- lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = scan_axis ),
877- params
878- )
879- chunk_rest = jax .tree .map (
880- lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = 0 ),
881- rest
873+ lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = scan_axis ), params
882874 )
875+ chunk_rest = jax .tree .map (lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = 0 ), rest )
883876 chunk_stack = nnx .merge (graphdef , chunk_params , chunk_rest )
884-
877+
885878 # Apply sequentially
886879 y , chunk_stack = self ._apply_layers_sequentially (
887880 chunk_stack , y , * args , length = scan_length , ** kwargs .get ("layer_kwargs" , {})
888881 )
889-
882+
890883 # Update the original stack state
891884 new_state = nnx .state (chunk_stack )
892885 new_params , new_rest = new_state .split (nnx .Param , ...)
893-
886+
894887 updated_params = jax .tree .map (
895- lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = scan_axis ),
896- params , new_params
888+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = scan_axis ), params , new_params
897889 )
898890 updated_rest = jax .tree .map (
899- lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = 0 ),
900- rest , new_rest
891+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = 0 ), rest , new_rest
901892 )
902-
893+
903894 nnx .update (layer_stack , updated_params , updated_rest )
904-
895+
905896 return y
906897
907898 def _apply_interleaved_scanned_layers (self , y , layer_stack , start_idx , end_idx , engram_indices , * args , ** kwargs ):
@@ -990,7 +981,7 @@ def __call__(
990981
991982 y = self ._apply_interleaved_scanned_layers (
992983 y ,
993- self .moe_layer ,
984+ self .moe_layers ,
994985 0 ,
995986 (cfg .num_decoder_layers - cfg .first_num_dense_layers ),
996987 [e - cfg .first_num_dense_layers for e in cfg .engram_layers ],
@@ -1007,7 +998,7 @@ def __call__(
1007998 if cfg .use_batch_split_schedule :
1008999 policy = self .get_remat_policy ()
10091000
1010- mock_params = self ._build_linen_params (self .moe_layer )
1001+ mock_params = self ._build_linen_params (self .moe_layers )
10111002
10121003 y = deepseek_batchsplit .scan_batch_split_layers (
10131004 y ,
@@ -1021,8 +1012,8 @@ def __call__(
10211012 policy = policy ,
10221013 )
10231014 else :
1024- y , self .moe_layer = self ._apply_layers_sequentially (
1025- self .moe_layer , y , * layer_args , length = num_moe , ** layer_kwargs
1015+ y , self .moe_layers = self ._apply_layers_sequentially (
1016+ self .moe_layers , y , * layer_args , length = num_moe , ** layer_kwargs
10261017 )
10271018 elif self .is_gemma3 :
10281019 y = self ._apply_gemma3_scanned_blocks (
0 commit comments