Skip to content

Commit 5779927

Browse files
Edwardf0t1claude
andcommitted
remove Qwen3.5 MoE-specific quantization code
Qwen3_5MoeExperts follows the standard fused expert pattern (gate_up_proj 3D + down_proj 3D + num_experts + act_fn) and is now handled generically by _QuantFusedExperts via register_fused_experts_on_the_fly. Both Qwen3.5-397B and Qwen3.5-35B use fused 3D experts in their main MoE layers; MTP per-expert layers are iterable and handled by _QuantSparseSequentialMoe. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent a52cd14 commit 5779927

File tree

1 file changed

+2
-116
lines changed

1 file changed

+2
-116
lines changed

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 2 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -786,106 +786,6 @@ def forward(
786786
return next_states
787787

788788

789-
class _Qwen35MoeExpertModule(nn.Module):
790-
"""Container for a single Qwen3.5 MoE expert's linear layers.
791-
792-
Produces the naming pattern: experts.{id}.gate_proj.weight
793-
(consistent with standard Qwen3 MoE per-expert module structure).
794-
"""
795-
796-
def __init__(self, hidden_dim: int, expert_dim: int):
797-
super().__init__()
798-
self.gate_proj = nn.Linear(hidden_dim, expert_dim, bias=False)
799-
self.up_proj = nn.Linear(hidden_dim, expert_dim, bias=False)
800-
self.down_proj = nn.Linear(expert_dim, hidden_dim, bias=False)
801-
802-
803-
class _QuantQwen35MoeExperts(QuantModule):
804-
def _setup(self):
805-
"""Modify the Qwen3_5MoeExperts by using per-expert nn.Module containers.
806-
807-
This produces the naming pattern: experts.{id}.gate_proj.weight
808-
(consistent with standard Qwen3 MoE).
809-
"""
810-
from accelerate import init_empty_weights
811-
812-
dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device
813-
814-
def _copy_weight(module, weight):
815-
module.to_empty(device=device)
816-
with torch.no_grad():
817-
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)
818-
819-
expert_dim = self.intermediate_dim
820-
821-
with init_empty_weights():
822-
expert_modules = nn.ModuleList(
823-
[
824-
_Qwen35MoeExpertModule(self.hidden_dim, expert_dim)
825-
for _ in range(self.num_experts)
826-
]
827-
)
828-
829-
for idx in range(self.num_experts):
830-
# gate_up_proj shape: (num_experts, 2*intermediate_dim, hidden_dim)
831-
# Already in (out_features, in_features) format, no transpose needed
832-
_copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :expert_dim, :])
833-
_copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, expert_dim:, :])
834-
# down_proj shape: (num_experts, hidden_dim, intermediate_dim)
835-
# Already in (out_features, in_features) format
836-
_copy_weight(expert_modules[idx].down_proj, self.down_proj[idx])
837-
838-
delattr(self, "gate_up_proj")
839-
delattr(self, "down_proj")
840-
# Register expert modules directly as numbered children (like nn.ModuleList)
841-
# so the naming pattern is: experts.{id}.gate_proj.weight (no extra nesting)
842-
for idx in range(self.num_experts):
843-
self.add_module(str(idx), expert_modules[idx])
844-
845-
def __len__(self):
846-
"""Support len() so the module is iterable like standard MoE experts."""
847-
return self.num_experts
848-
849-
def __iter__(self):
850-
"""Support iteration over expert modules."""
851-
for idx in range(self.num_experts):
852-
yield getattr(self, str(idx))
853-
854-
def __getitem__(self, idx):
855-
"""Support indexing to get individual expert modules."""
856-
return getattr(self, str(int(idx)))
857-
858-
def forward(
859-
self,
860-
hidden_states: torch.Tensor,
861-
top_k_index: torch.Tensor,
862-
top_k_weights: torch.Tensor,
863-
) -> torch.Tensor:
864-
final_hidden_states = torch.zeros_like(hidden_states)
865-
with torch.no_grad():
866-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
867-
expert_mask = expert_mask.permute(2, 1, 0)
868-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
869-
for expert_idx in expert_hit:
870-
expert_idx = expert_idx[0]
871-
if expert_idx == self.num_experts:
872-
continue
873-
with torch.no_grad():
874-
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
875-
current_state = hidden_states[token_idx]
876-
expert = self[expert_idx]
877-
gate = expert.gate_proj(current_state)
878-
up = expert.up_proj(current_state)
879-
current_hidden_states = self.act_fn(gate) * up
880-
current_hidden_states = expert.down_proj(current_hidden_states)
881-
current_hidden_states = (
882-
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
883-
)
884-
final_hidden_states.index_add_(
885-
0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
886-
)
887-
return final_hidden_states
888-
889789

890790
def _get_fused_expert_intermediate_dim(module):
891791
"""Resolve the intermediate (expert) dimension from a fused expert module.
@@ -1285,20 +1185,6 @@ def unpack_weight(self):
12851185
pass
12861186

12871187

1288-
try:
1289-
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts
1290-
1291-
# Qwen3_5MoeSparseMoeBlock registration is handled by register_sparse_moe_on_the_fly
1292-
# (auto-detected via gate.top_k + gate.num_experts + experts pattern).
1293-
# Only the fused expert weights need explicit registration.
1294-
if Qwen3_5MoeExperts not in QuantModuleRegistry:
1295-
QuantModuleRegistry.register({Qwen3_5MoeExperts: "hf.Qwen3_5MoeExperts"})(
1296-
_QuantQwen35MoeExperts
1297-
)
1298-
except ImportError:
1299-
pass
1300-
1301-
13021188
class _QuantGptOssExperts(_QuantFunctionalMixin):
13031189
"""Quantized wrapper for `transformers.GptOssExperts`.
13041190
@@ -1659,8 +1545,8 @@ class _QuantMoELinear(QuantModule):
16591545
weights and scales back into the original 3D format.
16601546
16611547
Note: we use expansion-then-reconstruction rather than the add_module() approach
1662-
(as in _QuantQwen35MoeExperts) because vLLM requires stacked 3D scaling factors;
1663-
per-expert expanded keys are not accepted by the downstream serving engine.
1548+
because vLLM requires stacked 3D scaling factors; per-expert expanded keys are
1549+
not accepted by the downstream serving engine.
16641550
"""
16651551

16661552
def _setup(self):

0 commit comments

Comments
 (0)