@@ -193,7 +193,14 @@ def test_setup_creates_expert_token_count(self):
193193
194194 converted = QuantModuleRegistry .convert (moe_block )
195195 assert hasattr (converted , "expert_token_count" )
196- expected_num_experts = moe_block .num_experts if hasattr (moe_block , "num_experts" ) else 0
196+ if hasattr (moe_block , "gate" ) and hasattr (moe_block .gate , "num_experts" ):
197+ expected_num_experts = moe_block .gate .num_experts
198+ elif hasattr (moe_block , "num_experts" ):
199+ expected_num_experts = moe_block .num_experts
200+ elif hasattr (moe_block , "experts" ) and hasattr (moe_block .experts , "num_experts" ):
201+ expected_num_experts = moe_block .experts .num_experts
202+ else :
203+ expected_num_experts = 0
197204 assert converted .expert_token_count .shape == (expected_num_experts ,)
198205 assert converted .expert_token_count .dtype == torch .long
199206 assert (converted .expert_token_count == 0 ).all ()
@@ -298,14 +305,16 @@ def test_gate_forward_hook_counts_tokens(self):
298305 converted .expert_token_count .zero_ ()
299306 converted ._count_expert_tokens = True
300307
301- hidden_size = converted .gate .in_features
308+ if TRANSFORMERS_VERSION_GE_5_0 :
309+ hidden_size = converted .gate .weight .shape [1 ]
310+ top_k = converted .gate .top_k
311+ else :
312+ hidden_size = converted .gate .in_features
313+ top_k = converted .top_k if hasattr (converted , "top_k" ) else converted .gate .top_k
314+
302315 x = torch .randn (8 , hidden_size )
303316 with torch .no_grad ():
304317 converted .gate (x )
305-
306- # After one gate call with counting enabled, total assigned tokens should equal
307- # num_tokens * top_k
308- top_k = converted .top_k if hasattr (converted , "top_k" ) else converted .gate .top_k
309318 total_assigned = converted .expert_token_count .sum ().item ()
310319 assert total_assigned == 8 * top_k
311320
0 commit comments