4242from maxtext .models import (
4343 deepseek ,
4444 deepseek_batchsplit ,
45+ deepseek_custom ,
4546 gemma ,
4647 gemma2 ,
4748 gemma3 ,
5253 mistral ,
5354 mixtral ,
5455 olmo3 ,
55- qwen2 ,
5656 qwen3 ,
5757 simple_layer ,
5858)
@@ -458,6 +458,14 @@ def get_decoder_layers(self):
458458 deepseek .DeepSeekDenseLayerToLinen ,
459459 deepseek .DeepSeekMoELayerToLinen ,
460460 ]
461+ case DecoderBlockType .DEEPSEEK_CUSTOM :
462+ deepseek_custom_moe_layer = deepseek_custom .DeepSeekMoELayerToLinen
463+ if self .config .scan_layers and self .config .attention_layer_hybrid_ratio > 1 :
464+ deepseek_custom_moe_layer = deepseek_custom .DeepSeekMoEScannableBlockToLinen
465+ return [
466+ deepseek_custom .DeepSeekDenseLayerToLinen ,
467+ deepseek_custom_moe_layer ,
468+ ]
461469 case DecoderBlockType .GEMMA :
462470 return [gemma .GemmaDecoderLayerToLinen ]
463471 case DecoderBlockType .GEMMA2 :
@@ -468,8 +476,6 @@ def get_decoder_layers(self):
468476 return [gpt3 .Gpt3DecoderLayerToLinen ]
469477 case DecoderBlockType .GPT_OSS :
470478 return [gpt_oss .GptOssScannableBlockToLinen ] if self .config .scan_layers else [gpt_oss .GptOssDecoderLayerToLinen ]
471- case DecoderBlockType .QWEN2 :
472- return [qwen2 .Qwen2DecoderLayerToLinen ]
473479 case DecoderBlockType .QWEN3 :
474480 return [qwen3 .Qwen3DecoderLayerToLinen ]
475481 case DecoderBlockType .QWEN3_MOE :
@@ -525,10 +531,10 @@ def get_norm_layer(self, num_features: int):
525531 DecoderBlockType .MISTRAL ,
526532 DecoderBlockType .MIXTRAL ,
527533 DecoderBlockType .DEEPSEEK ,
534+ DecoderBlockType .DEEPSEEK_CUSTOM ,
528535 DecoderBlockType .GEMMA ,
529536 DecoderBlockType .GEMMA2 ,
530537 DecoderBlockType .GEMMA3 ,
531- DecoderBlockType .QWEN2 ,
532538 DecoderBlockType .QWEN3 ,
533539 DecoderBlockType .QWEN3_MOE ,
534540 DecoderBlockType .GPT_OSS ,
@@ -577,7 +583,7 @@ def get_pipeline_stage_module(self, decoder_blocks):
577583 """get pipeline stage module"""
578584
579585 def get_layer_to_pipeline (blocks , cfg ):
580- if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
586+ if cfg .decoder_block in ( DecoderBlockType .DEEPSEEK , DecoderBlockType . DEEPSEEK_CUSTOM ) :
581587 return blocks [1 ] # return the sparse block
582588 else :
583589 return blocks [0 ]
@@ -803,7 +809,7 @@ def __call__(
803809 if cfg .pipeline_fsdp_ag_once or cfg .pipeline_fsdp_ag_per_repeat
804810 else None
805811 )
806- if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
812+ if cfg .decoder_block in ( DecoderBlockType .DEEPSEEK , DecoderBlockType . DEEPSEEK_CUSTOM ) :
807813 assert len (RemattedBlockLayers ) == 2 , "Scanned layers must have a length of 2 using deepseek."
808814 dense_layer = RemattedBlockLayers [0 ]
809815 moe_layer = RemattedBlockLayers [1 ]
@@ -849,7 +855,7 @@ def __call__(
849855 )(y , * broadcast_args )
850856 else :
851857 if cfg .scan_layers :
852- if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
858+ if cfg .decoder_block in ( DecoderBlockType .DEEPSEEK , DecoderBlockType . DEEPSEEK_CUSTOM ) :
853859 assert len (RemattedBlockLayers ) == 2 , "Scanned layers must have a length of 2 using deepseek."
854860 layer_call_kwargs = {
855861 "page_state" : page_state ,
@@ -927,10 +933,31 @@ def __call__(
927933 policy = policy ,
928934 )
929935 else :
936+ scan_length = num_moe_layers
937+ if cfg .decoder_block == DecoderBlockType .DEEPSEEK_CUSTOM and cfg .scan_layers :
938+ if num_moe_layers % cfg .inhomogeneous_layer_cycle_interval != 0 :
939+ raise ValueError (
940+ f"num_moe_layers ({ num_moe_layers } ) must be divisible by "
941+ f"inhomogeneous_layer_cycle_interval ({ cfg .inhomogeneous_layer_cycle_interval } ) "
942+ "when using DeepSeek Custom and scan_layers is True."
943+ )
944+ if cfg .attention_layer_hybrid_ratio != cfg .inhomogeneous_layer_cycle_interval :
945+ raise ValueError (
946+ f"attention_layer_hybrid_ratio ({ cfg .attention_layer_hybrid_ratio } ) and "
947+ f"inhomogeneous_layer_cycle_interval ({ cfg .inhomogeneous_layer_cycle_interval } ) "
948+ "must be the same."
949+ )
950+ scan_length = num_moe_layers // cfg .inhomogeneous_layer_cycle_interval
951+ max_logging .log (
952+ f"scan_length: { scan_length } , "
953+ f"num_moe_layers // cfg.inhomogeneous_layer_cycle_interval: "
954+ f"{ num_moe_layers // cfg .inhomogeneous_layer_cycle_interval } "
955+ )
956+
930957 y , _ = self .scan_decoder_layers (
931958 cfg ,
932959 moe_layer ,
933- num_moe_layers ,
960+ scan_length ,
934961 "moe_layers" ,
935962 mesh ,
936963 in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
@@ -968,7 +995,7 @@ def __call__(
968995 ** layer_kwargs ,
969996 )(y , * broadcast_args )
970997 else :
971- if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
998+ if cfg .decoder_block in ( DecoderBlockType .DEEPSEEK , DecoderBlockType . DEEPSEEK_CUSTOM ) :
972999 assert len (RemattedBlockLayers ) == 2 , "Unscanned layers must have a length of 2 using deepseek."
9731000 dense_layer = RemattedBlockLayers [0 ]
9741001 moe_layer = RemattedBlockLayers [1 ]
@@ -1058,11 +1085,14 @@ def __call__(
10581085 kv_caches ["key_cache" ][lyr ] = returned_cache [0 ]
10591086 kv_caches ["value_cache" ][lyr ] = returned_cache [1 ]
10601087
1061- if deepstack_visual_embeds is not None and lyr < len (deepstack_visual_embeds ):
1062- visual_embeds = deepstack_visual_embeds [lyr ]
1088+ if (
1089+ deepstack_visual_embeds is not None
1090+ and lyr < len (deepstack_visual_embeds )
1091+ and bidirectional_mask is not None
1092+ and deepstack_visual_embeds [lyr ] is not None
1093+ ):
10631094 # Use bidirectional_mask to identify visual token positions
1064- if bidirectional_mask is not None and visual_embeds is not None :
1065- y = deepstack_process (y , bidirectional_mask , visual_embeds )
1095+ y = deepstack_process (y , bidirectional_mask , deepstack_visual_embeds [lyr ])
10661096
10671097 assert isinstance (y , jax .Array )
10681098
0 commit comments