From 11623d94f3a4de69f49726aba51e1d66dcdbd65b Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 22 Apr 2026 22:40:33 +0000 Subject: [PATCH] introduce cp-as-ep rule for long context training or strong scaling --- .../configs/custom_mesh_and_rule/cp-as-ep.yml | 78 + src/maxtext/layers/moe.py | 3 + tests/utils/sharding_dump.py | 7 + .../input_shardings.json | 178 ++ .../logical_shardings.json | 980 +++++++ .../named_shardings.json | 2543 +++++++++++++++++ 6 files changed, 3789 insertions(+) create mode 100644 src/maxtext/configs/custom_mesh_and_rule/cp-as-ep.yml create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/input_shardings.json create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/logical_shardings.json create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/named_shardings.json diff --git a/src/maxtext/configs/custom_mesh_and_rule/cp-as-ep.yml b/src/maxtext/configs/custom_mesh_and_rule/cp-as-ep.yml new file mode 100644 index 0000000000..cbb3f948a8 --- /dev/null +++ b/src/maxtext/configs/custom_mesh_and_rule/cp-as-ep.yml @@ -0,0 +1,78 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This rule uses data, stage, FSDP, and expert. Expert axis acts as context parallelism in +# components except core dMoE part (between EP all2all). +mesh_axes: ['data', 'stage', 'fsdp', 'context', 'expert'] +data_sharding: [['data', 'stage', 'fsdp', 'context', 'expert']] +context_sharding: 'context' +logical_axis_rules: [ + # ========================================== + # Vocabulary Embedding + # ========================================== + # Vocab Activations + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']], + ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']], + # Vocab Weights + ['vocab', []], + ['embed_vocab', ['fsdp', 'context', 'expert']], + # ========================================== + # Attention + # ========================================== + # Attention Activations + ['activation_batch_attn', ['data', 'fsdp', 'expert']], + ['activation_heads', []], + ['activation_kv_heads', []], + ['activation_length_attn', ['context']], + ['activation_q_length', ['context']], + ['activation_kv_length', []], + ['activation_embed_attn', []], + ['activation_kv', []], + ['activation_kv_batch', ['data', 'fsdp', 'expert']], + ['activation_kv_head_dim', []], + # Attention Weights + ['heads', []], + ['q_heads', []], + ['kv_heads', []], + ['qkv', []], + ['kv', []], + ['kv_head_dim', []], + ['q_lora', ['fsdp', 'context', 'expert']], + ["q_lora_up_proj", []], + ['kv_lora', ['fsdp', 'context', 'expert']], + ["kv_lora_up_proj", []], + # ========================================== + # Mixture of Experts (MoE) + # ========================================== + # MoE Activations + ['activation_batch_moe', ['data', 'fsdp']], + ['activation_exp', ['context', 'expert']], + # MoE Weights + ['exp', ['context', 'expert']], + ['embed_moe', ['fsdp']], + # ========================================== + # Standard MLP / Dense Layers / Model Structure + # ========================================== + # Dense Activations + ['activation_mlp', []], + ['activation_batch', ['data', 'fsdp', 'expert']], + ['activation_length', ['context']], + ['activation_norm_length', ['context']], + ['activation_embed', []], + ['activation_stage', 'stage'], + # General Weights + ['mlp', []], + ['layers', 'stage'], + ['embed', ['fsdp', 'context', 'expert']], + ] \ No newline at end of file diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index dc42694676..9fbb24390d 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -376,6 +376,9 @@ def __init__( if self.config.attention == "vllm_rpa" and self.config.enable_dp_attention: self._expert_parallelism_name = "attn_dp_expert" + elif self.config.custom_mesh_and_rule == "cp-as-ep": + # when custom mesh and rule is cp-as-ep, context axis is same with expert in MoE component + self._expert_parallelism_name = ("context", "expert") else: self._expert_parallelism_name = "expert" diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index cd19a49d63..88958aa0ec 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -55,6 +55,13 @@ "ep-as-cp", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2"), ), + ( + "deepseek2-16b", + "tpu7x-8", + 1, + "cp-as-ep", + ("ici_fsdp_parallelism=-1", "ici_context_parallelism=2", "ici_expert_parallelism=2"), + ), ("qwen3-0.6b", "tpu7x-16", 1, "", ()), ("gpt-oss-20b", "tpu7x-16", 1, "", ()), ("gpt-oss-20b", "tpu7x-16", 1, "", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2")), diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/input_shardings.json new file mode 100644 index 0000000000..e8578a5ce9 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/input_shardings.json @@ -0,0 +1,178 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch_attn', 'activation_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch_attn', 'activation_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "attention_mla/q_nope: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[96,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)" + } + }, + { + "attention_mla/query: bfloat16[96,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[96,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)" + } + }, + { + "attention_mla/key: bfloat16[96,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)" + } + }, + { + "attention_mla/value: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)" + } + }, + { + "attention_op/arr: int8[1,4,4]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(None, 'context')" + } + }, + { + "attention_op/arr: int32[2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('context',)" + } + }, + { + "attention_op/query: bfloat16[96,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('fsdp', 'expert'), None, 'context', None)" + } + }, + { + "attention_op/key: bfloat16[96,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('fsdp', 'expert'), None, None, None)" + } + }, + { + "attention_op/value: bfloat16[96,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('fsdp', 'expert'), None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_batch_attn', 'activation_length', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "linears/x: bfloat16[96,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "deepseek/mlp: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "deepseek/x: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "moe/inputs: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "moe/gate_logits: bfloat16[96,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "moe/w0_kernel: bfloat16[64,2048,1408]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('context', 'expert'), None, None)" + } + }, + { + "moe/w1_kernel: bfloat16[64,2048,1408]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('context', 'expert'), None, None)" + } + }, + { + "moe/wo_kernel: bfloat16[64,1408,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('context', 'expert'), None, None)" + } + }, + { + "linears/x: bfloat16[96,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P(('fsdp', 'expert'), 'context', None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/logical_shardings.json new file mode 100644 index 0000000000..8d30b919f8 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/logical_shardings.json @@ -0,0 +1,980 @@ +{ + ".step": { + "partition_spec": [], + "shape": [] + }, + ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".params/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.count": { + "partition_spec": [], + "shape": [] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[2]/.count": { + "partition_spec": [], + "shape": [] + } +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/named_shardings.json new file mode 100644 index 0000000000..43e03a274f --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_cp-as-ep_ici_fsdp_parallelism=-1_ici_context_parallelism=2_ici_expert_parallelism=2/named_shardings.json @@ -0,0 +1,2543 @@ +{ + ".step": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [], + "shape": [] + }, + ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + "fsdp", + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + null, + "fsdp" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".params/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.count": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [], + "shape": [] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + "fsdp", + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + null, + "fsdp" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + "fsdp", + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "context", + "expert" + ], + null, + null, + "fsdp" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "context", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "context", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[2]/.count": { + "mesh": { + "axis_names": [ + "data", + "stage", + "fsdp", + "context", + "expert" + ], + "shape": { + "data": 1, + "stage": 1, + "fsdp": 2, + "context": 2, + "expert": 2 + } + }, + "partition_spec": [], + "shape": [] + } +} \ No newline at end of file