|
42 | 42 | from maxtext.models import ( |
43 | 43 | deepseek, |
44 | 44 | deepseek_batchsplit, |
| 45 | + deepseek_custom, |
45 | 46 | gemma, |
46 | 47 | gemma2, |
47 | 48 | gemma3, |
@@ -457,6 +458,14 @@ def get_decoder_layers(self): |
457 | 458 | deepseek.DeepSeekDenseLayerToLinen, |
458 | 459 | deepseek.DeepSeekMoELayerToLinen, |
459 | 460 | ] |
| 461 | + case DecoderBlockType.DEEPSEEK_CUSTOM: |
| 462 | + deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoELayerToLinen |
| 463 | + if self.config.scan_layers and self.config.attention_layer_hybrid_ratio > 1: |
| 464 | + deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoEScannableBlockToLinen |
| 465 | + return [ |
| 466 | + deepseek_custom.DeepSeekDenseLayerToLinen, |
| 467 | + deepseek_custom_moe_layer, |
| 468 | + ] |
460 | 469 | case DecoderBlockType.GEMMA: |
461 | 470 | return [gemma.GemmaDecoderLayerToLinen] |
462 | 471 | case DecoderBlockType.GEMMA2: |
@@ -522,6 +531,7 @@ def get_norm_layer(self, num_features: int): |
522 | 531 | DecoderBlockType.MISTRAL, |
523 | 532 | DecoderBlockType.MIXTRAL, |
524 | 533 | DecoderBlockType.DEEPSEEK, |
| 534 | + DecoderBlockType.DEEPSEEK_CUSTOM, |
525 | 535 | DecoderBlockType.GEMMA, |
526 | 536 | DecoderBlockType.GEMMA2, |
527 | 537 | DecoderBlockType.GEMMA3, |
@@ -573,7 +583,7 @@ def get_pipeline_stage_module(self, decoder_blocks): |
573 | 583 | """get pipeline stage module""" |
574 | 584 |
|
575 | 585 | def get_layer_to_pipeline(blocks, cfg): |
576 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 586 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
577 | 587 | return blocks[1] # return the sparse block |
578 | 588 | else: |
579 | 589 | return blocks[0] |
@@ -799,7 +809,7 @@ def __call__( |
799 | 809 | if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat |
800 | 810 | else None |
801 | 811 | ) |
802 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 812 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
803 | 813 | assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." |
804 | 814 | dense_layer = RemattedBlockLayers[0] |
805 | 815 | moe_layer = RemattedBlockLayers[1] |
@@ -845,7 +855,7 @@ def __call__( |
845 | 855 | )(y, *broadcast_args) |
846 | 856 | else: |
847 | 857 | if cfg.scan_layers: |
848 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 858 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
849 | 859 | assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." |
850 | 860 | layer_call_kwargs = { |
851 | 861 | "page_state": page_state, |
@@ -923,10 +933,19 @@ def __call__( |
923 | 933 | policy=policy, |
924 | 934 | ) |
925 | 935 | else: |
| 936 | + scan_length = num_moe_layers |
| 937 | + if cfg.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM and cfg.scan_layers: |
| 938 | + if num_moe_layers % cfg.inhomogeneous_layer_cycle_interval != 0: |
| 939 | + raise ValueError(f"num_moe_layers ({num_moe_layers}) must be divisible by inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) when using DeepSeek Custom and scan_layers is True.") |
| 940 | + if cfg.attention_layer_hybrid_ratio != cfg.inhomogeneous_layer_cycle_interval: |
| 941 | + raise ValueError(f"attention_layer_hybrid_ratio ({cfg.attention_layer_hybrid_ratio}) and inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) must be the same.") |
| 942 | + scan_length = num_moe_layers // cfg.inhomogeneous_layer_cycle_interval |
| 943 | + max_logging.log(f"scan_length: {scan_length}, num_moe_layers // cfg.inhomogeneous_layer_cycle_interval: {num_moe_layers // cfg.inhomogeneous_layer_cycle_interval}") |
| 944 | + |
926 | 945 | y, _ = self.scan_decoder_layers( |
927 | 946 | cfg, |
928 | 947 | moe_layer, |
929 | | - num_moe_layers, |
| 948 | + scan_length, |
930 | 949 | "moe_layers", |
931 | 950 | mesh, |
932 | 951 | in_axes_tuple=(nn.broadcast,) * len(broadcast_args), |
@@ -964,7 +983,7 @@ def __call__( |
964 | 983 | **layer_kwargs, |
965 | 984 | )(y, *broadcast_args) |
966 | 985 | else: |
967 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 986 | + if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM): |
968 | 987 | assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." |
969 | 988 | dense_layer = RemattedBlockLayers[0] |
970 | 989 | moe_layer = RemattedBlockLayers[1] |
|
0 commit comments