Describe the bug
Muon QKV splitting assumes that linear_qkv.weight is laid out as three logical
blocks per query group:
I hit this while training a Qwen3.5-style MoE model with Muon enabled. The model
uses gated attention output (attention_output_gate=True), where Megatron-Core's
attention module adds an extra query-sized gate block to linear_qkv_out_dim:
self.linear_qkv_out_dim = self.query_projection_size + 2 * self.kv_projection_size
if self.config.attention_output_gate:
self.linear_qkv_out_dim += self.config.kv_channels * self.config.num_attention_heads
So the fused QKV projection is effectively laid out as:
For a Qwen3.5-style gated-attention shape, this can produce a linear_qkv.weight
first dimension of 4608 = 2048 + 2048 + 256 + 256. The current Muon split path
can still compute and use a 3-way split such as [2048, 256, 256], which makes
the reshape in the Muon orthogonalization path invalid.
This is not specific to Qwen3.5 or ms-swift. Any Megatron-Core model that uses
gated attention output together with muon_split_qkv can hit the same structural
problem, because the tensor has four logical QKV-related blocks while the Muon
split path assumes three.
There is also a related metadata propagation issue: tensor-parallel attribute
copying preserves is_qkv, but affected versions do not preserve Muon-specific
metadata such as qkv_split_shapes or a parameter name/debug identifier when
copying attributes to optimizer/master parameters.
A proposed fix PR is available: #4728
Tagging @mcore-oncall for
visibility.
Steps/Code to reproduce bug
The failure can be reproduced without ms-swift, a full model, a dataset, or a
training run. The following Megatron-Core-only reproducer directly exercises the
Muon QKV split path:
import importlib
import torch
def import_tensor_parallel_muon():
try:
module = importlib.import_module("megatron.core.optimizer.emerging_optimizers")
return module.TensorParallelMuon
except ModuleNotFoundError:
module = importlib.import_module("megatron.core.optimizer.muon")
return module.TensorParallelMuon
TensorParallelMuon = import_tensor_parallel_muon()
optimizer = TensorParallelMuon.__new__(TensorParallelMuon)
optimizer.pg_collection = None
optimizer.mode = "duplicated"
optimizer.tp_mode = "duplicated"
optimizer.split_qkv = True
optimizer.is_qkv_fn = lambda p: getattr(p, "is_qkv", False)
optimizer.qkv_split_shapes = [2048, 256, 256] # current 3-way [q, k, v] assumption
optimizer.scaled_orthogonalize_fn = lambda grad, tp_group, partition_dim: grad
param = torch.empty(4608, 2048)
param.is_qkv = True
param.partition_dim = -1
grad = torch.empty(4608, 2048) # gated QKV: [2048, 2048, 256, 256]
optimizer.orthogonalize(param, grad)
On an affected version, this fails with:
RuntimeError: shape '[1, 2560, -1]' is invalid for input of size 9437184
The same tensor shape works if the split uses the gated-QKV layout:
optimizer.qkv_split_shapes = [2048, 2048, 256, 256]
out = optimizer.orthogonalize(param, grad)
assert out.shape == grad.shape
The metadata-copy issue can be reproduced independently:
import torch
from megatron.core.tensor_parallel.layers import copy_tensor_model_parallel_attributes
source = torch.empty(1)
destination = torch.empty(1)
source.is_qkv = True
source.qkv_split_shapes = [2048, 2048, 256, 256]
source.muon_param_name = "decoder.layers.0.self_attention.linear_qkv.weight"
copy_tensor_model_parallel_attributes(destination, source)
print(hasattr(destination, "is_qkv"))
print(hasattr(destination, "qkv_split_shapes"))
print(hasattr(destination, "muon_param_name"))
On an affected version this prints:
Expected behavior
For attention_output_gate=True, Muon should split gated QKV weights and
gradients using four logical blocks:
For the shape above, the expected split is:
Muon-specific QKV metadata should also be preserved when tensor-parallel
attributes are copied to optimizer/master parameters, so the optimizer uses the
correct split for the actual parameter it updates.
Additional context
The issue was first observed while training a Qwen3.5-style MoE model using Muon
QKV splitting. The important runtime conditions were:
optimizer: dist_muon
muon_split_qkv: true
attention_output_gate: true
linear_qkv.weight first dimension: 4608
old split shapes: [2048, 256, 256]
expected gated split shapes: [2048, 2048, 256, 256]
The full ms-swift training script is not required to reproduce the bug. The
minimal Megatron-Core reproducer above isolates the failing split logic directly.
The corresponding real training log failed in the Muon optimizer step:
[rank6]: Traceback (most recent call last):
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/cli/_megatron/sft.py", line 7, in <module>
[rank6]: megatron_sft_main()
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/megatron/pipelines/train/sft.py", line 93, in megatron_sft_main
[rank6]: return MegatronSft(args).main()
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/pipelines/base.py", line 52, in main
[rank6]: result = self.run()
[rank6]: ^^^^^^^^^^
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/megatron/pipelines/train/sft.py", line 68, in run
[rank6]: trainer.train(train_dataset, val_dataset)
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/megatron/trainers/base.py", line 694, in train
[rank6]: metrics, grad_norm, update_successful = self.train_step(train_data_iterator)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/megatron/trainers/base.py", line 953, in train_step
[rank6]: update_successful, grad_norm, _ = self.optimizer.step()
[rank6]: ^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/afs_share/miniconda3/envs/fc_swift_qwen3.5_py3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank6]: return func(*args, **kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/layer_wise_optimizer.py", line 282, in step
[rank6]: update_successful, grad_norm, num_zeros_in_grad = super().step()
[rank6]: ^^^^^^^^^^^^^^
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/optimizer.py", line 1388, in step
[rank6]: update_successful = self.step_with_ready_grads()
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/optimizer.py", line 1288, in step_with_ready_grads
[rank6]: success &= optimizer.step_with_ready_grads()
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/optimizer.py", line 597, in step_with_ready_grads
[rank6]: self.optimizer.step()
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/ms-swift/Emerging-Optimizers/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py", line 196, in step
[rank6]: orth_grad = self.orthogonalize(p, grad, **group_kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/emerging_optimizers.py", line 261, in orthogonalize
[rank6]: grad.view(num_query_groups, sum(self.qkv_split_shapes), -1),
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: RuntimeError: shape '[1, 2560, -1]' is invalid for input of size 9437184
Here 9437184 is the number of elements in the actual gradient tensor:
The failing reshape uses 2560 because the old 3-way QKV split sums to:
For the gated-attention tensor, however, the first dimension is:
2048 + 2048 + 256 + 256 = 4608
So the old code tries to view the gradient as [1, 2560, -1], but
9437184 / 2560 is not an integer. With the gated-QKV split
[2048, 2048, 256, 256], the reshape becomes [1, 4608, 2048], which matches
the actual tensor shape.
Describe the bug
Muon QKV splitting assumes that
linear_qkv.weightis laid out as three logicalblocks per query group:
I hit this while training a Qwen3.5-style MoE model with Muon enabled. The model
uses gated attention output (
attention_output_gate=True), where Megatron-Core'sattention module adds an extra query-sized gate block to
linear_qkv_out_dim:So the fused QKV projection is effectively laid out as:
For a Qwen3.5-style gated-attention shape, this can produce a
linear_qkv.weightfirst dimension of
4608 = 2048 + 2048 + 256 + 256. The current Muon split pathcan still compute and use a 3-way split such as
[2048, 256, 256], which makesthe reshape in the Muon orthogonalization path invalid.
This is not specific to Qwen3.5 or ms-swift. Any Megatron-Core model that uses
gated attention output together with
muon_split_qkvcan hit the same structuralproblem, because the tensor has four logical QKV-related blocks while the Muon
split path assumes three.
There is also a related metadata propagation issue: tensor-parallel attribute
copying preserves
is_qkv, but affected versions do not preserve Muon-specificmetadata such as
qkv_split_shapesor a parameter name/debug identifier whencopying attributes to optimizer/master parameters.
A proposed fix PR is available: #4728
Tagging @mcore-oncall for
visibility.
Steps/Code to reproduce bug
The failure can be reproduced without ms-swift, a full model, a dataset, or a
training run. The following Megatron-Core-only reproducer directly exercises the
Muon QKV split path:
On an affected version, this fails with:
The same tensor shape works if the split uses the gated-QKV layout:
The metadata-copy issue can be reproduced independently:
On an affected version this prints:
Expected behavior
For
attention_output_gate=True, Muon should split gated QKV weights andgradients using four logical blocks:
For the shape above, the expected split is:
Muon-specific QKV metadata should also be preserved when tensor-parallel
attributes are copied to optimizer/master parameters, so the optimizer uses the
correct split for the actual parameter it updates.
Additional context
The issue was first observed while training a Qwen3.5-style MoE model using Muon
QKV splitting. The important runtime conditions were:
The full ms-swift training script is not required to reproduce the bug. The
minimal Megatron-Core reproducer above isolates the failing split logic directly.
The corresponding real training log failed in the Muon optimizer step:
Here
9437184is the number of elements in the actual gradient tensor:The failing reshape uses
2560because the old 3-way QKV split sums to:For the gated-attention tensor, however, the first dimension is:
So the old code tries to view the gradient as
[1, 2560, -1], but9437184 / 2560is not an integer. With the gated-QKV split[2048, 2048, 256, 256], the reshape becomes[1, 4608, 2048], which matchesthe actual tensor shape.