diff --git a/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml b/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml new file mode 100644 index 0000000000..9a217ebf72 --- /dev/null +++ b/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml @@ -0,0 +1,83 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for DeepSeek V3 - 671B that uses fsdp on two logical axes + +# For DeepSeek default device-limited routing, +# please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments. + +base_emb_dim: 7168 +base_num_query_heads: 128 +base_num_kv_heads: 128 +base_mlp_dim: 18432 +base_moe_mlp_dim: 2048 +base_num_decoder_layers: 61 +first_num_dense_layers: 3 +mlp_activations: ["silu","linear"] +vocab_size: 129280 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +num_experts: 256 +num_experts_per_tok: 8 +shared_experts: 1 +routed_scaling_factor: 2.5 +routed_score_func: "sigmoid" +routed_bias: True +decoder_block: "deepseek" +# MLA +attention_type: "mla" +q_lora_rank: 1536 +kv_lora_rank: 512 +qk_nope_head_dim: 128 +qk_rope_head_dim: 64 +v_head_dim: 128 +mscale: 1.0 +# RoPE +rope_type: "yarn" +rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 +max_position_embeddings: 163840 +original_max_position_embeddings: 4096 +rope_factor: 40 +beta_fast: 32 +rope_interleave: True +rope_truncate: True +rope_attention_scaling: False + +override_logical_axis_rules: True +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context'] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']] +logical_axis_rules: [ + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], + ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_norm_length', ['context']], + ['activation_heads', []], + ['activation_stage', 'stage'], + ['embed', ['fsdp']], + ['embed_no_exp', ['fsdp']], + ['q_lora', ['fsdp']], + ['kv_lora', ['fsdp']], + ['layers', 'stage'], + ['q_lora_up_proj', ['fsdp_transpose']], + ['kv_lora_up_proj', ['fsdp_transpose']], + ['q_heads', ['fsdp_transpose']], + ['kv_heads', ['fsdp_transpose']], + ['heads', ['fsdp_transpose']], + ['mlp', ['fsdp_transpose']], + ['fsdp_transpose_and_expert', ['fsdp_transpose', 'expert']], + ['fsdp_transpose_only', ['fsdp_transpose']], + ['expert_only', ['expert']], +] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e66ecac8fa..68e750fe1a 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -219,6 +219,7 @@ class ProfilerType(str, Enum): "deepseek2-236b", "deepseek3-671b", "deepseek3-671b-2dfsdp", + "deepseek3-671b-batchsplit", "deepseek3-test", "deepseek3-tiny", "deepseek3.2-671b", diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index 14ce96978c..c9200bfb94 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -180,26 +180,18 @@ def fn(weights): (routed_wi_0, routed_wi_1, routed_wo), (shared_wi_0, shared_wi_1, shared_wo), ) = weights - # All-gather across FSDP axis. Expert axis is used for FSDP in attention. - wq_a = jax.lax.all_gather(wq_a, axis_name="expert", tiled=True, axis=1) + # All-gather across FSDP axis. wq_a = jax.lax.all_gather(wq_a, axis_name="fsdp", tiled=True) - wq_b = jax.lax.all_gather(wq_b, axis_name="expert", tiled=True, axis=1) wq_b = jax.lax.all_gather(wq_b, axis_name="fsdp", tiled=True) - wkv_a = jax.lax.all_gather(wkv_a, axis_name="expert", tiled=True, axis=1) wkv_a = jax.lax.all_gather(wkv_a, axis_name="fsdp", tiled=True) - wkv_b = jax.lax.all_gather(wkv_b, axis_name="expert", tiled=True, axis=1) wkv_b = jax.lax.all_gather(wkv_b, axis_name="fsdp", tiled=True) - out = jax.lax.all_gather(out, axis_name="expert", tiled=True) out = jax.lax.all_gather(out, axis_name="fsdp", tiled=True, axis=2) gate = jax.lax.all_gather(gate, axis_name="fsdp", tiled=True) routed_wi_0 = jax.lax.all_gather(routed_wi_0, axis_name="fsdp", tiled=True) routed_wi_1 = jax.lax.all_gather(routed_wi_1, axis_name="fsdp", tiled=True) routed_wo = jax.lax.all_gather(routed_wo, axis_name="fsdp", tiled=True) - shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="expert", tiled=True, axis=1) shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="fsdp", tiled=True) - shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="expert", tiled=True, axis=1) shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="fsdp", tiled=True) - shared_wo = jax.lax.all_gather(shared_wo, axis_name="expert", tiled=True) shared_wo = jax.lax.all_gather(shared_wo, axis_name="fsdp", tiled=True, axis=1) return ( ( @@ -224,13 +216,13 @@ def fn(weights): jax.sharding.PartitionSpec(None), ), ( - jax.sharding.PartitionSpec("fsdp", "expert"), - jax.sharding.PartitionSpec("fsdp", "expert", None), + jax.sharding.PartitionSpec("fsdp", None), + jax.sharding.PartitionSpec("fsdp", None, None), jax.sharding.PartitionSpec(None), - jax.sharding.PartitionSpec("fsdp", "expert"), - jax.sharding.PartitionSpec("fsdp", "expert", None), + jax.sharding.PartitionSpec("fsdp", None), + jax.sharding.PartitionSpec("fsdp", None, None), jax.sharding.PartitionSpec(None), - jax.sharding.PartitionSpec("expert", None, "fsdp"), + jax.sharding.PartitionSpec(None, None, "fsdp"), ), ), ( @@ -244,9 +236,9 @@ def fn(weights): jax.sharding.PartitionSpec("fsdp", "expert", None), ), ( - jax.sharding.PartitionSpec("fsdp", "expert"), - jax.sharding.PartitionSpec("fsdp", "expert"), - jax.sharding.PartitionSpec("expert", "fsdp"), + jax.sharding.PartitionSpec("fsdp", None), + jax.sharding.PartitionSpec("fsdp", None), + jax.sharding.PartitionSpec(None, "fsdp"), ), ), ),