1212from ..utility import dtypes
1313
1414
15+ # DEPRECATED: low-level binding kept for backward compatibility only.
16+ # Will be removed once all callers have migrated to topk_gating() below.
17+ # New code should use topk_gating(), which:
18+ # - accepts an Optional[Tensor] correction_bias (None => no bias)
19+ # - validates score_func string
20+ # - exposes the same C++ kernel under a more accurate name
1521@compile_ops ("module_moe_topk" )
1622def topk_softplus (
1723 topk_weights : torch .Tensor ,
@@ -20,9 +26,54 @@ def topk_softplus(
2026 correction_bias : torch .Tensor ,
2127 need_renorm : bool ,
2228 routed_scaling_factor : float = 1.0 ,
29+ score_func : str = "sqrtsoftplus" ,
2330) -> None : ...
2431
2532
33+ _VALID_SCORE_FUNCS = {"sqrtsoftplus" , "sigmoid" , "softmax" }
34+
35+
36+ def topk_gating (
37+ topk_weights : torch .Tensor ,
38+ topk_indices : torch .Tensor ,
39+ gating_output : torch .Tensor ,
40+ correction_bias : Optional [torch .Tensor ] = None ,
41+ need_renorm : bool = True ,
42+ routed_scaling_factor : float = 1.0 ,
43+ score_func : str = "sqrtsoftplus" ,
44+ ) -> None :
45+ """Unified fused topk gating for MoE routing.
46+
47+ Args:
48+ score_func: one of {"sqrtsoftplus" (DeepSeek V4-Pro default),
49+ "sigmoid" (Llama4),
50+ "softmax" (DeepSeek V3 / classic MoE)}.
51+ correction_bias: optional bias tensor, pass None for no bias.
52+
53+ Note: softmax is already normalized, so renorm is forced off.
54+ """
55+ assert (
56+ score_func in _VALID_SCORE_FUNCS
57+ ), f"Unknown score_func '{ score_func } ', expected one of { _VALID_SCORE_FUNCS } "
58+ if correction_bias is None :
59+ # Match gating dtype/device so dispatch picks DTYPE_B == DTYPE_I,
60+ # avoiding extra kernel template instantiations.
61+ correction_bias = torch .empty (
62+ 0 , dtype = gating_output .dtype , device = gating_output .device
63+ )
64+ if score_func == "softmax" :
65+ need_renorm = False
66+ topk_softplus (
67+ topk_weights ,
68+ topk_indices ,
69+ gating_output ,
70+ correction_bias ,
71+ need_renorm ,
72+ routed_scaling_factor ,
73+ score_func ,
74+ )
75+
76+
2677@compile_ops ("module_moe_asm" , fc_name = "biased_grouped_topk" )
2778def biased_grouped_topk_hip (
2879 gating_output : torch .Tensor ,
0 commit comments