Skip to content

Commit b305ca1

Browse files
custom DeepSeek v3
PiperOrigin-RevId: 882309708
1 parent 093ab89 commit b305ca1

11 files changed

Lines changed: 1043 additions & 38 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class DecoderBlockType(enum.Enum):
102102
SIMPLE_MLP = "simple_mlp"
103103
LLAMA4 = "llama4"
104104
OLMO3 = "olmo3"
105+
DEEPSEEK_CUSTOM = "deepseek_custom"
105106

106107

107108
class AttentionType(enum.Enum):

src/maxtext/configs/base.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ base_num_kv_heads: 16
155155
base_mlp_dim: 7168
156156
base_num_decoder_layers: 16
157157
head_dim: 128
158+
attention_output_dim: -1
159+
local_num_query_heads: -1
160+
local_num_kv_heads: -1
161+
global_num_query_heads: -1
162+
global_num_kv_heads: -1
163+
attention_layer_hybrid_ratio: -1
158164
mlp_activations: ["silu", "linear"]
159165
mlp_activations_limit: -1.0
160166
dropout_rate: 0.0
@@ -240,6 +246,8 @@ use_2d_fsdp_sharding: False
240246

241247
# deepseek moe
242248
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
249+
moe_model_dim: -1 # dimension of token entering moe layer.
250+
shared_expert_mlp_dim: -1 # intermediate dimension of the shared expert.
243251
first_num_dense_layers: 0 # number of initial dense layers in the model
244252
shared_experts: 1
245253
routed_scaling_factor: 1.0 # scaling factor for routing scores
@@ -484,6 +492,7 @@ logical_axis_rules: [
484492
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
485493
['embed_no_exp', ['fsdp', 'sequence', 'context']],
486494
['embed_tensor_transpose', ['tensor_transpose']],
495+
['attention_out_proj', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
487496
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
488497
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
489498
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 16384
19+
moe_model_dim: 8192
20+
base_moe_mlp_dim: 16384 # (2 * 8192)
21+
shared_expert_mlp_dim: 32768 # (4 * 8192)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 4 # (1 shared + 4 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 8192 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
base_emb_dim: 1024
18+
moe_model_dim: 512
19+
base_moe_mlp_dim: 1024
20+
shared_expert_mlp_dim: 4096
21+
num_experts: 16
22+
num_experts_per_tok: 2
23+
shared_experts: 1
24+
base_num_decoder_layers: 4
25+
first_num_dense_layers: 1
26+
mlp_activations: ["silu", "linear"]
27+
vocab_size: 129280
28+
enable_dropout: False
29+
logits_via_embedding: False
30+
normalization_layer_epsilon: 1.0e-6
31+
routed_scaling_factor: 2.5
32+
routed_score_func: "sigmoid"
33+
routed_bias: True
34+
decoder_block: "deepseek_custom"
35+
36+
37+
# Hybrid GQA Attention
38+
39+
attention_output_dim: 512 # same as moe_model_dim
40+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
41+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
42+
head_dim: 256
43+
44+
local_num_query_heads: 4
45+
local_num_kv_heads: 2
46+
sliding_window_size: 128
47+
48+
global_num_query_heads: 4
49+
global_num_kv_heads: 1
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 7168
19+
moe_model_dim: 3072
20+
base_moe_mlp_dim: 6144 # (2 * 3072)
21+
shared_expert_mlp_dim: 15360 # (5 * 3072)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 4 # (1 shared + 4 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 3072 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False

src/maxtext/configs/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ class ProfilerType(str, Enum):
223223
"deepseek3-tiny",
224224
"deepseek3.2-671b",
225225
"deepseek-custom",
226+
"deepseek3-custom-small",
227+
"deepseek3-custom",
228+
"deepseek3-custom-large",
226229
"kimi-k2-1t",
227230
"gemma-7b",
228231
"gemma-2b",
@@ -430,6 +433,14 @@ class ModelArchitecture(BaseModel):
430433
base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.")
431434
base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.")
432435
head_dim: int = Field(128, description="Dimension of each attention head.")
436+
attention_output_dim: int = Field(-1, description="Override output dimension for attention block")
437+
local_num_query_heads: int = Field(-1, description="Number of query heads in local context layers.")
438+
local_num_kv_heads: int = Field(-1, description="Number of KV heads in local context layers.")
439+
global_num_query_heads: int = Field(-1, description="Number of query heads in global context layers.")
440+
global_num_kv_heads: int = Field(-1, description="Number of KV heads in global context layers.")
441+
attention_layer_hybrid_ratio: int = Field(
442+
-1, description="Ratio of layer context styles (e.g. 5 means 4 local followed by 1 global)."
443+
)
433444
mlp_activations: list[str] = Field(["silu", "linear"], description="Activation functions in the MLP layer.")
434445
mlp_activations_limit: float = Field(
435446
-1.0,
@@ -707,6 +718,8 @@ class DeepSeekMoE(BaseModel):
707718
"""Configuration specific to DeepSeek-style MoE layers."""
708719

709720
base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
721+
moe_model_dim: int = Field(-1, description="Dimension of tokens entering the MoE layer.")
722+
shared_expert_mlp_dim: int = Field(-1, description="Intermediate dimension for the shared expert.")
710723
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
711724
shared_experts: PositiveInt = Field(1, description="Number of shared experts.")
712725
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")

src/maxtext/layers/decoders.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from maxtext.models import (
4343
deepseek,
4444
deepseek_batchsplit,
45+
deepseek_custom,
4546
gemma,
4647
gemma2,
4748
gemma3,
@@ -52,7 +53,6 @@
5253
mistral,
5354
mixtral,
5455
olmo3,
55-
qwen2,
5656
qwen3,
5757
simple_layer,
5858
)
@@ -458,6 +458,14 @@ def get_decoder_layers(self):
458458
deepseek.DeepSeekDenseLayerToLinen,
459459
deepseek.DeepSeekMoELayerToLinen,
460460
]
461+
case DecoderBlockType.DEEPSEEK_CUSTOM:
462+
deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoELayerToLinen
463+
if self.config.scan_layers and self.config.attention_layer_hybrid_ratio > 1:
464+
deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoEScannableBlockToLinen
465+
return [
466+
deepseek_custom.DeepSeekDenseLayerToLinen,
467+
deepseek_custom_moe_layer,
468+
]
461469
case DecoderBlockType.GEMMA:
462470
return [gemma.GemmaDecoderLayerToLinen]
463471
case DecoderBlockType.GEMMA2:
@@ -468,8 +476,6 @@ def get_decoder_layers(self):
468476
return [gpt3.Gpt3DecoderLayerToLinen]
469477
case DecoderBlockType.GPT_OSS:
470478
return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen]
471-
case DecoderBlockType.QWEN2:
472-
return [qwen2.Qwen2DecoderLayerToLinen]
473479
case DecoderBlockType.QWEN3:
474480
return [qwen3.Qwen3DecoderLayerToLinen]
475481
case DecoderBlockType.QWEN3_MOE:
@@ -525,10 +531,10 @@ def get_norm_layer(self, num_features: int):
525531
DecoderBlockType.MISTRAL,
526532
DecoderBlockType.MIXTRAL,
527533
DecoderBlockType.DEEPSEEK,
534+
DecoderBlockType.DEEPSEEK_CUSTOM,
528535
DecoderBlockType.GEMMA,
529536
DecoderBlockType.GEMMA2,
530537
DecoderBlockType.GEMMA3,
531-
DecoderBlockType.QWEN2,
532538
DecoderBlockType.QWEN3,
533539
DecoderBlockType.QWEN3_MOE,
534540
DecoderBlockType.GPT_OSS,
@@ -577,7 +583,7 @@ def get_pipeline_stage_module(self, decoder_blocks):
577583
"""get pipeline stage module"""
578584

579585
def get_layer_to_pipeline(blocks, cfg):
580-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
586+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
581587
return blocks[1] # return the sparse block
582588
else:
583589
return blocks[0]
@@ -803,7 +809,7 @@ def __call__(
803809
if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat
804810
else None
805811
)
806-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
812+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
807813
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
808814
dense_layer = RemattedBlockLayers[0]
809815
moe_layer = RemattedBlockLayers[1]
@@ -849,7 +855,7 @@ def __call__(
849855
)(y, *broadcast_args)
850856
else:
851857
if cfg.scan_layers:
852-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
858+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
853859
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
854860
layer_call_kwargs = {
855861
"page_state": page_state,
@@ -927,10 +933,31 @@ def __call__(
927933
policy=policy,
928934
)
929935
else:
936+
scan_length = num_moe_layers
937+
if cfg.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM and cfg.scan_layers:
938+
if num_moe_layers % cfg.inhomogeneous_layer_cycle_interval != 0:
939+
raise ValueError(
940+
f"num_moe_layers ({num_moe_layers}) must be divisible by "
941+
f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) "
942+
"when using DeepSeek Custom and scan_layers is True."
943+
)
944+
if cfg.attention_layer_hybrid_ratio != cfg.inhomogeneous_layer_cycle_interval:
945+
raise ValueError(
946+
f"attention_layer_hybrid_ratio ({cfg.attention_layer_hybrid_ratio}) and "
947+
f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) "
948+
"must be the same."
949+
)
950+
scan_length = num_moe_layers // cfg.inhomogeneous_layer_cycle_interval
951+
max_logging.log(
952+
f"scan_length: {scan_length}, "
953+
f"num_moe_layers // cfg.inhomogeneous_layer_cycle_interval: "
954+
f"{num_moe_layers // cfg.inhomogeneous_layer_cycle_interval}"
955+
)
956+
930957
y, _ = self.scan_decoder_layers(
931958
cfg,
932959
moe_layer,
933-
num_moe_layers,
960+
scan_length,
934961
"moe_layers",
935962
mesh,
936963
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
@@ -968,7 +995,7 @@ def __call__(
968995
**layer_kwargs,
969996
)(y, *broadcast_args)
970997
else:
971-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
998+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
972999
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
9731000
dense_layer = RemattedBlockLayers[0]
9741001
moe_layer = RemattedBlockLayers[1]
@@ -1058,11 +1085,14 @@ def __call__(
10581085
kv_caches["key_cache"][lyr] = returned_cache[0]
10591086
kv_caches["value_cache"][lyr] = returned_cache[1]
10601087

1061-
if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
1062-
visual_embeds = deepstack_visual_embeds[lyr]
1088+
if (
1089+
deepstack_visual_embeds is not None
1090+
and lyr < len(deepstack_visual_embeds)
1091+
and bidirectional_mask is not None
1092+
and deepstack_visual_embeds[lyr] is not None
1093+
):
10631094
# Use bidirectional_mask to identify visual token positions
1064-
if bidirectional_mask is not None and visual_embeds is not None:
1065-
y = deepstack_process(y, bidirectional_mask, visual_embeds)
1095+
y = deepstack_process(y, bidirectional_mask, deepstack_visual_embeds[lyr])
10661096

10671097
assert isinstance(y, jax.Array)
10681098

src/maxtext/layers/linears.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def get_norm_layer(self, num_features: int):
474474
DecoderBlockType.GEMMA3,
475475
DecoderBlockType.QWEN3,
476476
DecoderBlockType.DEEPSEEK,
477+
DecoderBlockType.DEEPSEEK_CUSTOM,
477478
DecoderBlockType.LLAMA4,
478479
):
479480
return functools.partial(normalizations.RMSNorm, num_features=num_features)

0 commit comments

Comments
 (0)