@@ -328,11 +328,15 @@ def __init__(
328328 else :
329329 next_boundary = self ._find_next_boundary (current_idx , config .first_num_dense_layers , config .engram_layers )
330330 chunk_name = f"dense_layers_{ current_idx } _{ next_boundary - 1 } "
331- setattr (self , chunk_name , self ._create_scanned_layers (
332- dense_cls , length = (next_boundary - current_idx ), metadata_axis_name = chunk_name , rngs = rngs
333- ))
331+ setattr (
332+ self ,
333+ chunk_name ,
334+ self ._create_scanned_layers (
335+ dense_cls , length = (next_boundary - current_idx ), metadata_axis_name = chunk_name , rngs = rngs
336+ ),
337+ )
334338 current_idx = next_boundary
335-
339+
336340 # 2. Create MoE Chunks (Direct setattr, NO nnx.Dict)
337341 current_idx = config .first_num_dense_layers
338342 while current_idx < config .num_decoder_layers :
@@ -343,9 +347,13 @@ def __init__(
343347 else :
344348 next_boundary = self ._find_next_boundary (current_idx , config .num_decoder_layers , config .engram_layers )
345349 chunk_name = f"moe_layers_{ current_idx } _{ next_boundary - 1 } "
346- setattr (self , chunk_name , self ._create_scanned_layers (
347- moe_cls , length = (next_boundary - current_idx ), metadata_axis_name = chunk_name , rngs = rngs
348- ))
350+ setattr (
351+ self ,
352+ chunk_name ,
353+ self ._create_scanned_layers (
354+ moe_cls , length = (next_boundary - current_idx ), metadata_axis_name = chunk_name , rngs = rngs
355+ ),
356+ )
349357 current_idx = next_boundary
350358 else :
351359 # Standard DeepSeek logic when Engrams are disabled
@@ -374,7 +382,7 @@ def __init__(
374382 self .layers_remainder = RemattedGemma3Block (
375383 config = self .config , mesh = mesh , quant = self .quant , model_mode = self .model_mode , ** rem_layer_kwargs , rngs = rngs
376384 ) # pytype: disable=wrong-keyword-args
377- elif self .is_gemma4 : # <-- ADDED BLOCK
385+ elif self .is_gemma4 : # <-- ADDED BLOCK
378386 attention_pattern_length = len (gemma4 .GEMMA4_ATTENTION_PATTERN )
379387 scan_length = config .num_decoder_layers // attention_pattern_length
380388 num_remaining_layers = config .num_decoder_layers % attention_pattern_length
@@ -424,7 +432,7 @@ def __init__(
424432 layer_kwargs = {}
425433 if config .decoder_block == DecoderBlockType .GEMMA3 :
426434 layer_kwargs = {"attention_type" : gemma3 .get_attention_type (layer_id = lyr )}
427- elif config .decoder_block == DecoderBlockType .GEMMA4 : # <-- ADDED
435+ elif config .decoder_block == DecoderBlockType .GEMMA4 : # <-- ADDED
428436 layer_kwargs = {"attention_type" : gemma4 .get_attention_type (layer_id = lyr )}
429437 elif config .decoder_block == DecoderBlockType .LLAMA4 :
430438 layer_kwargs = {
@@ -932,16 +940,11 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
932940 def _apply_single_engram_layer (self , y , layer_name , * args , ** kwargs ):
933941 """Applies a single, unscanned Engram layer."""
934942 layer = getattr (self , layer_name )
935-
943+
936944 decoder_input_tokens = kwargs .get ("decoder_input_tokens" )
937945 layer_kwargs = kwargs .get ("layer_kwargs" , {})
938946
939- out = layer (
940- y ,
941- * args ,
942- decoder_input_tokens = decoder_input_tokens ,
943- ** layer_kwargs
944- )
947+ out = layer (y , * args , decoder_input_tokens = decoder_input_tokens , ** layer_kwargs )
945948 if isinstance (out , tuple ):
946949 y = out [0 ]
947950 else :
@@ -997,7 +1000,7 @@ def _apply_interleaved_scanned_layers(self, y, layer_prefix, start_idx, end_idx,
9971000 chunk_name = f"{ layer_prefix } _{ current_idx } _{ next_boundary - 1 } "
9981001 chunk_stack = getattr (self , chunk_name )
9991002 scan_length = next_boundary - current_idx
1000-
1003+
10011004 y , chunk_stack = self ._apply_layers_sequentially (
10021005 chunk_stack , y , * args , length = scan_length , ** kwargs .get ("layer_kwargs" , {})
10031006 )
@@ -1046,7 +1049,7 @@ def __call__(
10461049 # Extract the bidirectional mask locally for layer configurations
10471050 bidirectional_mask = multimodal_input .bidirectional_mask if multimodal_input is not None else None
10481051
1049- if cfg .decoder_block in (DecoderBlockType .GEMMA3 , DecoderBlockType .GEMMA4 ): # <-- UPDATED
1052+ if cfg .decoder_block in (DecoderBlockType .GEMMA3 , DecoderBlockType .GEMMA4 ): # <-- UPDATED
10501053 layer_kwargs ["bidirectional_mask" ] = bidirectional_mask
10511054
10521055 if attention_metadata is not None :
@@ -1071,7 +1074,13 @@ def __call__(
10711074 )
10721075
10731076 y = self ._apply_interleaved_scanned_layers (
1074- y , "moe_layers" , cfg .first_num_dense_layers , cfg .num_decoder_layers , cfg .engram_layers , * layer_args , ** common_kwargs
1077+ y ,
1078+ "moe_layers" ,
1079+ cfg .first_num_dense_layers ,
1080+ cfg .num_decoder_layers ,
1081+ cfg .engram_layers ,
1082+ * layer_args ,
1083+ ** common_kwargs ,
10751084 )
10761085 else :
10771086 y , self .dense_layers = self ._apply_layers_sequentially (
@@ -1123,7 +1132,7 @@ def __call__(
11231132 previous_chunk ,
11241133 page_state ,
11251134 slot ,
1126- )
1135+ )
11271136 else :
11281137 scan_length = int (cfg .num_decoder_layers / cfg .inhomogeneous_layer_cycle_interval )
11291138 if scan_length > 0 :
@@ -1303,6 +1312,7 @@ def pure_gemma_fn(graphdef, state_in, y_in):
13031312
13041313 return y
13051314
1315+
13061316def decoder_as_linen (
13071317 config : Config ,
13081318 mesh : Mesh ,
0 commit comments