-
Notifications
You must be signed in to change notification settings - Fork 32.7k
[Qwen3MoE] Potentially a bug on Qwen3MoeSparseMoeBlock #45208
Copy link
Copy link
Open
Description
Hi,
I found a typing mismatch on Qwen3MoeSparseMoeBlock:
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(self, config: Qwen3MoeConfig):
super().__init__()
self.experts = Qwen3MoeExperts(config)
self.gate = Qwen3MoeTopKRouter(config)
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)if the code is correct, the return type of forward should be torch.Tensor. However, i don't know whether returning routing_weights is also needed or not. Also, Qwen3MoeSparseMoeBlock is used in Qwen3MoeDecoderLayer as self.mlp, and there is a residual connection after self.mlp(hidden_states):
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_statesIf we return a tuple of tensors, hidden_states = residual + hidden_states will give this error
TypeError: unsupported operand type(s) for +: 'Tensor' and 'tuple'Did i miss something? Should we also return the routing_weights for computing loss during training?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels