Skip to content

Commit fa93aec

Browse files
committed
add tests
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent fc2f4eb commit fa93aec

1 file changed

Lines changed: 5 additions & 8 deletions

File tree

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
from modelopt.torch.quantization.nn import QuantModuleRegistry
2626
from modelopt.torch.quantization.plugins.huggingface import (
27-
_QuantFusedExperts,
2827
_is_fused_experts_module,
2928
_is_sparse_moe_block,
29+
_QuantFusedExperts,
3030
register_fused_experts_on_the_fly,
3131
register_sparse_moe_on_the_fly,
3232
)
@@ -51,9 +51,7 @@ def __init__(self):
5151
self.gate_up_proj = nn.Parameter(
5252
torch.randn(NUM_EXPERTS, 2 * INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02
5353
)
54-
self.down_proj = nn.Parameter(
55-
torch.randn(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM) * 0.02
56-
)
54+
self.down_proj = nn.Parameter(torch.randn(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM) * 0.02)
5755
self.act_fn = nn.SiLU()
5856

5957
def forward(self, hidden_states, top_k_index, top_k_weights):
@@ -70,7 +68,9 @@ def forward(self, hidden_states, top_k_index, top_k_weights):
7068
gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
7169
current_hidden_states = self.act_fn(gate) * up
7270
current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx])
73-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
71+
current_hidden_states = (
72+
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
73+
)
7474
final_hidden_states.index_add_(
7575
0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
7676
)
@@ -254,9 +254,6 @@ class TestExportFusedExperts:
254254
def test_export_creates_per_expert_submodules(self):
255255
"""_export_fused_experts should create per-expert submodules with standard naming."""
256256
from modelopt.torch.export.moe_utils import _export_fused_experts
257-
from modelopt.torch.quantization.plugins.huggingface import (
258-
_get_fused_expert_intermediate_dim,
259-
)
260257

261258
experts = _SyntheticFusedExperts()
262259
expert_type = type(experts)

0 commit comments

Comments
 (0)