Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class DecoderBlockType(enum.Enum):
SIMPLE_MLP = "simple_mlp"
LLAMA4 = "llama4"
OLMO3 = "olmo3"
DEEPSEEK_CUSTOM = "deepseek_custom"


class AttentionType(enum.Enum):
Expand Down
14 changes: 14 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ base_num_kv_heads: 16
base_mlp_dim: 7168
base_num_decoder_layers: 16
head_dim: 128
attention_output_dim: -1
local_num_query_heads: -1
local_num_kv_heads: -1
global_num_query_heads: -1
global_num_kv_heads: -1
attention_layer_hybrid_ratio: -1
mlp_activations: ["silu", "linear"]
mlp_activations_limit: -1.0
dropout_rate: 0.0
Expand Down Expand Up @@ -184,6 +190,11 @@ num_experts_per_tok: 1
megablox: true
sparse_matmul: true
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations.
# By default (-1), this buffer will be worst case size to ensure no dropping.
# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
# a size larger than this then tokens will be dropped.
# In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor.
load_balance_loss_weight: 0.0 # weight for the load balance loss
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
Expand Down Expand Up @@ -240,6 +251,8 @@ use_2d_fsdp_sharding: False

# deepseek moe
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
moe_model_dim: -1 # dimension of token entering moe layer.
shared_expert_mlp_dim: -1 # intermediate dimension of the shared expert.
first_num_dense_layers: 0 # number of initial dense layers in the model
shared_experts: 1
routed_scaling_factor: 1.0 # scaling factor for routing scores
Expand Down Expand Up @@ -485,6 +498,7 @@ logical_axis_rules: [
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'context']],
['embed_tensor_transpose', ['tensor_transpose']],
['attention_out_proj', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
Expand Down
62 changes: 62 additions & 0 deletions src/maxtext/configs/models/deepseek3-custom-large.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for DeepSeek Custom


base_emb_dim: 16384
moe_model_dim: 8192
base_moe_mlp_dim: 16384 # (2 * 8192)
shared_expert_mlp_dim: 32768 # (4 * 8192)

base_num_decoder_layers: 61
first_num_dense_layers: 3
mlp_activations: ["silu","linear"]
vocab_size: 129280
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6

num_experts: 256
num_experts_per_tok: 4 # (1 shared + 4 routed)
shared_experts: 1
routed_scaling_factor: 2.5
routed_score_func: "sigmoid"
routed_bias: True
decoder_block: "deepseek_custom"

# Hybrid GQA Attention
attention_output_dim: 8192 # same as moe_model_dim
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
head_dim: 256

local_num_query_heads: 64
local_num_kv_heads: 8
sliding_window_size: 1024

global_num_query_heads: 64
global_num_kv_heads: 4

mscale: 1.0
# RoPE
rope_type: "yarn"
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
max_position_embeddings: 163840
original_max_position_embeddings: 4096
rope_factor: 40
beta_fast: 32
rope_interleave: True
rope_truncate: True
rope_attention_scaling: False
49 changes: 49 additions & 0 deletions src/maxtext/configs/models/deepseek3-custom-small.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for DeepSeek Custom

base_emb_dim: 1024
moe_model_dim: 512
base_moe_mlp_dim: 1024
shared_expert_mlp_dim: 4096
num_experts: 16
num_experts_per_tok: 2
shared_experts: 1
base_num_decoder_layers: 4
first_num_dense_layers: 1
mlp_activations: ["silu", "linear"]
vocab_size: 129280
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
routed_scaling_factor: 2.5
routed_score_func: "sigmoid"
routed_bias: True
decoder_block: "deepseek_custom"


# Hybrid GQA Attention

attention_output_dim: 512 # same as moe_model_dim
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
head_dim: 256

local_num_query_heads: 4
local_num_kv_heads: 2
sliding_window_size: 128

global_num_query_heads: 4
global_num_kv_heads: 1
62 changes: 62 additions & 0 deletions src/maxtext/configs/models/deepseek3-custom.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for DeepSeek Custom


base_emb_dim: 7168
moe_model_dim: 3072
base_moe_mlp_dim: 6144 # (2 * 3072)
shared_expert_mlp_dim: 15360 # (5 * 3072)

base_num_decoder_layers: 61
first_num_dense_layers: 3
mlp_activations: ["silu","linear"]
vocab_size: 129280
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6

num_experts: 256
num_experts_per_tok: 4 # (1 shared + 4 routed)
shared_experts: 1
routed_scaling_factor: 2.5
routed_score_func: "sigmoid"
routed_bias: True
decoder_block: "deepseek_custom"

# Hybrid GQA Attention
attention_output_dim: 3072 # same as moe_model_dim
attention_layer_hybrid_ratio: 2 # 1 Local : 1 Global ratio
inhomogeneous_layer_cycle_interval: 2 # same as attention_layer_hybrid_ratio
head_dim: 256

local_num_query_heads: 64
local_num_kv_heads: 8
sliding_window_size: 1024

global_num_query_heads: 64
global_num_kv_heads: 4

mscale: 1.0
# RoPE
rope_type: "yarn"
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
max_position_embeddings: 163840
original_max_position_embeddings: 4096
rope_factor: 40
beta_fast: 32
rope_interleave: True
rope_truncate: True
rope_attention_scaling: False
14 changes: 14 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ class ProfilerType(str, Enum):
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek-custom",
"deepseek3-custom-small",
"deepseek3-custom",
"deepseek3-custom-large",
"kimi-k2-1t",
"gemma-7b",
"gemma-2b",
Expand Down Expand Up @@ -437,6 +440,14 @@ class ModelArchitecture(BaseModel):
base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.")
base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.")
head_dim: int = Field(128, description="Dimension of each attention head.")
attention_output_dim: int = Field(-1, description="Override output dimension for attention block")
local_num_query_heads: int = Field(-1, description="Number of query heads in local context layers.")
local_num_kv_heads: int = Field(-1, description="Number of KV heads in local context layers.")
global_num_query_heads: int = Field(-1, description="Number of query heads in global context layers.")
global_num_kv_heads: int = Field(-1, description="Number of KV heads in global context layers.")
attention_layer_hybrid_ratio: int = Field(
-1, description="Ratio of layer context styles (e.g. 5 means 4 local followed by 1 global)."
)
mlp_activations: list[str] = Field(["silu", "linear"], description="Activation functions in the MLP layer.")
mlp_activations_limit: float = Field(
-1.0,
Expand Down Expand Up @@ -617,6 +628,7 @@ class MoEGeneral(BaseModel):
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
use_custom_sort_vjp: bool = Field(
True,
Expand Down Expand Up @@ -714,6 +726,8 @@ class DeepSeekMoE(BaseModel):
"""Configuration specific to DeepSeek-style MoE layers."""

base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
moe_model_dim: int = Field(-1, description="Dimension of tokens entering the MoE layer.")
shared_expert_mlp_dim: int = Field(-1, description="Intermediate dimension for the shared expert.")
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
shared_experts: PositiveInt = Field(1, description="Number of shared experts.")
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")
Expand Down
Loading
Loading