Skip to content

Commit 5c13ca3

Browse files
custom DeepSeek v3
PiperOrigin-RevId: 882309708
1 parent 37ded59 commit 5c13ca3

11 files changed

Lines changed: 1019 additions & 30 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class DecoderBlockType(enum.Enum):
101101
SIMPLE_MLP = "simple_mlp"
102102
LLAMA4 = "llama4"
103103
OLMO3 = "olmo3"
104+
DEEPSEEK_CUSTOM = "deepseek_custom"
104105

105106

106107
class AttentionType(enum.Enum):

src/maxtext/configs/base.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ base_num_kv_heads: 16
155155
base_mlp_dim: 7168
156156
base_num_decoder_layers: 16
157157
head_dim: 128
158+
attention_output_dim: -1
159+
local_num_query_heads: -1
160+
local_num_kv_heads: -1
161+
global_num_query_heads: -1
162+
global_num_kv_heads: -1
163+
attention_layer_hybrid_ratio: -1
158164
mlp_activations: ["silu", "linear"]
159165
mlp_activations_limit: -1.0
160166
dropout_rate: 0.0
@@ -240,6 +246,8 @@ use_2d_fsdp_sharding: False
240246

241247
# deepseek moe
242248
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
249+
moe_model_dim: -1 # dimension of token entering moe layer.
250+
shared_expert_mlp_dim: -1 # intermediate dimension of the shared expert.
243251
first_num_dense_layers: 0 # number of initial dense layers in the model
244252
shared_experts: 1
245253
routed_scaling_factor: 1.0 # scaling factor for routing scores
@@ -484,6 +492,7 @@ logical_axis_rules: [
484492
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
485493
['embed_no_exp', ['fsdp', 'sequence', 'context']],
486494
['embed_tensor_transpose', ['tensor_transpose']],
495+
['attention_out_proj', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
487496
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
488497
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
489498
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 16384
19+
moe_model_dim: 8192
20+
base_moe_mlp_dim: 16384 # (2 * 8192)
21+
shared_expert_mlp_dim: 32768 # (4 * 8192)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 4 # (1 shared + 4 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 8192 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
base_emb_dim: 1024
18+
moe_model_dim: 512
19+
base_moe_mlp_dim: 1024
20+
shared_expert_mlp_dim: 4096
21+
num_experts: 16
22+
num_experts_per_tok: 2
23+
shared_experts: 1
24+
base_num_decoder_layers: 4
25+
first_num_dense_layers: 1
26+
mlp_activations: ["silu", "linear"]
27+
vocab_size: 129280
28+
enable_dropout: False
29+
logits_via_embedding: False
30+
normalization_layer_epsilon: 1.0e-6
31+
routed_scaling_factor: 2.5
32+
routed_score_func: "sigmoid"
33+
routed_bias: True
34+
decoder_block: "deepseek_custom"
35+
36+
37+
# Hybrid GQA Attention
38+
39+
attention_output_dim: 512 # same as moe_model_dim
40+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
41+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
42+
head_dim: 256
43+
44+
local_num_query_heads: 4
45+
local_num_kv_heads: 2
46+
sliding_window_size: 128
47+
48+
global_num_query_heads: 4
49+
global_num_kv_heads: 1
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 7168
19+
moe_model_dim: 3072
20+
base_moe_mlp_dim: 6144 # (2 * 3072)
21+
shared_expert_mlp_dim: 15360 # (5 * 3072)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 4 # (1 shared + 4 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 3072 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False

src/maxtext/configs/types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ class ProfilerType(str, Enum):
222222
"deepseek3-test",
223223
"deepseek3-tiny",
224224
"deepseek3.2-671b",
225-
"deepseek-custom",
225+
"deepseek3-custom-small",
226+
"deepseek3-custom",
227+
"deepseek3-custom-large",
226228
"kimi-k2-1t",
227229
"gemma-7b",
228230
"gemma-2b",
@@ -428,6 +430,12 @@ class ModelArchitecture(BaseModel):
428430
base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.")
429431
base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.")
430432
head_dim: int = Field(128, description="Dimension of each attention head.")
433+
attention_output_dim: int = Field(-1, description="Override output dimension for attention block")
434+
local_num_query_heads: int = Field(-1, description="Number of query heads in local context layers.")
435+
local_num_kv_heads: int = Field(-1, description="Number of KV heads in local context layers.")
436+
global_num_query_heads: int = Field(-1, description="Number of query heads in global context layers.")
437+
global_num_kv_heads: int = Field(-1, description="Number of KV heads in global context layers.")
438+
attention_layer_hybrid_ratio: int = Field(-1, description="Ratio of layer context styles (e.g. 5 means 4 local followed by 1 global).")
431439
mlp_activations: list[str] = Field(["silu", "linear"], description="Activation functions in the MLP layer.")
432440
mlp_activations_limit: float = Field(
433441
-1.0,
@@ -705,6 +713,8 @@ class DeepSeekMoE(BaseModel):
705713
"""Configuration specific to DeepSeek-style MoE layers."""
706714

707715
base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
716+
moe_model_dim: int = Field(-1, description="Dimension of tokens entering the MoE layer.")
717+
shared_expert_mlp_dim: int = Field(-1, description="Intermediate dimension for the shared expert.")
708718
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
709719
shared_experts: PositiveInt = Field(1, description="Number of shared experts.")
710720
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")

src/maxtext/layers/decoders.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from maxtext.models import (
4343
deepseek,
4444
deepseek_batchsplit,
45+
deepseek_custom,
4546
gemma,
4647
gemma2,
4748
gemma3,
@@ -457,6 +458,14 @@ def get_decoder_layers(self):
457458
deepseek.DeepSeekDenseLayerToLinen,
458459
deepseek.DeepSeekMoELayerToLinen,
459460
]
461+
case DecoderBlockType.DEEPSEEK_CUSTOM:
462+
deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoELayerToLinen
463+
if self.config.scan_layers and self.config.attention_layer_hybrid_ratio > 1:
464+
deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoEScannableBlockToLinen
465+
return [
466+
deepseek_custom.DeepSeekDenseLayerToLinen,
467+
deepseek_custom_moe_layer,
468+
]
460469
case DecoderBlockType.GEMMA:
461470
return [gemma.GemmaDecoderLayerToLinen]
462471
case DecoderBlockType.GEMMA2:
@@ -522,6 +531,7 @@ def get_norm_layer(self, num_features: int):
522531
DecoderBlockType.MISTRAL,
523532
DecoderBlockType.MIXTRAL,
524533
DecoderBlockType.DEEPSEEK,
534+
DecoderBlockType.DEEPSEEK_CUSTOM,
525535
DecoderBlockType.GEMMA,
526536
DecoderBlockType.GEMMA2,
527537
DecoderBlockType.GEMMA3,
@@ -573,7 +583,7 @@ def get_pipeline_stage_module(self, decoder_blocks):
573583
"""get pipeline stage module"""
574584

575585
def get_layer_to_pipeline(blocks, cfg):
576-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
586+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
577587
return blocks[1] # return the sparse block
578588
else:
579589
return blocks[0]
@@ -799,7 +809,7 @@ def __call__(
799809
if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat
800810
else None
801811
)
802-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
812+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
803813
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
804814
dense_layer = RemattedBlockLayers[0]
805815
moe_layer = RemattedBlockLayers[1]
@@ -845,7 +855,7 @@ def __call__(
845855
)(y, *broadcast_args)
846856
else:
847857
if cfg.scan_layers:
848-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
858+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
849859
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
850860
layer_call_kwargs = {
851861
"page_state": page_state,
@@ -923,10 +933,19 @@ def __call__(
923933
policy=policy,
924934
)
925935
else:
936+
scan_length = num_moe_layers
937+
if cfg.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM and cfg.scan_layers:
938+
if num_moe_layers % cfg.inhomogeneous_layer_cycle_interval != 0:
939+
raise ValueError(f"num_moe_layers ({num_moe_layers}) must be divisible by inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) when using DeepSeek Custom and scan_layers is True.")
940+
if cfg.attention_layer_hybrid_ratio != cfg.inhomogeneous_layer_cycle_interval:
941+
raise ValueError(f"attention_layer_hybrid_ratio ({cfg.attention_layer_hybrid_ratio}) and inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) must be the same.")
942+
scan_length = num_moe_layers // cfg.inhomogeneous_layer_cycle_interval
943+
max_logging.log(f"scan_length: {scan_length}, num_moe_layers // cfg.inhomogeneous_layer_cycle_interval: {num_moe_layers // cfg.inhomogeneous_layer_cycle_interval}")
944+
926945
y, _ = self.scan_decoder_layers(
927946
cfg,
928947
moe_layer,
929-
num_moe_layers,
948+
scan_length,
930949
"moe_layers",
931950
mesh,
932951
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
@@ -964,7 +983,7 @@ def __call__(
964983
**layer_kwargs,
965984
)(y, *broadcast_args)
966985
else:
967-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
986+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
968987
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
969988
dense_layer = RemattedBlockLayers[0]
970989
moe_layer = RemattedBlockLayers[1]

src/maxtext/layers/linears.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def get_norm_layer(self, num_features: int):
474474
DecoderBlockType.GEMMA3,
475475
DecoderBlockType.QWEN3,
476476
DecoderBlockType.DEEPSEEK,
477+
DecoderBlockType.DEEPSEEK_CUSTOM,
477478
DecoderBlockType.LLAMA4,
478479
):
479480
return functools.partial(normalizations.RMSNorm, num_features=num_features)

0 commit comments

Comments
 (0)