|
42 | 42 | from maxtext.models import ( |
43 | 43 | deepseek, |
44 | 44 | deepseek_batchsplit, |
| 45 | + deepseek_custom, |
45 | 46 | gemma, |
46 | 47 | gemma2, |
47 | 48 | gemma3, |
@@ -458,6 +459,14 @@ def get_decoder_layers(self): |
458 | 459 | deepseek.DeepSeekDenseLayerToLinen, |
459 | 460 | deepseek.DeepSeekMoELayerToLinen, |
460 | 461 | ] |
| 462 | + case DecoderBlockType.DEEPSEEK_CUSTOM: |
| 463 | + deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoELayerToLinen |
| 464 | + if self.config.scan_layers and self.config.attention_layer_hybrid_ratio > 1: |
| 465 | + deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoEScannableBlockToLinen |
| 466 | + return [ |
| 467 | + deepseek_custom.DeepSeekDenseLayerToLinen, |
| 468 | + deepseek_custom_moe_layer, |
| 469 | + ] |
461 | 470 | case DecoderBlockType.GEMMA: |
462 | 471 | return [gemma.GemmaDecoderLayerToLinen] |
463 | 472 | case DecoderBlockType.GEMMA2: |
@@ -525,6 +534,7 @@ def get_norm_layer(self, num_features: int): |
525 | 534 | DecoderBlockType.MISTRAL, |
526 | 535 | DecoderBlockType.MIXTRAL, |
527 | 536 | DecoderBlockType.DEEPSEEK, |
| 537 | + DecoderBlockType.DEEPSEEK_CUSTOM, |
528 | 538 | DecoderBlockType.GEMMA, |
529 | 539 | DecoderBlockType.GEMMA2, |
530 | 540 | DecoderBlockType.GEMMA3, |
@@ -577,7 +587,7 @@ def get_pipeline_stage_module(self, decoder_blocks): |
577 | 587 | """get pipeline stage module""" |
578 | 588 |
|
579 | 589 | def get_layer_to_pipeline(blocks, cfg): |
580 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 590 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
581 | 591 | return blocks[1] # return the sparse block |
582 | 592 | else: |
583 | 593 | return blocks[0] |
@@ -803,7 +813,7 @@ def __call__( |
803 | 813 | if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat |
804 | 814 | else None |
805 | 815 | ) |
806 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 816 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
807 | 817 | assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." |
808 | 818 | dense_layer = RemattedBlockLayers[0] |
809 | 819 | moe_layer = RemattedBlockLayers[1] |
@@ -849,7 +859,7 @@ def __call__( |
849 | 859 | )(y, *broadcast_args) |
850 | 860 | else: |
851 | 861 | if cfg.scan_layers: |
852 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 862 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
853 | 863 | assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." |
854 | 864 | layer_call_kwargs = { |
855 | 865 | "page_state": page_state, |
@@ -927,10 +937,31 @@ def __call__( |
927 | 937 | policy=policy, |
928 | 938 | ) |
929 | 939 | else: |
| 940 | + scan_length = num_moe_layers |
| 941 | + if cfg.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM and cfg.scan_layers: |
| 942 | + if num_moe_layers % cfg.inhomogeneous_layer_cycle_interval != 0: |
| 943 | + raise ValueError( |
| 944 | + f"num_moe_layers ({num_moe_layers}) must be divisible by " |
| 945 | + f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) " |
| 946 | + "when using DeepSeek Custom and scan_layers is True." |
| 947 | + ) |
| 948 | + if cfg.attention_layer_hybrid_ratio != cfg.inhomogeneous_layer_cycle_interval: |
| 949 | + raise ValueError( |
| 950 | + f"attention_layer_hybrid_ratio ({cfg.attention_layer_hybrid_ratio}) and " |
| 951 | + f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) " |
| 952 | + "must be the same." |
| 953 | + ) |
| 954 | + scan_length = num_moe_layers // cfg.inhomogeneous_layer_cycle_interval |
| 955 | + max_logging.log( |
| 956 | + f"scan_length: {scan_length}, " |
| 957 | + f"num_moe_layers // cfg.inhomogeneous_layer_cycle_interval: " |
| 958 | + f"{num_moe_layers // cfg.inhomogeneous_layer_cycle_interval}" |
| 959 | + ) |
| 960 | + |
930 | 961 | y, _ = self.scan_decoder_layers( |
931 | 962 | cfg, |
932 | 963 | moe_layer, |
933 | | - num_moe_layers, |
| 964 | + scan_length, |
934 | 965 | "moe_layers", |
935 | 966 | mesh, |
936 | 967 | in_axes_tuple=(nn.broadcast,) * len(broadcast_args), |
@@ -968,7 +999,7 @@ def __call__( |
968 | 999 | **layer_kwargs, |
969 | 1000 | )(y, *broadcast_args) |
970 | 1001 | else: |
971 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 1002 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
972 | 1003 | assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." |
973 | 1004 | dense_layer = RemattedBlockLayers[0] |
974 | 1005 | moe_layer = RemattedBlockLayers[1] |
@@ -1058,11 +1089,14 @@ def __call__( |
1058 | 1089 | kv_caches["key_cache"][lyr] = returned_cache[0] |
1059 | 1090 | kv_caches["value_cache"][lyr] = returned_cache[1] |
1060 | 1091 |
|
1061 | | - if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): |
1062 | | - visual_embeds = deepstack_visual_embeds[lyr] |
| 1092 | + if ( |
| 1093 | + deepstack_visual_embeds is not None |
| 1094 | + and lyr < len(deepstack_visual_embeds) |
| 1095 | + and bidirectional_mask is not None |
| 1096 | + and deepstack_visual_embeds[lyr] is not None |
| 1097 | + ): |
1063 | 1098 | # Use bidirectional_mask to identify visual token positions |
1064 | | - if bidirectional_mask is not None and visual_embeds is not None: |
1065 | | - y = deepstack_process(y, bidirectional_mask, visual_embeds) |
| 1099 | + y = deepstack_process(y, bidirectional_mask, deepstack_visual_embeds[lyr]) |
1066 | 1100 |
|
1067 | 1101 | assert isinstance(y, jax.Array) |
1068 | 1102 |
|
|
0 commit comments