Skip to content

Muon QKV split fails for gated-attention QKV projections #4731

@Moozy23232

Description

@Moozy23232

Describe the bug

Muon QKV splitting assumes that linear_qkv.weight is laid out as three logical
blocks per query group:

[q, k, v]

I hit this while training a Qwen3.5-style MoE model with Muon enabled. The model
uses gated attention output (attention_output_gate=True), where Megatron-Core's
attention module adds an extra query-sized gate block to linear_qkv_out_dim:

self.linear_qkv_out_dim = self.query_projection_size + 2 * self.kv_projection_size
if self.config.attention_output_gate:
    self.linear_qkv_out_dim += self.config.kv_channels * self.config.num_attention_heads

So the fused QKV projection is effectively laid out as:

[q_gate, q, k, v]

For a Qwen3.5-style gated-attention shape, this can produce a linear_qkv.weight
first dimension of 4608 = 2048 + 2048 + 256 + 256. The current Muon split path
can still compute and use a 3-way split such as [2048, 256, 256], which makes
the reshape in the Muon orthogonalization path invalid.

This is not specific to Qwen3.5 or ms-swift. Any Megatron-Core model that uses
gated attention output together with muon_split_qkv can hit the same structural
problem, because the tensor has four logical QKV-related blocks while the Muon
split path assumes three.

There is also a related metadata propagation issue: tensor-parallel attribute
copying preserves is_qkv, but affected versions do not preserve Muon-specific
metadata such as qkv_split_shapes or a parameter name/debug identifier when
copying attributes to optimizer/master parameters.

A proposed fix PR is available: #4728

Tagging @mcore-oncall for
visibility.

Steps/Code to reproduce bug

The failure can be reproduced without ms-swift, a full model, a dataset, or a
training run. The following Megatron-Core-only reproducer directly exercises the
Muon QKV split path:

import importlib
import torch


def import_tensor_parallel_muon():
    try:
        module = importlib.import_module("megatron.core.optimizer.emerging_optimizers")
        return module.TensorParallelMuon
    except ModuleNotFoundError:
        module = importlib.import_module("megatron.core.optimizer.muon")
        return module.TensorParallelMuon


TensorParallelMuon = import_tensor_parallel_muon()

optimizer = TensorParallelMuon.__new__(TensorParallelMuon)
optimizer.pg_collection = None
optimizer.mode = "duplicated"
optimizer.tp_mode = "duplicated"
optimizer.split_qkv = True
optimizer.is_qkv_fn = lambda p: getattr(p, "is_qkv", False)
optimizer.qkv_split_shapes = [2048, 256, 256]  # current 3-way [q, k, v] assumption
optimizer.scaled_orthogonalize_fn = lambda grad, tp_group, partition_dim: grad

param = torch.empty(4608, 2048)
param.is_qkv = True
param.partition_dim = -1

grad = torch.empty(4608, 2048)  # gated QKV: [2048, 2048, 256, 256]
optimizer.orthogonalize(param, grad)

On an affected version, this fails with:

RuntimeError: shape '[1, 2560, -1]' is invalid for input of size 9437184

The same tensor shape works if the split uses the gated-QKV layout:

optimizer.qkv_split_shapes = [2048, 2048, 256, 256]
out = optimizer.orthogonalize(param, grad)
assert out.shape == grad.shape

The metadata-copy issue can be reproduced independently:

import torch
from megatron.core.tensor_parallel.layers import copy_tensor_model_parallel_attributes

source = torch.empty(1)
destination = torch.empty(1)
source.is_qkv = True
source.qkv_split_shapes = [2048, 2048, 256, 256]
source.muon_param_name = "decoder.layers.0.self_attention.linear_qkv.weight"

copy_tensor_model_parallel_attributes(destination, source)

print(hasattr(destination, "is_qkv"))
print(hasattr(destination, "qkv_split_shapes"))
print(hasattr(destination, "muon_param_name"))

On an affected version this prints:

True
False
False

Expected behavior

For attention_output_gate=True, Muon should split gated QKV weights and
gradients using four logical blocks:

[q_gate, q, k, v]

For the shape above, the expected split is:

[2048, 2048, 256, 256]

Muon-specific QKV metadata should also be preserved when tensor-parallel
attributes are copied to optimizer/master parameters, so the optimizer uses the
correct split for the actual parameter it updates.

Additional context

The issue was first observed while training a Qwen3.5-style MoE model using Muon
QKV splitting. The important runtime conditions were:

optimizer: dist_muon
muon_split_qkv: true
attention_output_gate: true
linear_qkv.weight first dimension: 4608
old split shapes: [2048, 256, 256]
expected gated split shapes: [2048, 2048, 256, 256]

The full ms-swift training script is not required to reproduce the bug. The
minimal Megatron-Core reproducer above isolates the failing split logic directly.

The corresponding real training log failed in the Muon optimizer step:

[rank6]: Traceback (most recent call last):
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/cli/_megatron/sft.py", line 7, in <module>
[rank6]:     megatron_sft_main()
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/megatron/pipelines/train/sft.py", line 93, in megatron_sft_main
[rank6]:     return MegatronSft(args).main()
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/pipelines/base.py", line 52, in main
[rank6]:     result = self.run()
[rank6]:              ^^^^^^^^^^
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/megatron/pipelines/train/sft.py", line 68, in run
[rank6]:     trainer.train(train_dataset, val_dataset)
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/megatron/trainers/base.py", line 694, in train
[rank6]:     metrics, grad_norm, update_successful = self.train_step(train_data_iterator)
[rank6]:                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/ms-swift/swift/megatron/trainers/base.py", line 953, in train_step
[rank6]:     update_successful, grad_norm, _ = self.optimizer.step()
[rank6]:                                       ^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/mnt/afs_share/miniconda3/envs/fc_swift_qwen3.5_py3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank6]:     return func(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^

[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/layer_wise_optimizer.py", line 282, in step
[rank6]:     update_successful, grad_norm, num_zeros_in_grad = super().step()
[rank6]:                                                       ^^^^^^^^^^^^^^
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/optimizer.py", line 1388, in step
[rank6]:     update_successful = self.step_with_ready_grads()
[rank6]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/optimizer.py", line 1288, in step_with_ready_grads
[rank6]:     success &= optimizer.step_with_ready_grads()
[rank6]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/optimizer.py", line 597, in step_with_ready_grads
[rank6]:     self.optimizer.step()
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/ms-swift/Emerging-Optimizers/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py", line 196, in step
[rank6]:     orth_grad = self.orthogonalize(p, grad, **group_kwargs)
[rank6]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/mnt/afs_agents/lizhihan/workspace/megatron-lm/megatron/core/optimizer/emerging_optimizers.py", line 261, in orthogonalize
[rank6]:     grad.view(num_query_groups, sum(self.qkv_split_shapes), -1),
[rank6]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: RuntimeError: shape '[1, 2560, -1]' is invalid for input of size 9437184

Here 9437184 is the number of elements in the actual gradient tensor:

4608 * 2048 = 9437184

The failing reshape uses 2560 because the old 3-way QKV split sums to:

2048 + 256 + 256 = 2560

For the gated-attention tensor, however, the first dimension is:

2048 + 2048 + 256 + 256 = 4608

So the old code tries to view the gradient as [1, 2560, -1], but
9437184 / 2560 is not an integer. With the gated-QKV split
[2048, 2048, 256, 256], the reshape becomes [1, 4608, 2048], which matches
the actual tensor shape.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions