diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index 96ecf91e5b..8f1bb7a5c4 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -972,6 +972,7 @@ def module_match_name_list(module, name_list): "Qwen3MoeSparseMoeBlock", "Qwen3NextSparseMoeBlock", "Qwen3_5MoeSparseMoeBlock", + "Qwen3VLMoeTextSparseMoeBlock", "DeepseekMoE", ], ): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 812550e4f4..987dd02315 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -691,9 +691,27 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: return self.w2_linear[expert_idx](x1) +class _Qwen3VLMoeExpertModule(nn.Module): + """Container for a single Qwen3VL MoE expert's linear layers. + + Produces the naming pattern: experts.{id}.gate_proj.weight + (consistent with standard Qwen3 MoE per-expert module structure). + """ + + def __init__(self, hidden_size: int, expert_dim: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, expert_dim, bias=False) + self.up_proj = nn.Linear(hidden_size, expert_dim, bias=False) + self.down_proj = nn.Linear(expert_dim, hidden_size, bias=False) + + class _QuantQwen3VLMoeTextExperts(QuantModule): def _setup(self): - """Modify the Qwen3VLMoeTextExperts by using nn.Linear layers.""" + """Modify the Qwen3VLMoeTextExperts by using per-expert nn.Module containers. + + This produces the naming pattern: experts.{id}.gate_proj.weight + (consistent with standard Qwen3 MoE per-expert module structure). + """ from accelerate import init_empty_weights dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device @@ -713,35 +731,37 @@ def _copy_weight(module, weight): raise AttributeError("Could not find intermediate dimension size in model") with init_empty_weights(): - gate_proj = nn.ModuleList( - [ - nn.Linear(self.hidden_size, expert_dim, bias=False) - for _ in range(self.num_experts) - ] - ) - up_proj = nn.ModuleList( - [ - nn.Linear(self.hidden_size, expert_dim, bias=False) - for _ in range(self.num_experts) - ] - ) - down_proj = nn.ModuleList( + expert_modules = nn.ModuleList( [ - nn.Linear(expert_dim, self.hidden_size, bias=False) + _Qwen3VLMoeExpertModule(self.hidden_size, expert_dim) for _ in range(self.num_experts) ] ) for idx in range(self.num_experts): - _copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, :expert_dim].T) - _copy_weight(up_proj[idx], self.gate_up_proj[idx, :, expert_dim:].T) - _copy_weight(down_proj[idx], self.down_proj[idx, :].T) + _copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :, :expert_dim].T) + _copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, :, expert_dim:].T) + _copy_weight(expert_modules[idx].down_proj, self.down_proj[idx, :].T) delattr(self, "gate_up_proj") delattr(self, "down_proj") - self.gate_proj = gate_proj - self.up_proj = up_proj - self.down_proj = down_proj + # Register expert modules directly as numbered children + # so the naming pattern is: experts.{id}.gate_proj.weight (no extra nesting) + for idx in range(self.num_experts): + self.add_module(str(idx), expert_modules[idx]) + + def __len__(self): + """Support len() so the module is iterable like standard MoE experts.""" + return self.num_experts + + def __iter__(self): + """Support iteration over expert modules.""" + for idx in range(self.num_experts): + yield getattr(self, str(idx)) + + def __getitem__(self, idx): + """Support indexing to get individual expert modules.""" + return getattr(self, str(int(idx))) def forward( self, @@ -757,13 +777,15 @@ def forward( expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: + expert_idx = expert_idx[0] with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx[0]]) + _, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] - gate = self.gate_proj[expert_idx](current_state) - up = self.up_proj[expert_idx](current_state) + expert = self[expert_idx] + gate = expert.gate_proj(current_state) + up = expert.up_proj(current_state) gated_output = up * self.act_fn(gate) - out = self.down_proj[expert_idx](gated_output) + out = expert.down_proj(gated_output) weighted_output = out * routing_weights[token_idx, expert_idx, None] next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size)