4141from maxtext .layers import initializers , linears , mhc , normalizations , quantizations
4242from maxtext .layers .attentions import Attention
4343from maxtext .layers .embeddings import Embed , PositionalEmbedding , attend_on_embedding
44- from maxtext .layers .engram import Engram , NgramHashMapping
4544from maxtext .layers .normalizations import RMSNorm
4645from maxtext .layers .quantizations import AqtQuantization as Quant
4746from maxtext .models import (
@@ -328,11 +327,15 @@ def __init__(
328327 else :
329328 next_boundary = self ._find_next_boundary (current_idx , config .first_num_dense_layers , config .engram_layers )
330329 chunk_name = f"dense_layers_{ current_idx } _{ next_boundary - 1 } "
331- setattr (self , chunk_name , self ._create_scanned_layers (
332- dense_cls , length = (next_boundary - current_idx ), metadata_axis_name = chunk_name , rngs = rngs
333- ))
330+ setattr (
331+ self ,
332+ chunk_name ,
333+ self ._create_scanned_layers (
334+ dense_cls , length = (next_boundary - current_idx ), metadata_axis_name = chunk_name , rngs = rngs
335+ ),
336+ )
334337 current_idx = next_boundary
335-
338+
336339 # 2. Create MoE Chunks (Direct setattr, NO nnx.Dict)
337340 current_idx = config .first_num_dense_layers
338341 while current_idx < config .num_decoder_layers :
@@ -343,9 +346,13 @@ def __init__(
343346 else :
344347 next_boundary = self ._find_next_boundary (current_idx , config .num_decoder_layers , config .engram_layers )
345348 chunk_name = f"moe_layers_{ current_idx } _{ next_boundary - 1 } "
346- setattr (self , chunk_name , self ._create_scanned_layers (
347- moe_cls , length = (next_boundary - current_idx ), metadata_axis_name = chunk_name , rngs = rngs
348- ))
349+ setattr (
350+ self ,
351+ chunk_name ,
352+ self ._create_scanned_layers (
353+ moe_cls , length = (next_boundary - current_idx ), metadata_axis_name = chunk_name , rngs = rngs
354+ ),
355+ )
349356 current_idx = next_boundary
350357 else :
351358 # Standard DeepSeek logic when Engrams are disabled
@@ -374,7 +381,7 @@ def __init__(
374381 self .layers_remainder = RemattedGemma3Block (
375382 config = self .config , mesh = mesh , quant = self .quant , model_mode = self .model_mode , ** rem_layer_kwargs , rngs = rngs
376383 ) # pytype: disable=wrong-keyword-args
377- elif self .is_gemma4 : # <-- ADDED BLOCK
384+ elif self .is_gemma4 : # <-- ADDED BLOCK
378385 attention_pattern_length = len (gemma4 .GEMMA4_ATTENTION_PATTERN )
379386 scan_length = config .num_decoder_layers // attention_pattern_length
380387 num_remaining_layers = config .num_decoder_layers % attention_pattern_length
@@ -424,7 +431,7 @@ def __init__(
424431 layer_kwargs = {}
425432 if config .decoder_block == DecoderBlockType .GEMMA3 :
426433 layer_kwargs = {"attention_type" : gemma3 .get_attention_type (layer_id = lyr )}
427- elif config .decoder_block == DecoderBlockType .GEMMA4 : # <-- ADDED
434+ elif config .decoder_block == DecoderBlockType .GEMMA4 : # <-- ADDED
428435 layer_kwargs = {"attention_type" : gemma4 .get_attention_type (layer_id = lyr )}
429436 elif config .decoder_block == DecoderBlockType .LLAMA4 :
430437 layer_kwargs = {
@@ -932,16 +939,11 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
932939 def _apply_single_engram_layer (self , y , layer_name , * args , ** kwargs ):
933940 """Applies a single, unscanned Engram layer."""
934941 layer = getattr (self , layer_name )
935-
942+
936943 decoder_input_tokens = kwargs .get ("decoder_input_tokens" )
937944 layer_kwargs = kwargs .get ("layer_kwargs" , {})
938945
939- out = layer (
940- y ,
941- * args ,
942- decoder_input_tokens = decoder_input_tokens ,
943- ** layer_kwargs
944- )
946+ out = layer (y , * args , decoder_input_tokens = decoder_input_tokens , ** layer_kwargs )
945947 if isinstance (out , tuple ):
946948 y = out [0 ]
947949 else :
@@ -997,7 +999,7 @@ def _apply_interleaved_scanned_layers(self, y, layer_prefix, start_idx, end_idx,
997999 chunk_name = f"{ layer_prefix } _{ current_idx } _{ next_boundary - 1 } "
9981000 chunk_stack = getattr (self , chunk_name )
9991001 scan_length = next_boundary - current_idx
1000-
1002+
10011003 y , chunk_stack = self ._apply_layers_sequentially (
10021004 chunk_stack , y , * args , length = scan_length , ** kwargs .get ("layer_kwargs" , {})
10031005 )
@@ -1046,7 +1048,7 @@ def __call__(
10461048 # Extract the bidirectional mask locally for layer configurations
10471049 bidirectional_mask = multimodal_input .bidirectional_mask if multimodal_input is not None else None
10481050
1049- if cfg .decoder_block in (DecoderBlockType .GEMMA3 , DecoderBlockType .GEMMA4 ): # <-- UPDATED
1051+ if cfg .decoder_block in (DecoderBlockType .GEMMA3 , DecoderBlockType .GEMMA4 ): # <-- UPDATED
10501052 layer_kwargs ["bidirectional_mask" ] = bidirectional_mask
10511053
10521054 if attention_metadata is not None :
@@ -1071,7 +1073,13 @@ def __call__(
10711073 )
10721074
10731075 y = self ._apply_interleaved_scanned_layers (
1074- y , "moe_layers" , cfg .first_num_dense_layers , cfg .num_decoder_layers , cfg .engram_layers , * layer_args , ** common_kwargs
1076+ y ,
1077+ "moe_layers" ,
1078+ cfg .first_num_dense_layers ,
1079+ cfg .num_decoder_layers ,
1080+ cfg .engram_layers ,
1081+ * layer_args ,
1082+ ** common_kwargs ,
10751083 )
10761084 else :
10771085 y , self .dense_layers = self ._apply_layers_sequentially (
@@ -1123,7 +1131,7 @@ def __call__(
11231131 previous_chunk ,
11241132 page_state ,
11251133 slot ,
1126- )
1134+ )
11271135 else :
11281136 scan_length = int (cfg .num_decoder_layers / cfg .inhomogeneous_layer_cycle_interval )
11291137 if scan_length > 0 :
@@ -1303,6 +1311,7 @@ def pure_gemma_fn(graphdef, state_in, y_in):
13031311
13041312 return y
13051313
1314+
13061315def decoder_as_linen (
13071316 config : Config ,
13081317 mesh : Mesh ,
0 commit comments