Skip to content

Commit e6176a0

Browse files
custom DeepSeek v3
PiperOrigin-RevId: 882309708
1 parent 093ab89 commit e6176a0

12 files changed

Lines changed: 1064 additions & 16 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class DecoderBlockType(enum.Enum):
102102
SIMPLE_MLP = "simple_mlp"
103103
LLAMA4 = "llama4"
104104
OLMO3 = "olmo3"
105+
DEEPSEEK_CUSTOM = "deepseek_custom"
105106

106107

107108
class AttentionType(enum.Enum):

src/maxtext/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ num_experts_per_tok: 1
184184
megablox: true
185185
sparse_matmul: true
186186
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
187+
ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations.
188+
# By default (-1), this buffer will be worst case size to ensure no dropping.
189+
# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
190+
# a size larger than this then tokens will be dropped.
191+
# In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor.
187192
load_balance_loss_weight: 0.0 # weight for the load balance loss
188193
use_random_routing: false # whether to use random routing for debug/test purpose
189194
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 16384
19+
moe_model_dim: 8192
20+
base_moe_mlp_dim: 16384 # (2 * 8192)
21+
shared_expert_mlp_dim: 32768 # (4 * 8192)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 4 # (1 shared + 4 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 8192 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
base_emb_dim: 1024
18+
moe_model_dim: 512
19+
base_moe_mlp_dim: 1024
20+
shared_expert_mlp_dim: 4096
21+
num_experts: 16
22+
num_experts_per_tok: 2
23+
shared_experts: 1
24+
base_num_decoder_layers: 4
25+
first_num_dense_layers: 1
26+
mlp_activations: ["silu", "linear"]
27+
vocab_size: 129280
28+
enable_dropout: False
29+
logits_via_embedding: False
30+
normalization_layer_epsilon: 1.0e-6
31+
routed_scaling_factor: 2.5
32+
routed_score_func: "sigmoid"
33+
routed_bias: True
34+
decoder_block: "deepseek_custom"
35+
36+
37+
# Hybrid GQA Attention
38+
39+
attention_output_dim: 512 # same as moe_model_dim
40+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
41+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
42+
head_dim: 256
43+
44+
local_num_query_heads: 4
45+
local_num_kv_heads: 2
46+
sliding_window_size: 128
47+
48+
global_num_query_heads: 4
49+
global_num_kv_heads: 1
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 7168
19+
moe_model_dim: 3072
20+
base_moe_mlp_dim: 6144 # (2 * 3072)
21+
shared_expert_mlp_dim: 15360 # (5 * 3072)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 8 # (1 shared + 8 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 3072 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 7168
19+
moe_model_dim: 3072
20+
base_moe_mlp_dim: 6144 # (2 * 3072)
21+
shared_expert_mlp_dim: 15360 # (5 * 3072)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 4 # (1 shared + 4 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 3072 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ class MoEGeneral(BaseModel):
610610
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
611611
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
612612
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
613+
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
613614
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
614615
use_custom_sort_vjp: bool = Field(
615616
True,

src/maxtext/layers/decoders.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from maxtext.models import (
4343
deepseek,
4444
deepseek_batchsplit,
45+
deepseek_custom,
4546
gemma,
4647
gemma2,
4748
gemma3,
@@ -458,6 +459,14 @@ def get_decoder_layers(self):
458459
deepseek.DeepSeekDenseLayerToLinen,
459460
deepseek.DeepSeekMoELayerToLinen,
460461
]
462+
case DecoderBlockType.DEEPSEEK_CUSTOM:
463+
deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoELayerToLinen
464+
if self.config.scan_layers and self.config.attention_layer_hybrid_ratio > 1:
465+
deepseek_custom_moe_layer = deepseek_custom.DeepSeekMoEScannableBlockToLinen
466+
return [
467+
deepseek_custom.DeepSeekDenseLayerToLinen,
468+
deepseek_custom_moe_layer,
469+
]
461470
case DecoderBlockType.GEMMA:
462471
return [gemma.GemmaDecoderLayerToLinen]
463472
case DecoderBlockType.GEMMA2:
@@ -525,6 +534,7 @@ def get_norm_layer(self, num_features: int):
525534
DecoderBlockType.MISTRAL,
526535
DecoderBlockType.MIXTRAL,
527536
DecoderBlockType.DEEPSEEK,
537+
DecoderBlockType.DEEPSEEK_CUSTOM,
528538
DecoderBlockType.GEMMA,
529539
DecoderBlockType.GEMMA2,
530540
DecoderBlockType.GEMMA3,
@@ -577,7 +587,7 @@ def get_pipeline_stage_module(self, decoder_blocks):
577587
"""get pipeline stage module"""
578588

579589
def get_layer_to_pipeline(blocks, cfg):
580-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
590+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
581591
return blocks[1] # return the sparse block
582592
else:
583593
return blocks[0]
@@ -803,7 +813,7 @@ def __call__(
803813
if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat
804814
else None
805815
)
806-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
816+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
807817
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
808818
dense_layer = RemattedBlockLayers[0]
809819
moe_layer = RemattedBlockLayers[1]
@@ -849,7 +859,7 @@ def __call__(
849859
)(y, *broadcast_args)
850860
else:
851861
if cfg.scan_layers:
852-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
862+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
853863
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
854864
layer_call_kwargs = {
855865
"page_state": page_state,
@@ -927,10 +937,31 @@ def __call__(
927937
policy=policy,
928938
)
929939
else:
940+
scan_length = num_moe_layers
941+
if cfg.decoder_block == DecoderBlockType.DEEPSEEK_CUSTOM and cfg.scan_layers:
942+
if num_moe_layers % cfg.inhomogeneous_layer_cycle_interval != 0:
943+
raise ValueError(
944+
f"num_moe_layers ({num_moe_layers}) must be divisible by "
945+
f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) "
946+
"when using DeepSeek Custom and scan_layers is True."
947+
)
948+
if cfg.attention_layer_hybrid_ratio != cfg.inhomogeneous_layer_cycle_interval:
949+
raise ValueError(
950+
f"attention_layer_hybrid_ratio ({cfg.attention_layer_hybrid_ratio}) and "
951+
f"inhomogeneous_layer_cycle_interval ({cfg.inhomogeneous_layer_cycle_interval}) "
952+
"must be the same."
953+
)
954+
scan_length = num_moe_layers // cfg.inhomogeneous_layer_cycle_interval
955+
max_logging.log(
956+
f"scan_length: {scan_length}, "
957+
f"num_moe_layers // cfg.inhomogeneous_layer_cycle_interval: "
958+
f"{num_moe_layers // cfg.inhomogeneous_layer_cycle_interval}"
959+
)
960+
930961
y, _ = self.scan_decoder_layers(
931962
cfg,
932963
moe_layer,
933-
num_moe_layers,
964+
scan_length,
934965
"moe_layers",
935966
mesh,
936967
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
@@ -968,7 +999,7 @@ def __call__(
968999
**layer_kwargs,
9691000
)(y, *broadcast_args)
9701001
else:
971-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
1002+
if cfg.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK_CUSTOM):
9721003
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
9731004
dense_layer = RemattedBlockLayers[0]
9741005
moe_layer = RemattedBlockLayers[1]
@@ -1058,11 +1089,14 @@ def __call__(
10581089
kv_caches["key_cache"][lyr] = returned_cache[0]
10591090
kv_caches["value_cache"][lyr] = returned_cache[1]
10601091

1061-
if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
1062-
visual_embeds = deepstack_visual_embeds[lyr]
1092+
if (
1093+
deepstack_visual_embeds is not None
1094+
and lyr < len(deepstack_visual_embeds)
1095+
and bidirectional_mask is not None
1096+
and deepstack_visual_embeds[lyr] is not None
1097+
):
10631098
# Use bidirectional_mask to identify visual token positions
1064-
if bidirectional_mask is not None and visual_embeds is not None:
1065-
y = deepstack_process(y, bidirectional_mask, visual_embeds)
1099+
y = deepstack_process(y, bidirectional_mask, deepstack_visual_embeds[lyr])
10661100

10671101
assert isinstance(y, jax.Array)
10681102

src/maxtext/layers/linears.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def get_norm_layer(self, num_features: int):
474474
DecoderBlockType.GEMMA3,
475475
DecoderBlockType.QWEN3,
476476
DecoderBlockType.DEEPSEEK,
477+
DecoderBlockType.DEEPSEEK_CUSTOM,
477478
DecoderBlockType.LLAMA4,
478479
):
479480
return functools.partial(normalizations.RMSNorm, num_features=num_features)

0 commit comments

Comments
 (0)