Skip to content

Commit 4bfeed7

Browse files
custom DeepSeek v3
PiperOrigin-RevId: 882309708
1 parent c9ffb30 commit 4bfeed7

13 files changed

Lines changed: 1275 additions & 66 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class DecoderBlockType(enum.Enum):
102102
SIMPLE_MLP = "simple_mlp"
103103
LLAMA4 = "llama4"
104104
OLMO3 = "olmo3"
105+
DEEPSEEK_CUSTOM = "deepseek_custom"
105106

106107

107108
class AttentionType(enum.Enum):

src/maxtext/configs/base.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ base_num_kv_heads: 16
155155
base_mlp_dim: 7168
156156
base_num_decoder_layers: 16
157157
head_dim: 128
158+
attention_output_dim: -1
159+
local_num_query_heads: -1
160+
local_num_kv_heads: -1
161+
global_num_query_heads: -1
162+
global_num_kv_heads: -1
163+
attention_layer_hybrid_ratio: -1
158164
mlp_activations: ["silu", "linear"]
159165
mlp_activations_limit: -1.0
160166
dropout_rate: 0.0
@@ -184,6 +190,11 @@ num_experts_per_tok: 1
184190
megablox: true
185191
sparse_matmul: true
186192
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
193+
ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations.
194+
# By default (-1), this buffer will be worst case size to ensure no dropping.
195+
# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
196+
# a size larger than this then tokens will be dropped.
197+
# In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor.
187198
load_balance_loss_weight: 0.0 # weight for the load balance loss
188199
use_random_routing: false # whether to use random routing for debug/test purpose
189200
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
@@ -240,6 +251,8 @@ use_2d_fsdp_sharding: False
240251

241252
# deepseek moe
242253
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
254+
moe_model_dim: -1 # dimension of token entering moe layer.
255+
shared_expert_mlp_dim: -1 # intermediate dimension of the shared expert.
243256
first_num_dense_layers: 0 # number of initial dense layers in the model
244257
shared_experts: 1
245258
routed_scaling_factor: 1.0 # scaling factor for routing scores
@@ -485,6 +498,7 @@ logical_axis_rules: [
485498
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
486499
['embed_no_exp', ['fsdp', 'sequence', 'context']],
487500
['embed_tensor_transpose', ['tensor_transpose']],
501+
['attention_out_proj', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
488502
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
489503
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
490504
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 16384
19+
moe_model_dim: 8192
20+
base_moe_mlp_dim: 16384 # (2 * 8192)
21+
shared_expert_mlp_dim: 32768 # (4 * 8192)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 4 # (1 shared + 4 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 8192 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
base_emb_dim: 1024
18+
moe_model_dim: 512
19+
base_moe_mlp_dim: 1024
20+
shared_expert_mlp_dim: 4096
21+
num_experts: 16
22+
num_experts_per_tok: 2
23+
shared_experts: 1
24+
base_num_decoder_layers: 4
25+
first_num_dense_layers: 1
26+
mlp_activations: ["silu", "linear"]
27+
vocab_size: 129280
28+
enable_dropout: False
29+
logits_via_embedding: False
30+
normalization_layer_epsilon: 1.0e-6
31+
routed_scaling_factor: 2.5
32+
routed_score_func: "sigmoid"
33+
routed_bias: True
34+
decoder_block: "deepseek_custom"
35+
36+
37+
# Hybrid GQA Attention
38+
39+
attention_output_dim: 512 # same as moe_model_dim
40+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
41+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
42+
head_dim: 256
43+
44+
local_num_query_heads: 4
45+
local_num_kv_heads: 2
46+
sliding_window_size: 128
47+
48+
global_num_query_heads: 4
49+
global_num_kv_heads: 1
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek Custom
16+
17+
18+
base_emb_dim: 7168
19+
moe_model_dim: 3072
20+
base_moe_mlp_dim: 6144 # (2 * 3072)
21+
shared_expert_mlp_dim: 15360 # (5 * 3072)
22+
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
31+
num_experts: 256
32+
num_experts_per_tok: 4 # (1 shared + 4 routed)
33+
shared_experts: 1
34+
routed_scaling_factor: 2.5
35+
routed_score_func: "sigmoid"
36+
routed_bias: True
37+
decoder_block: "deepseek_custom"
38+
39+
# Hybrid GQA Attention
40+
attention_output_dim: 3072 # same as moe_model_dim
41+
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
42+
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
43+
head_dim: 256
44+
45+
local_num_query_heads: 64
46+
local_num_kv_heads: 8
47+
sliding_window_size: 1024
48+
49+
global_num_query_heads: 64
50+
global_num_kv_heads: 4
51+
52+
mscale: 1.0
53+
# RoPE
54+
rope_type: "yarn"
55+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
56+
max_position_embeddings: 163840
57+
original_max_position_embeddings: 4096
58+
rope_factor: 40
59+
beta_fast: 32
60+
rope_interleave: True
61+
rope_truncate: True
62+
rope_attention_scaling: False

src/maxtext/configs/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ class ProfilerType(str, Enum):
224224
"deepseek3-tiny",
225225
"deepseek3.2-671b",
226226
"deepseek-custom",
227+
"deepseek3-custom-small",
228+
"deepseek3-custom",
229+
"deepseek3-custom-large",
227230
"kimi-k2-1t",
228231
"gemma-7b",
229232
"gemma-2b",
@@ -437,6 +440,14 @@ class ModelArchitecture(BaseModel):
437440
base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.")
438441
base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.")
439442
head_dim: int = Field(128, description="Dimension of each attention head.")
443+
attention_output_dim: int = Field(-1, description="Override output dimension for attention block")
444+
local_num_query_heads: int = Field(-1, description="Number of query heads in local context layers.")
445+
local_num_kv_heads: int = Field(-1, description="Number of KV heads in local context layers.")
446+
global_num_query_heads: int = Field(-1, description="Number of query heads in global context layers.")
447+
global_num_kv_heads: int = Field(-1, description="Number of KV heads in global context layers.")
448+
attention_layer_hybrid_ratio: int = Field(
449+
-1, description="Ratio of layer context styles (e.g. 5 means 4 local followed by 1 global)."
450+
)
440451
mlp_activations: list[str] = Field(["silu", "linear"], description="Activation functions in the MLP layer.")
441452
mlp_activations_limit: float = Field(
442453
-1.0,
@@ -617,6 +628,7 @@ class MoEGeneral(BaseModel):
617628
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
618629
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
619630
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
631+
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
620632
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
621633
use_custom_sort_vjp: bool = Field(
622634
True,
@@ -714,6 +726,8 @@ class DeepSeekMoE(BaseModel):
714726
"""Configuration specific to DeepSeek-style MoE layers."""
715727

716728
base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
729+
moe_model_dim: int = Field(-1, description="Dimension of tokens entering the MoE layer.")
730+
shared_expert_mlp_dim: int = Field(-1, description="Intermediate dimension for the shared expert.")
717731
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
718732
shared_experts: PositiveInt = Field(1, description="Number of shared experts.")
719733
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")

0 commit comments

Comments
 (0)