@@ -311,7 +311,7 @@ def __init__(
311311
312312 num_moe = config .num_decoder_layers - config .first_num_dense_layers
313313
314- self .moe_layer = self ._create_scanned_layers (moe_cls , length = num_moe , rngs = rngs )
314+ self .moe_layers = self ._create_scanned_layers (moe_cls , length = num_moe , rngs = rngs )
315315 elif self .is_gemma3 :
316316 attention_pattern_length = len (gemma3 .GEMMA3_ATTENTION_PATTERN )
317317 scan_length = config .num_decoder_layers // attention_pattern_length
@@ -441,36 +441,27 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
441441
442442 scan_axis = self .config .param_scan_axis
443443 if scan_axis != 0 :
444- # Move scan_axis to 0 so scan can iterate over it
445444 params = jax .tree .map (lambda x : jnp .moveaxis (x , scan_axis , 0 ), params )
446445
447446 layer_cls = layers .__class__
448447 sig = inspect .signature (layer_cls .__call__ )
449448 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
450449
451- layer_cls = layers .__class__ # Access the underlying class
450+ layer_cls = layers .__class__
452451 sig = inspect .signature (layer_cls .__call__ )
453- # Filter kwargs to only include keys that exist in the layer's signature
454452 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
455453
456454 def layer_fn (carry , scanned_vars ):
457- # Unpack the sliced variables for THIS layer
458455 current_params , current_state = scanned_vars
459456
460457 if self .config .parameter_memory_host_offload :
461458 current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
462459
463- # Merge using the SLICED state
464460 layer = nnx .merge (graphdef , current_params , current_state )
465-
466- # Run the layer (Filter kwargs if using the solution from previous turn)
467461 layer_out = layer (carry , * args , ** valid_kwargs )
468-
469462 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
470-
471- # Extract the updated state to return it
472- # _, new_current_state = nnx.split(layer, nnx.Param, ...)
473463 new_current_state = nnx .state (layer )
464+
474465 return new_carry , new_current_state
475466
476467 layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
@@ -829,10 +820,19 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
829820 def _apply_single_engram_layer (self , y , current_idx , layer_stack , * args , ** kwargs ):
830821 """Applies a single, unscanned Engram layer by dynamically slicing the NNX state."""
831822 graphdef , state = nnx .split (layer_stack )
823+ params , rest = state .split (nnx .Param , ...)
824+ scan_axis = self .config .param_scan_axis
825+
826+ # Helper to generate N-dimensional basic slices (e.g., x[:, idx, :])
827+ def _extract_slice (x , idx , axis ):
828+ slices = tuple (idx if i == axis else slice (None ) for i in range (x .ndim ))
829+ return x [slices ]
832830
833- # Slice the parameters for the current index (assuming scan axis is 0)
834- sliced_state = jax .tree .map (lambda x : x [current_idx ], state )
835- single_layer = nnx .merge (graphdef , sliced_state )
831+ # Slice using native indexing instead of jnp.take
832+ sliced_params = jax .tree .map (lambda x : _extract_slice (x , current_idx , scan_axis ), params )
833+ sliced_rest = jax .tree .map (lambda x : _extract_slice (x , current_idx , 0 ), rest )
834+
835+ single_layer = nnx .merge (graphdef , sliced_params , sliced_rest )
836836
837837 # Run the single layer
838838 out = single_layer (
@@ -841,37 +841,57 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg
841841 y = out [0 ] if isinstance (out , tuple ) else out
842842
843843 # Re-merge the updated state back into the specific slice of the stack
844- new_single_state = nnx .state (single_layer )
845- updated_state = jax .tree .map (
844+ new_state = nnx .state (single_layer )
845+ new_params , new_rest = new_state .split (nnx .Param , ...)
846+
847+ updated_params = jax .tree .map (
848+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (
849+ s , jnp .expand_dims (new_s , axis = scan_axis ), current_idx , axis = scan_axis
850+ ),
851+ params ,
852+ new_params ,
853+ )
854+ updated_rest = jax .tree .map (
846855 lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , jnp .expand_dims (new_s , axis = 0 ), current_idx , axis = 0 ),
847- state ,
848- new_single_state ,
856+ rest ,
857+ new_rest ,
849858 )
850- nnx .update (layer_stack , updated_state )
851859
860+ nnx .update (layer_stack , updated_params , updated_rest )
852861 return y
853862
854863 def _apply_scanned_chunk (self , y , current_idx , next_boundary , layer_stack , * args , ** kwargs ):
855864 """Applies a contiguous chunk of layers using scan over a state slice."""
856865 scan_length = next_boundary - current_idx
857866 if scan_length > 0 :
858867 graphdef , state = nnx .split (layer_stack )
868+ params , rest = state .split (nnx .Param , ...)
869+ scan_axis = self .config .param_scan_axis
859870
860- # Slice the chunk state
861- chunk_state = jax .tree .map (lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = 0 ), state )
862- chunk_stack = nnx .merge (graphdef , chunk_state )
871+ # Slice the chunk state along the correct axes
872+ chunk_params = jax .tree .map (
873+ lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = scan_axis ), params
874+ )
875+ chunk_rest = jax .tree .map (lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = 0 ), rest )
876+ chunk_stack = nnx .merge (graphdef , chunk_params , chunk_rest )
863877
864878 # Apply sequentially
865879 y , chunk_stack = self ._apply_layers_sequentially (
866880 chunk_stack , y , * args , length = scan_length , ** kwargs .get ("layer_kwargs" , {})
867881 )
868882
869883 # Update the original stack state
870- new_chunk_state = nnx .state (chunk_stack )
871- updated_state = jax .tree .map (
872- lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = 0 ), state , new_chunk_state
884+ new_state = nnx .state (chunk_stack )
885+ new_params , new_rest = new_state .split (nnx .Param , ...)
886+
887+ updated_params = jax .tree .map (
888+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = scan_axis ), params , new_params
889+ )
890+ updated_rest = jax .tree .map (
891+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = 0 ), rest , new_rest
873892 )
874- nnx .update (layer_stack , updated_state )
893+
894+ nnx .update (layer_stack , updated_params , updated_rest )
875895
876896 return y
877897
@@ -961,7 +981,7 @@ def __call__(
961981
962982 y = self ._apply_interleaved_scanned_layers (
963983 y ,
964- self .moe_layer ,
984+ self .moe_layers ,
965985 0 ,
966986 (cfg .num_decoder_layers - cfg .first_num_dense_layers ),
967987 [e - cfg .first_num_dense_layers for e in cfg .engram_layers ],
@@ -978,7 +998,7 @@ def __call__(
978998 if cfg .use_batch_split_schedule :
979999 policy = self .get_remat_policy ()
9801000
981- mock_params = self ._build_linen_params (self .moe_layer )
1001+ mock_params = self ._build_linen_params (self .moe_layers )
9821002
9831003 y = deepseek_batchsplit .scan_batch_split_layers (
9841004 y ,
@@ -992,8 +1012,8 @@ def __call__(
9921012 policy = policy ,
9931013 )
9941014 else :
995- y , self .moe_layer = self ._apply_layers_sequentially (
996- self .moe_layer , y , * layer_args , length = num_moe , ** layer_kwargs
1015+ y , self .moe_layers = self ._apply_layers_sequentially (
1016+ self .moe_layers , y , * layer_args , length = num_moe , ** layer_kwargs
9971017 )
9981018 elif self .is_gemma3 :
9991019 y = self ._apply_gemma3_scanned_blocks (
0 commit comments