@@ -311,7 +311,7 @@ def __init__(
311311
312312 num_moe = config .num_decoder_layers - config .first_num_dense_layers
313313
314- self .moe_layer = self ._create_scanned_layers (moe_cls , length = num_moe , rngs = rngs )
314+ self .moe_layers = self ._create_scanned_layers (moe_cls , length = num_moe , rngs = rngs )
315315 elif self .is_gemma3 :
316316 attention_pattern_length = len (gemma3 .GEMMA3_ATTENTION_PATTERN )
317317 scan_length = config .num_decoder_layers // attention_pattern_length
@@ -337,7 +337,11 @@ def __init__(
337337 "interleave_moe_layer_step" : self .config .interleave_moe_layer_step ,
338338 }
339339
340- self .layers = self ._create_scanned_layers (layer_cls , length = num_layers , rngs = rngs , ** layer_kwargs )
340+ if num_layers > 0 :
341+ self .layers = self ._create_scanned_layers (layer_cls , length = num_layers , rngs = rngs , ** layer_kwargs )
342+ else :
343+ self .layers = nnx .List ([])
344+
341345 else :
342346 self .layers = nnx .List ([])
343347
@@ -437,7 +441,7 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
437441 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
438442 graphdef , params , state = nnx .split (
439443 layers , nnx .Param , ...
440- )
444+ ) # state: the mutable state we carry (KV cache, RNGs, etc.)
441445
442446 scan_axis = self .config .param_scan_axis
443447 if scan_axis != 0 :
@@ -447,21 +451,16 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
447451 sig = inspect .signature (layer_cls .__call__ )
448452 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
449453
450- layer_cls = layers .__class__
451- sig = inspect .signature (layer_cls .__call__ )
452- valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
453-
454454 def layer_fn (carry , scanned_vars ):
455455 current_params , current_state = scanned_vars
456456
457457 if self .config .parameter_memory_host_offload :
458458 current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
459459
460460 layer = nnx .merge (graphdef , current_params , current_state )
461-
462461 layer_out = layer (carry , * args , ** valid_kwargs )
463-
464462 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
463+ new_current_state = nnx .state (layer )
465464
466465 return new_carry , new_current_state
467466
@@ -823,43 +822,41 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg
823822 graphdef , state = nnx .split (layer_stack )
824823 params , rest = state .split (nnx .Param , ...)
825824 scan_axis = self .config .param_scan_axis
826-
825+
827826 # Helper to generate N-dimensional basic slices (e.g., x[:, idx, :])
828827 def _extract_slice (x , idx , axis ):
829828 slices = tuple (idx if i == axis else slice (None ) for i in range (x .ndim ))
830829 return x [slices ]
831-
830+
832831 # Slice using native indexing instead of jnp.take
833832 sliced_params = jax .tree .map (lambda x : _extract_slice (x , current_idx , scan_axis ), params )
834833 sliced_rest = jax .tree .map (lambda x : _extract_slice (x , current_idx , 0 ), rest )
835-
834+
836835 single_layer = nnx .merge (graphdef , sliced_params , sliced_rest )
837-
836+
838837 # Run the single layer
839838 out = single_layer (
840- y , * args ,
841- decoder_input_tokens = kwargs .get ("decoder_input_tokens" ),
842- ** kwargs .get ("layer_kwargs" , {})
839+ y , * args , decoder_input_tokens = kwargs .get ("decoder_input_tokens" ), ** kwargs .get ("layer_kwargs" , {})
843840 )
844841 y = out [0 ] if isinstance (out , tuple ) else out
845-
842+
846843 # Re-merge the updated state back into the specific slice of the stack
847844 new_state = nnx .state (single_layer )
848845 new_params , new_rest = new_state .split (nnx .Param , ...)
849-
846+
850847 updated_params = jax .tree .map (
851848 lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (
852849 s , jnp .expand_dims (new_s , axis = scan_axis ), current_idx , axis = scan_axis
853- ),
854- params , new_params
850+ ),
851+ params ,
852+ new_params ,
855853 )
856854 updated_rest = jax .tree .map (
857- lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (
858- s , jnp .expand_dims (new_s , axis = 0 ), current_idx , axis = 0
859- ),
860- rest , new_rest
855+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , jnp .expand_dims (new_s , axis = 0 ), current_idx , axis = 0 ),
856+ rest ,
857+ new_rest ,
861858 )
862-
859+
863860 nnx .update (layer_stack , updated_params , updated_rest )
864861 return y
865862
@@ -870,38 +867,32 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args
870867 graphdef , state = nnx .split (layer_stack )
871868 params , rest = state .split (nnx .Param , ...)
872869 scan_axis = self .config .param_scan_axis
873-
870+
874871 # Slice the chunk state along the correct axes
875872 chunk_params = jax .tree .map (
876- lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = scan_axis ),
877- params
878- )
879- chunk_rest = jax .tree .map (
880- lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = 0 ),
881- rest
873+ lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = scan_axis ), params
882874 )
875+ chunk_rest = jax .tree .map (lambda x : jax .lax .dynamic_slice_in_dim (x , current_idx , scan_length , axis = 0 ), rest )
883876 chunk_stack = nnx .merge (graphdef , chunk_params , chunk_rest )
884-
877+
885878 # Apply sequentially
886879 y , chunk_stack = self ._apply_layers_sequentially (
887880 chunk_stack , y , * args , length = scan_length , ** kwargs .get ("layer_kwargs" , {})
888881 )
889-
882+
890883 # Update the original stack state
891884 new_state = nnx .state (chunk_stack )
892885 new_params , new_rest = new_state .split (nnx .Param , ...)
893-
886+
894887 updated_params = jax .tree .map (
895- lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = scan_axis ),
896- params , new_params
888+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = scan_axis ), params , new_params
897889 )
898890 updated_rest = jax .tree .map (
899- lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = 0 ),
900- rest , new_rest
891+ lambda s , new_s : jax .lax .dynamic_update_slice_in_dim (s , new_s , current_idx , axis = 0 ), rest , new_rest
901892 )
902-
893+
903894 nnx .update (layer_stack , updated_params , updated_rest )
904-
895+
905896 return y
906897
907898 def _apply_interleaved_scanned_layers (self , y , layer_stack , start_idx , end_idx , engram_indices , * args , ** kwargs ):
@@ -990,7 +981,7 @@ def __call__(
990981
991982 y = self ._apply_interleaved_scanned_layers (
992983 y ,
993- self .moe_layer ,
984+ self .moe_layers ,
994985 0 ,
995986 (cfg .num_decoder_layers - cfg .first_num_dense_layers ),
996987 [e - cfg .first_num_dense_layers for e in cfg .engram_layers ],
@@ -1007,7 +998,7 @@ def __call__(
1007998 if cfg .use_batch_split_schedule :
1008999 policy = self .get_remat_policy ()
10091000
1010- mock_params = self ._build_linen_params (self .moe_layer )
1001+ mock_params = self ._build_linen_params (self .moe_layers )
10111002
10121003 y = deepseek_batchsplit .scan_batch_split_layers (
10131004 y ,
@@ -1021,8 +1012,8 @@ def __call__(
10211012 policy = policy ,
10221013 )
10231014 else :
1024- y , self .moe_layer = self ._apply_layers_sequentially (
1025- self .moe_layer , y , * layer_args , length = num_moe , ** layer_kwargs
1015+ y , self .moe_layers = self ._apply_layers_sequentially (
1016+ self .moe_layers , y , * layer_args , length = num_moe , ** layer_kwargs
10261017 )
10271018 elif self .is_gemma3 :
10281019 y = self ._apply_gemma3_scanned_blocks (
@@ -1038,7 +1029,8 @@ def __call__(
10381029 )
10391030 else :
10401031 scan_length = int (cfg .num_decoder_layers / cfg .inhomogeneous_layer_cycle_interval )
1041- y , self .layers = self ._apply_layers_sequentially (self .layers , y , * layer_args , length = scan_length , ** layer_kwargs )
1032+ if scan_length > 0 :
1033+ y , self .layers = self ._apply_layers_sequentially (self .layers , y , * layer_args , length = scan_length , ** layer_kwargs )
10421034 else :
10431035 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (cfg )
10441036
@@ -1056,7 +1048,16 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
10561048
10571049 for lyr , layer in enumerate (self .layers ):
10581050 graphdef , state = nnx .split (layer )
1059- kv_cache = kv_caches [lyr ] if kv_caches is not None else None
1051+ if kv_caches is not None :
1052+ if cfg .decoder_block == DecoderBlockType .QWEN3_NEXT :
1053+ if (lyr + 1 ) % cfg .inhomogeneous_layer_cycle_interval == 0 :
1054+ kv_cache = (kv_caches ["key_cache" ][lyr ], kv_caches ["value_cache" ][lyr ])
1055+ else :
1056+ kv_cache = None
1057+ else :
1058+ kv_cache = kv_caches [lyr ]
1059+ else :
1060+ kv_cache = None
10601061
10611062 input_tokens = decoder_input_tokens if cfg .engram_layers else None
10621063 if input_tokens is not None :
@@ -1066,7 +1067,12 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
10661067 nnx .update (layer , new_state )
10671068
10681069 if kv_caches is not None and kv_cache is not None :
1069- kv_caches [lyr ] = kv_cache
1070+ if cfg .decoder_block == DecoderBlockType .QWEN3_NEXT :
1071+ if (lyr + 1 ) % cfg .inhomogeneous_layer_cycle_interval == 0 :
1072+ kv_caches ["key_cache" ][lyr ] = kv_cache [0 ]
1073+ kv_caches ["value_cache" ][lyr ] = kv_cache [1 ]
1074+ else :
1075+ kv_caches [lyr ] = kv_cache
10701076
10711077 if deepstack_visual_embeds is not None and lyr < len (deepstack_visual_embeds ):
10721078 visual_embeds = deepstack_visual_embeds [lyr ]
@@ -1088,7 +1094,7 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
10881094
10891095 # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
10901096 # Instead, we keep track on the hidden states, which has smaller size compared to full logits
1091- if cfg .num_vocab_tiling > 1 and self .model_mode == MODEL_MODE_TRAIN :
1097+ elif cfg .num_vocab_tiling > 1 and self .model_mode == MODEL_MODE_TRAIN :
10921098 logits = None
10931099 self .sow (nnx .Intermediate , "hidden_states" , hidden_state )
10941100
0 commit comments