|
42 | 42 | from maxtext.models import ( |
43 | 43 | deepseek, |
44 | 44 | deepseek_batchsplit, |
| 45 | + deepseek_custom, |
45 | 46 | gemma, |
46 | 47 | gemma2, |
47 | 48 | gemma3, |
@@ -457,6 +458,14 @@ def get_decoder_layers(self): |
457 | 458 | deepseek.DeepSeekDenseLayerToLinen, |
458 | 459 | deepseek.DeepSeekMoELayerToLinen, |
459 | 460 | ] |
| 461 | + case DecoderBlockType.DEEPSEEK_CUSTOM: |
| 462 | + deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoELayerToLinen |
| 463 | + if self.config.scan_layers and self.config.attention_layer_hybrid_ratio > 1: |
| 464 | + deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoEScannableBlockToLinen |
| 465 | + return [ |
| 466 | + deepseek_custom.DeepSeekDenseLayerToLinen, |
| 467 | + deepseek_custom_moe_layer, |
| 468 | + ] |
460 | 469 | case DecoderBlockType.GEMMA: |
461 | 470 | return [gemma.GemmaDecoderLayerToLinen] |
462 | 471 | case DecoderBlockType.GEMMA2: |
@@ -522,6 +531,7 @@ def get_norm_layer(self, num_features: int): |
522 | 531 | DecoderBlockType.MISTRAL, |
523 | 532 | DecoderBlockType.MIXTRAL, |
524 | 533 | DecoderBlockType.DEEPSEEK, |
| 534 | + DecoderBlockType.DEEPSEEK_CUSTOM, |
525 | 535 | DecoderBlockType.GEMMA, |
526 | 536 | DecoderBlockType.GEMMA2, |
527 | 537 | DecoderBlockType.GEMMA3, |
@@ -573,7 +583,7 @@ def get_pipeline_stage_module(self, decoder_blocks): |
573 | 583 | """get pipeline stage module""" |
574 | 584 |
|
575 | 585 | def get_layer_to_pipeline(blocks, cfg): |
576 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 586 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
577 | 587 | return blocks[1] # return the sparse block |
578 | 588 | else: |
579 | 589 | return blocks[0] |
@@ -799,7 +809,7 @@ def __call__( |
799 | 809 | if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat |
800 | 810 | else None |
801 | 811 | ) |
802 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 812 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
803 | 813 | assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." |
804 | 814 | dense_layer = RemattedBlockLayers[0] |
805 | 815 | moe_layer = RemattedBlockLayers[1] |
@@ -845,7 +855,7 @@ def __call__( |
845 | 855 | )(y, *broadcast_args) |
846 | 856 | else: |
847 | 857 | if cfg.scan_layers: |
848 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 858 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
849 | 859 | assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." |
850 | 860 | layer_call_kwargs = { |
851 | 861 | "page_state": page_state, |
@@ -923,10 +933,31 @@ def __call__( |
923 | 933 | policy=policy, |
924 | 934 | ) |
925 | 935 | else: |
| 936 | + scan_length = num_moe_layers |
| 937 | + if cfg.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM and cfg.scan_layers: |
| 938 | + if num_moe_layers % cfg.inhomogeneous_layer_cycle_interval != 0: |
| 939 | + raise ValueError( |
| 940 | + f"num_moe_layers ({num_moe_layers}) must be divisible by " |
| 941 | + f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) " |
| 942 | + "when using DeepSeek Custom and scan_layers is True." |
| 943 | + ) |
| 944 | + if cfg.attention_layer_hybrid_ratio != cfg.inhomogeneous_layer_cycle_interval: |
| 945 | + raise ValueError( |
| 946 | + f"attention_layer_hybrid_ratio ({cfg.attention_layer_hybrid_ratio}) and " |
| 947 | + f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) " |
| 948 | + "must be the same." |
| 949 | + ) |
| 950 | + scan_length = num_moe_layers // cfg.inhomogeneous_layer_cycle_interval |
| 951 | + max_logging.log( |
| 952 | + f"scan_length: {scan_length}, " |
| 953 | + f"num_moe_layers // cfg.inhomogeneous_layer_cycle_interval: " |
| 954 | + f"{num_moe_layers // cfg.inhomogeneous_layer_cycle_interval}" |
| 955 | + ) |
| 956 | + |
926 | 957 | y, _ = self.scan_decoder_layers( |
927 | 958 | cfg, |
928 | 959 | moe_layer, |
929 | | - num_moe_layers, |
| 960 | + scan_length, |
930 | 961 | "moe_layers", |
931 | 962 | mesh, |
932 | 963 | in_axes_tuple=(nn.broadcast,) * len(broadcast_args), |
@@ -964,7 +995,7 @@ def __call__( |
964 | 995 | **layer_kwargs, |
965 | 996 | )(y, *broadcast_args) |
966 | 997 | else: |
967 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 998 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
968 | 999 | assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." |
969 | 1000 | dense_layer = RemattedBlockLayers[0] |
970 | 1001 | moe_layer = RemattedBlockLayers[1] |
@@ -1054,11 +1085,14 @@ def __call__( |
1054 | 1085 | kv_caches["key_cache"][lyr] = returned_cache[0] |
1055 | 1086 | kv_caches["value_cache"][lyr] = returned_cache[1] |
1056 | 1087 |
|
1057 | | - if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): |
1058 | | - visual_embeds = deepstack_visual_embeds[lyr] |
| 1088 | + if ( |
| 1089 | + deepstack_visual_embeds is not None |
| 1090 | + and lyr < len(deepstack_visual_embeds) |
| 1091 | + and bidirectional_mask is not None |
| 1092 | + and deepstack_visual_embeds[lyr] is not None |
| 1093 | + ): |
1059 | 1094 | # Use bidirectional_mask to identify visual token positions |
1060 | | - if bidirectional_mask is not None and visual_embeds is not None: |
1061 | | - y = deepstack_process(y, bidirectional_mask, visual_embeds) |
| 1095 | + y = deepstack_process(y, bidirectional_mask, deepstack_visual_embeds[lyr]) |
1062 | 1096 |
|
1063 | 1097 | assert isinstance(y, jax.Array) |
1064 | 1098 |
|
|
0 commit comments