2424
2525from modelopt .torch .quantization .nn import QuantModuleRegistry
2626from modelopt .torch .quantization .plugins .huggingface import (
27- _QuantFusedExperts ,
2827 _is_fused_experts_module ,
2928 _is_sparse_moe_block ,
29+ _QuantFusedExperts ,
3030 register_fused_experts_on_the_fly ,
3131 register_sparse_moe_on_the_fly ,
3232)
@@ -51,9 +51,7 @@ def __init__(self):
5151 self .gate_up_proj = nn .Parameter (
5252 torch .randn (NUM_EXPERTS , 2 * INTERMEDIATE_DIM , HIDDEN_DIM ) * 0.02
5353 )
54- self .down_proj = nn .Parameter (
55- torch .randn (NUM_EXPERTS , HIDDEN_DIM , INTERMEDIATE_DIM ) * 0.02
56- )
54+ self .down_proj = nn .Parameter (torch .randn (NUM_EXPERTS , HIDDEN_DIM , INTERMEDIATE_DIM ) * 0.02 )
5755 self .act_fn = nn .SiLU ()
5856
5957 def forward (self , hidden_states , top_k_index , top_k_weights ):
@@ -70,7 +68,9 @@ def forward(self, hidden_states, top_k_index, top_k_weights):
7068 gate , up = F .linear (current_state , self .gate_up_proj [expert_idx ]).chunk (2 , dim = - 1 )
7169 current_hidden_states = self .act_fn (gate ) * up
7270 current_hidden_states = F .linear (current_hidden_states , self .down_proj [expert_idx ])
73- current_hidden_states = current_hidden_states * top_k_weights [token_idx , top_k_pos , None ]
71+ current_hidden_states = (
72+ current_hidden_states * top_k_weights [token_idx , top_k_pos , None ]
73+ )
7474 final_hidden_states .index_add_ (
7575 0 , token_idx , current_hidden_states .to (final_hidden_states .dtype )
7676 )
@@ -254,9 +254,6 @@ class TestExportFusedExperts:
254254 def test_export_creates_per_expert_submodules (self ):
255255 """_export_fused_experts should create per-expert submodules with standard naming."""
256256 from modelopt .torch .export .moe_utils import _export_fused_experts
257- from modelopt .torch .quantization .plugins .huggingface import (
258- _get_fused_expert_intermediate_dim ,
259- )
260257
261258 experts = _SyntheticFusedExperts ()
262259 expert_type = type (experts )
0 commit comments