1- # Copyright 2026 Google LLC
1+ # Copyright 2023– 2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
@@ -829,10 +829,19 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
829829 def _apply_single_engram_layer (self , y , current_idx , layer_stack , * args , ** kwargs ):
830830 """Applies a single, unscanned Engram layer by dynamically slicing the NNX state."""
831831 graphdef , state = nnx .split (layer_stack )
832+ params , rest = state .split (nnx .Param , ...)
833+ scan_axis = self .config .param_scan_axis
834+
835+ # Helper to generate N-dimensional basic slices (e.g., x[:, idx, :])
836+ def _extract_slice (x , idx , axis ):
837+ slices = tuple (idx if i == axis else slice (None ) for i in range (x .ndim ))
838+ return x [slices ]
832839
833- # Slice the parameters for the current index (assuming scan axis is 0)
834- sliced_state = jax .tree .map (lambda x : x [current_idx ], state )
835- single_layer = nnx .merge (graphdef , sliced_state )
840+ # Slice using native indexing instead of jnp.take
841+ sliced_params = jax .tree .map (lambda x : _extract_slice (x , current_idx , scan_axis ), params )
842+ sliced_rest = jax .tree .map (lambda x : _extract_slice (x , current_idx , 0 ), rest )
843+
844+ single_layer = nnx .merge (graphdef , sliced_params , sliced_rest )
836845
837846 # Run the single layer
838847 out = single_layer (
@@ -841,37 +850,57 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg
841850 y = out [0 ] if isinstance (out , tuple ) else out
842851
843852 # Re-merge the updated state back into the specific slice of the stack
844- new_single_state = nnx .state (single_layer )
845- updated_state = jax .tree .map (
853+ new_state = nnx .state (single_layer )
854+ new_params , new_rest = new_state .split (nnx .Param , ...)
855+
856+ updated_params = jax .tree .map (
857+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (
858+ s , jnp .expand_dims (new_s , axis = scan_axis ), current_idx , axis = scan_axis
859+ ),
860+ params ,
861+ new_params ,
862+ )
863+ updated_rest = jax .tree .map (
846864 lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , jnp .expand_dims (new_s , axis = 0 ), current_idx , axis = 0 ),
847- state ,
848- new_single_state ,
865+ rest ,
866+ new_rest ,
849867 )
850- nnx .update (layer_stack , updated_state )
851868
869+ nnx .update (layer_stack , updated_params , updated_rest )
852870 return y
853871
854872 def _apply_scanned_chunk (self , y , current_idx , next_boundary , layer_stack , * args , ** kwargs ):
855873 """Applies a contiguous chunk of layers using scan over a state slice."""
856874 scan_length = next_boundary - current_idx
857875 if scan_length > 0 :
858876 graphdef , state = nnx .split (layer_stack )
877+ params , rest = state .split (nnx .Param , ...)
878+ scan_axis = self .config .param_scan_axis
859879
860- # Slice the chunk state
861- chunk_state = jax .tree .map (lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = 0 ), state )
862- chunk_stack = nnx .merge (graphdef , chunk_state )
880+ # Slice the chunk state along the correct axes
881+ chunk_params = jax .tree .map (
882+ lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = scan_axis ), params
883+ )
884+ chunk_rest = jax .tree .map (lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = 0 ), rest )
885+ chunk_stack = nnx .merge (graphdef , chunk_params , chunk_rest )
863886
864887 # Apply sequentially
865888 y , chunk_stack = self ._apply_layers_sequentially (
866889 chunk_stack , y , * args , length = scan_length , ** kwargs .get ("layer_kwargs" , {})
867890 )
868891
869892 # Update the original stack state
870- new_chunk_state = nnx .state (chunk_stack )
871- updated_state = jax .tree .map (
872- lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = 0 ), state , new_chunk_state
893+ new_state = nnx .state (chunk_stack )
894+ new_params , new_rest = new_state .split (nnx .Param , ...)
895+
896+ updated_params = jax .tree .map (
897+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = scan_axis ), params , new_params
873898 )
874- nnx .update (layer_stack , updated_state )
899+ updated_rest = jax .tree .map (
900+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = 0 ), rest , new_rest
901+ )
902+
903+ nnx .update (layer_stack , updated_params , updated_rest )
875904
876905 return y
877906
0 commit comments