diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 2c061a891a..662362a674 100644 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -534,7 +534,7 @@ "srcs": [ "f'{AITER_CSRC_DIR}/pybind/moe_topk_pybind.cu'", "f'{AITER_CSRC_DIR}/py_itfs_ck/topk_sigmoid_kernels.cu'", - "f'{AITER_CSRC_DIR}/kernels/topk_softplus_kernels.cu'", + "f'{AITER_CSRC_DIR}/kernels/topk_gating_kernels.cu'", "f'{CK_DIR}/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp'" ], "flags_extra_cc": [], diff --git a/aiter/ops/topk.py b/aiter/ops/topk.py index 535eade8b0..c5be2b8428 100755 --- a/aiter/ops/topk.py +++ b/aiter/ops/topk.py @@ -12,6 +12,12 @@ from ..utility import dtypes +# DEPRECATED: low-level binding kept for backward compatibility only. +# Will be removed once all callers have migrated to topk_gating() below. +# New code should use topk_gating(), which: +# - accepts an Optional[Tensor] correction_bias (None => no bias) +# - validates score_func string +# - exposes the same C++ kernel under a more accurate name @compile_ops("module_moe_topk") def topk_softplus( topk_weights: torch.Tensor, @@ -20,9 +26,54 @@ def topk_softplus( correction_bias: torch.Tensor, need_renorm: bool, routed_scaling_factor: float = 1.0, + score_func: str = "sqrtsoftplus", ) -> None: ... +_VALID_SCORE_FUNCS = {"sqrtsoftplus", "sigmoid", "softmax"} + + +def topk_gating( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: Optional[torch.Tensor] = None, + need_renorm: bool = True, + routed_scaling_factor: float = 1.0, + score_func: str = "sqrtsoftplus", +) -> None: + """Unified fused topk gating for MoE routing. + + Args: + score_func: one of {"sqrtsoftplus" (DeepSeek V4-Pro default), + "sigmoid" (Llama4), + "softmax" (DeepSeek V3 / classic MoE)}. + correction_bias: optional bias tensor, pass None for no bias. + + Note: softmax is already normalized, so renorm is forced off. + """ + assert ( + score_func in _VALID_SCORE_FUNCS + ), f"Unknown score_func '{score_func}', expected one of {_VALID_SCORE_FUNCS}" + if correction_bias is None: + # Match gating dtype/device so dispatch picks DTYPE_B == DTYPE_I, + # avoiding extra kernel template instantiations. + correction_bias = torch.empty( + 0, dtype=gating_output.dtype, device=gating_output.device + ) + if score_func == "softmax": + need_renorm = False + topk_softplus( + topk_weights, + topk_indices, + gating_output, + correction_bias, + need_renorm, + routed_scaling_factor, + score_func, + ) + + @compile_ops("module_moe_asm", fc_name="biased_grouped_topk") def biased_grouped_topk_hip( gating_output: torch.Tensor, diff --git a/csrc/include/moe_op.h b/csrc/include/moe_op.h index 6401f3f4a6..ff4fa48de5 100644 --- a/csrc/include/moe_op.h +++ b/csrc/include/moe_op.h @@ -55,7 +55,8 @@ void topk_softplus(torch::Tensor& topk_weights, torch::Tensor& gating_output, torch::Tensor& correction_bias, bool need_renorm, - float routed_scaling_factor = 1.0); + float routed_scaling_factor = 1.0, + const std::string& score_func = "sqrtsoftplus"); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 26462175c9..53e3ecea79 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1180,7 +1180,8 @@ namespace py = pybind11; py::arg("correction_bias"), \ py::arg("need_renorm"), \ py::arg("routed_scaling_factor") = 1.0, \ - "Apply topk sqrtsoftplus to the gating outputs."); + py::arg("score_func") = "sqrtsoftplus", \ + "Fused topk gating: score_func='sqrtsoftplus'|'sigmoid'|'softmax'."); #define MOE_SORTING_PYBIND \ m.def("moe_sorting_fwd", \ diff --git a/csrc/kernels/topk_gating_kernels.cu b/csrc/kernels/topk_gating_kernels.cu new file mode 100644 index 0000000000..8b90aa6953 --- /dev/null +++ b/csrc/kernels/topk_gating_kernels.cu @@ -0,0 +1,518 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +// +// Fused topk gating kernel for MoE routing. +// +// Scoring functions (selected by string at the C++ entry): +// "sqrtsoftplus" → sqrt(softplus(x)) — DeepSeek V4-Pro default +// "sigmoid" → sigmoid(x) — Llama4 +// "softmax" → softmax(x) — DeepSeek V3 / classic MoE +// +// Kernel variants: +// topk_softplus_kernel_opt — register-only, sort+merge (64/128/256/384 experts) +// topk_softplus_kernel — shared-memory fallback (any expert count) + +#include "aiter_hip_common.h" +#include "hip_reduce.h" +#include "py_itfs_common.h" +#include "aiter_opus_plus.h" +#include +#include +#include +#include +#include +#include + +namespace aiter { + +// --------------------------------------------------------------------------- +// Primitives +// --------------------------------------------------------------------------- + +enum { SCORE_SQRTSOFTPLUS = 0, SCORE_SIGMOID = 1, SCORE_SOFTMAX = 2 }; + +// Fused DPP warp argmax: 6× v_max_f32+DPP + ballot + ctzll + readlane ≈ 9 instr. +// NaN-safe: if all lanes have NaN (val_o == max_val is always false), ballot is 0 +// and ctzll(0) is UB. Detect this via the ballot result and fall back to lane 0. +__device__ __forceinline__ void warpReduceMax_softplus(float& val_o, int& idx) +{ + float max_val = multithread_reduce_max_dpp(val_o); + uint64_t mask = __ballot(val_o == max_val); + int win_lane = (mask != 0) ? __builtin_ctzll(mask) : 0; + idx = __builtin_amdgcn_readlane(idx, win_lane); + val_o = max_val; +} + +template +__device__ __forceinline__ float compute_score(float x) +{ + if constexpr(SCORE_FUNC == SCORE_SIGMOID) + { + // sigmoid(x) = rcp(1 + 2^(-x·log₂e)) → v_exp_f32 + v_rcp_f32 + return __builtin_amdgcn_rcpf(1.0f + exp2f(-x * 1.4426950408889634f)); + } + else if constexpr(SCORE_FUNC == SCORE_SOFTMAX) + { + // softmax: per-element score is identity; normalization done separately + return x; + } + else + { + // sqrt(softplus(x)) = sqrt(log(1 + exp(x))) + // Highest-precision path: pure libm (expf + log1pf), ≤1 ULP. + // Faster alternatives (commented out, ~0.5-1 ULP extra error): + // float sp = x > 20.0f ? x : log1pf(exp2f(x * 1.4426950408889634f)); // exp2f HW + float sp = x > 20.0f ? x : log2f(1.0f + exp2f(x * 1.4426950408889634f)) * 0.6931471805599453f; // both HW + return sqrtf(sp); + } +} + +// --------------------------------------------------------------------------- +// Sorting network (descending, 3 arrays co-permuted: vals, orig, idxs) +// --------------------------------------------------------------------------- + +#define _CAS_DESC(v, o, id, i, j) \ + do \ + { \ + if((v)[i] < (v)[j]) \ + { \ + float _tv = (v)[i]; (v)[i] = (v)[j]; (v)[j] = _tv; \ + float _to = (o)[i]; (o)[i] = (o)[j]; (o)[j] = _to; \ + int _ti = (id)[i]; (id)[i] = (id)[j]; (id)[j] = _ti; \ + } \ + } while(0) + +template +__device__ __forceinline__ void sort_network_desc(float* vals, float* orig, int* idxs) +{ + if constexpr(N <= 1) + return; + else if constexpr(N == 2) + { + _CAS_DESC(vals, orig, idxs, 0, 1); + } + else if constexpr(N == 3) + { + _CAS_DESC(vals, orig, idxs, 0, 1); + _CAS_DESC(vals, orig, idxs, 0, 2); + _CAS_DESC(vals, orig, idxs, 1, 2); + } + else if constexpr(N == 4) + { // 5-comparator optimal network + _CAS_DESC(vals, orig, idxs, 0, 1); + _CAS_DESC(vals, orig, idxs, 2, 3); + _CAS_DESC(vals, orig, idxs, 0, 2); + _CAS_DESC(vals, orig, idxs, 1, 3); + _CAS_DESC(vals, orig, idxs, 1, 2); + } + else if constexpr(N == 6) + { // 12-comparator optimal network + _CAS_DESC(vals, orig, idxs, 0, 1); + _CAS_DESC(vals, orig, idxs, 2, 3); + _CAS_DESC(vals, orig, idxs, 4, 5); + _CAS_DESC(vals, orig, idxs, 0, 2); + _CAS_DESC(vals, orig, idxs, 1, 4); + _CAS_DESC(vals, orig, idxs, 3, 5); + _CAS_DESC(vals, orig, idxs, 0, 1); + _CAS_DESC(vals, orig, idxs, 2, 3); + _CAS_DESC(vals, orig, idxs, 4, 5); + _CAS_DESC(vals, orig, idxs, 1, 2); + _CAS_DESC(vals, orig, idxs, 3, 4); + _CAS_DESC(vals, orig, idxs, 2, 3); + } + else + { // generic unrolled bubble sort fallback +#pragma unroll + for(int i = 0; i < N - 1; i++) + { +#pragma unroll + for(int j = 0; j < N - 1 - i; j++) + { + _CAS_DESC(vals, orig, idxs, j, j + 1); + } + } + } +} + +#undef _CAS_DESC + +// --------------------------------------------------------------------------- +// Register-only kernel (for expert counts divisible by WARP_SIZE) +// +// Each thread loads EPT = NUM_EXPERTS/WARP_SIZE elements, sorts them locally +// via an optimal sorting network, then participates in a warp-level k-way +// merge (iterative argmax) to extract the global top-K. +// No shared memory, no __syncthreads. +// +// 1 warp = 1 token = 1 block. Multi-warp-per-block was tried (WPB=2,4) and +// regressed K≥4 cases (extra register pressure / wave-scheduling overhead), +// while only marginally helping K=1~2. K-merge serial chain is the actual +// bottleneck, not block-launch overhead. +// --------------------------------------------------------------------------- + +template +__global__ void topk_softplus_kernel_opt( + const DTYPE_I* __restrict__ gating_output, + const DTYPE_B* __restrict__ correction_bias, + float* __restrict__ topk_weights, + int* __restrict__ topk_ids, + const size_t stride_tk, + const int topk, + const int num_tokens, + const float routed_scaling_factor) +{ + static constexpr int EPT = NUM_EXPERTS / WARP_SIZE; + static_assert(NUM_EXPERTS % WARP_SIZE == 0); + + const int token_idx = blockIdx.x; + auto const* input_ptr = gating_output + token_idx * NUM_EXPERTS; + + float vals[EPT]; + float orig[EPT]; + int idxs[EPT]; + + // Step 1: load → score → bias (all in registers, strided access) + // orig[] caches unbiased scores; sorted alongside vals[]/idxs[] so all + // three arrays share one cursor index for the merge phase. +#pragma unroll + for(int i = 0; i < EPT; i++) + { + int e = threadIdx.x + i * static_cast(WARP_SIZE); + float score = compute_score(static_cast(input_ptr[e])); + orig[i] = score; + vals[i] = score; + idxs[i] = e; + if(correction_bias != nullptr) + vals[i] += static_cast(correction_bias[e]); + } + + // Step 2: sort thread-local partition descending + sort_network_desc(vals, orig, idxs); + + // Step 3: warp-level k-way merge + // Winning lane = expert_idx & (WARP_SIZE-1) → readlane broadcasts + // the pre-cached unbiased score (no per-round global memory access). + int cursor = 0; + float sum = 0.0f; + int topk_indice = 0; + float topk_value = 0.0f; + + for(int k = 0; k < topk; ++k) + { + float my_val = (cursor < EPT) ? vals[cursor] : -INFINITY; + int my_idx = (cursor < EPT) ? idxs[cursor] : 0; + + warpReduceMax_softplus(my_val, my_idx); + + bool i_won = (cursor < EPT && idxs[cursor] == my_idx); + float my_orig = i_won ? orig[cursor] : 0.0f; + if(i_won) cursor++; + + int win_lane = my_idx & (static_cast(WARP_SIZE) - 1); + float weight = __builtin_bit_cast( + float, __builtin_amdgcn_readlane(__builtin_bit_cast(int, my_orig), win_lane)); + + if(static_cast(threadIdx.x) == k) + { + topk_indice = my_idx; + topk_value = weight; + } + if constexpr(need_renorm) sum += weight; + } + + // Step 4: renorm + scale + write + if constexpr(need_renorm) + sum = routed_scaling_factor / fmaxf(sum, 1e-20f); + else + sum = routed_scaling_factor; + + if(static_cast(threadIdx.x) < topk) + { + topk_weights[token_idx * stride_tk + threadIdx.x] = topk_value * sum; + topk_ids[token_idx * stride_tk + threadIdx.x] = topk_indice; + } +} + +// --------------------------------------------------------------------------- +// Generic fallback kernel (shared-memory based, any expert count) +// --------------------------------------------------------------------------- + +template +__global__ void topk_softplus_kernel( + const DTYPE_I* __restrict__ gating_output, + const DTYPE_B* __restrict__ correction_bias, + float* __restrict__ topk_weights, + int* __restrict__ topk_ids, + const size_t stride_tk, + const int num_experts, + const int topk, + const int num_tokens, + const float routed_scaling_factor) +{ + extern __shared__ char shared_mem[]; + const int token_idx = blockIdx.x; + float* scores = reinterpret_cast(shared_mem); + + using cktype_i = typename hip2opus::type; + f32vec* scores_vec = reinterpret_cast(scores); + static constexpr int vec_size = opus::vector_traits::size(); + using vec_i = opus::vector_t; + const int num_experts_vec = num_experts / vec_size; + + // Step 1: load + score function + // For softmax, bias is NOT added here — it's added AFTER normalization + // (bias only shifts scores for topk selection, not for softmax computation). + auto const* input_ptr = gating_output + token_idx * num_experts; + for(int e = threadIdx.x; e < num_experts_vec; e += blockDim.x) + { + vec_i tmp = reinterpret_cast(input_ptr)[e]; + f32vec gating; +#pragma unroll + for(size_t i = 0; i < vec_size; i++) + { + gating[i] = compute_score(static_cast(tmp[i])); + if constexpr(SCORE_FUNC != SCORE_SOFTMAX) + { + if(correction_bias != nullptr) + gating[i] += static_cast(correction_bias[e * vec_size + i]); + } + } + scores_vec[e] = gating; + } + for(int e = num_experts_vec * vec_size + threadIdx.x; e < num_experts; e += blockDim.x) + { + scores[e] = compute_score(static_cast(input_ptr[e])); + if constexpr(SCORE_FUNC != SCORE_SOFTMAX) + { + if(correction_bias != nullptr) + scores[e] += static_cast(correction_bias[e]); + } + } + __syncthreads(); + + // Softmax: normalize first, then add bias for topk selection. + // scores[] after this block = softmax(x) + bias (biased for selection). + // The topk loop subtracts bias back to get unbiased softmax weights. + if constexpr(SCORE_FUNC == SCORE_SOFTMAX) + { + float local_max = -INFINITY; + for(int e = threadIdx.x; e < num_experts; e += blockDim.x) + local_max = fmaxf(local_max, scores[e]); + local_max = multithread_reduce_max_dpp(local_max); + + float local_sum = 0.0f; + for(int e = threadIdx.x; e < num_experts; e += blockDim.x) + { + scores[e] = exp2f((scores[e] - local_max) * 1.4426950408889634f); + local_sum += scores[e]; + } + local_sum = wave_reduce(local_sum, [](float a, float b) { return a + b; }); + + float inv_sum = __builtin_amdgcn_rcpf(local_sum); + for(int e = threadIdx.x; e < num_experts; e += blockDim.x) + { + scores[e] *= inv_sum; + if(correction_bias != nullptr) + scores[e] += static_cast(correction_bias[e]); + } + __syncthreads(); + } + + float sum = 0.0f; + int topk_indice = 0; + float topk_value = 0.0f; + for(int k = 0; k < topk; ++k) + { + float max_val = -INFINITY; + int max_idx = k; + for(int e = threadIdx.x; e < num_experts_vec; e += blockDim.x) + { + f32vec tmp = scores_vec[e]; +#pragma unroll + for(size_t i = 0; i < vec_size; i++) + { + if(tmp[i] > max_val) { max_val = tmp[i]; max_idx = e * vec_size + i; } + } + } + warpReduceMax_softplus(max_val, max_idx); + if(correction_bias != nullptr) + max_val -= static_cast(correction_bias[max_idx]); + scores[max_idx] = -INFINITY; + if(static_cast(threadIdx.x) == k) + { + topk_indice = max_idx; + topk_value = max_val; + } + if(need_renorm) sum += max_val; + } + + if(need_renorm) + sum = routed_scaling_factor / fmaxf(sum, 1e-20f); + else + sum = routed_scaling_factor; + + for(int k = threadIdx.x; k < topk; k += blockDim.x) + { + topk_weights[token_idx * stride_tk + k] = topk_value * sum; + topk_ids[token_idx * stride_tk + k] = topk_indice; + } +} + +// --------------------------------------------------------------------------- +// Launch macros +// --------------------------------------------------------------------------- + +#define LAUNCH_TOPK_KERNEL(VEC_F, RENORM, SF) \ + hipLaunchKernelGGL( \ + (aiter::topk_softplus_kernel), \ + dim3(grid), dim3(block), shared_mem_size, stream, \ + reinterpret_cast(gating_output.data_ptr()), \ + has_bias ? reinterpret_cast(correction_bias.data_ptr()) : nullptr, \ + topk_weights.data_ptr(), topk_indices.data_ptr(), \ + stride_tk, num_experts, topk, num_tokens, routed_scaling_factor); + +#define LAUNCH_TOPK_KERNEL_OPT(NE, RENORM, SF) \ + hipLaunchKernelGGL( \ + (aiter::topk_softplus_kernel_opt), \ + dim3(grid), dim3(block), 0, stream, \ + reinterpret_cast(gating_output.data_ptr()), \ + has_bias ? reinterpret_cast(correction_bias.data_ptr()) : nullptr, \ + topk_weights.data_ptr(), topk_indices.data_ptr(), \ + stride_tk, topk, num_tokens, routed_scaling_factor); + +// --------------------------------------------------------------------------- +// Host dispatch +// --------------------------------------------------------------------------- + +// Resolve "sqrtsoftplus"/"sigmoid"/"softmax" → SCORE_* enum, or AITER_CHECK fail. +static inline int parse_score_func(const std::string& s) +{ + if(s == "sqrtsoftplus") return SCORE_SQRTSOFTPLUS; + if(s == "sigmoid") return SCORE_SIGMOID; + if(s == "softmax") return SCORE_SOFTMAX; + AITER_CHECK(false, "unknown score_func: ", s, + " (expected sqrtsoftplus|sigmoid|softmax)"); + return SCORE_SQRTSOFTPLUS; // unreachable +} + +void topk_softplus(torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& gating_output, + torch::Tensor& correction_bias, + bool need_renorm, + float routed_scaling_factor, + const std::string& score_func) +{ + const int sf_code = parse_score_func(score_func); + const int num_tokens = gating_output.size(0); + const int num_experts = gating_output.size(1); + const int topk = topk_indices.size(1); + const size_t stride_tk = topk_indices.stride(0); + const bool has_bias = correction_bias.numel() > 0; + + // Both kernels assign one lane per top-K winner during writeout + // (`if (lane == k) topk_value = ...`), so topk must fit in a single warp + // and cannot exceed the number of routable experts. Fail fast with a + // clear error rather than silently producing partial / wrong output. + AITER_CHECK(topk <= static_cast(WARP_SIZE), + "topk (", topk, ") exceeds WARP_SIZE (", WARP_SIZE, ")"); + AITER_CHECK(topk <= num_experts, + "topk (", topk, ") exceeds num_experts (", num_experts, ")"); + + // Softmax outputs are already a probability distribution that sums to 1 + // across the routed top-K (post-selection); a second renorm would distort + // those weights. Enforce here so direct C++ callers behave the same as + // the Python topk_gating() wrapper (which already forces this). + if(sf_code == SCORE_SOFTMAX) + { + need_renorm = false; + } + + dim3 grid(num_tokens); + dim3 block(get_warp_size_func()); + + // Use PyTorch's current stream so that the kernel runs on the same stream + // as the surrounding torch ops (avoids race conditions and works with CUDA + // graph capture). + // TODO: when this op is migrated to aiter_tensor_t (and @compile_ops uses + // develop=True), switch to aiter::getCurrentHIPStream() — the wrapper + // will then sync torch.cuda.current_stream() before each call. + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + + const auto gating_st = gating_output.scalar_type(); + const auto bias_st = has_bias ? correction_bias.scalar_type() : gating_st; + + // Three-level compile-time dispatch: gating dtype → bias dtype → score_func. + auto dispatch = [&](auto gating_tag, auto bias_tag, auto sf_tag) { + using scalar_t = decltype(gating_tag); + using bias_scalar_t = decltype(bias_tag); + constexpr int SF = decltype(sf_tag)::value; + + // Register-only opt kernel (NOT supported for softmax: needs global reduce). + if constexpr(SF != SCORE_SOFTMAX) + { +#define _DISPATCH_REG_KERNEL(NE) \ + if(num_experts == NE) { \ + if(need_renorm) { LAUNCH_TOPK_KERNEL_OPT(NE, true, SF) } \ + else { LAUNCH_TOPK_KERNEL_OPT(NE, false, SF) } \ + return; \ + } + _DISPATCH_REG_KERNEL(64) + _DISPATCH_REG_KERNEL(128) + _DISPATCH_REG_KERNEL(256) + _DISPATCH_REG_KERNEL(384) +#undef _DISPATCH_REG_KERNEL + } + + // Shared-memory fallback kernel + const size_t shared_mem_size = num_experts * sizeof(float); +#define _DISPATCH_SMEM_KERNEL(VEC_LANES) \ + { \ + using VT = opus::vector_t; \ + if(need_renorm) { LAUNCH_TOPK_KERNEL(VT, true, SF) } \ + else { LAUNCH_TOPK_KERNEL(VT, false, SF) } \ + } + switch(num_experts % 4) + { + case 0: _DISPATCH_SMEM_KERNEL(4) break; + case 2: _DISPATCH_SMEM_KERNEL(2) break; + default: _DISPATCH_SMEM_KERNEL(1) break; + } +#undef _DISPATCH_SMEM_KERNEL + }; + + auto dispatch_sf = [&](auto gating_tag, auto bias_tag) { + switch(sf_code) + { + case SCORE_SIGMOID: + dispatch(gating_tag, bias_tag, std::integral_constant{}); break; + case SCORE_SOFTMAX: + dispatch(gating_tag, bias_tag, std::integral_constant{}); break; + default: + dispatch(gating_tag, bias_tag, std::integral_constant{}); break; + } + }; + + auto dispatch_bias = [&](auto gating_tag) { + switch(bias_st) + { + case at::kFloat: dispatch_sf(gating_tag, float{}); break; + case at::kHalf: dispatch_sf(gating_tag, __half{}); break; + case at::kBFloat16: dispatch_sf(gating_tag, hip_bfloat16{}); break; + default: AITER_CHECK(false, "unsupported correction_bias dtype"); break; + } + }; + + switch(gating_st) + { + case at::kFloat: dispatch_bias(float{}); break; + case at::kHalf: dispatch_bias(__half{}); break; + case at::kBFloat16: dispatch_bias(hip_bfloat16{}); break; + default: AITER_CHECK(false, "unsupported gating_output dtype"); break; + } +} + +} // namespace aiter diff --git a/csrc/kernels/topk_softplus_kernels.cu b/csrc/kernels/topk_softplus_kernels.cu deleted file mode 100644 index 2046089ea8..0000000000 --- a/csrc/kernels/topk_softplus_kernels.cu +++ /dev/null @@ -1,233 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -#include "dispatch_utils.h" -#include "hip_reduce.h" -#include "py_itfs_common.h" -#include "aiter_hip_common.h" -#include "aiter_opus_plus.h" -#include -#include -#include -#include -#include - -namespace aiter { - -__inline__ __device__ void warpReduceMax_softplus(float& val_o, int& idx) -{ - using kvp = hipcub::KeyValuePair; - kvp thread_kvp; - thread_kvp.key = idx; - thread_kvp.value = val_o; - auto arg_max = [](kvp a, kvp b) { return a.value > b.value ? a : b; }; - const kvp result_kvp = - wave_reduce(thread_kvp, arg_max); - val_o = __builtin_bit_cast( - float, - __builtin_amdgcn_readlane(__builtin_bit_cast(int, result_kvp.value), WARP_SIZE - 1)); - idx = __builtin_bit_cast( - int, __builtin_amdgcn_readlane(result_kvp.key, WARP_SIZE - 1)); -} - -template -__global__ void topk_softplus_kernel( - const DTYPE_I* __restrict__ gating_output, // [num_tokens, num_experts] - const DTYPE_I* __restrict__ correction_bias, // [num_experts] or nullptr - float* __restrict__ topk_weights, // [num_tokens, topk] - int* __restrict__ topk_ids, // [num_tokens, topk] - const size_t stride_tk, - const int num_experts, - const int topk, - const int num_tokens, - const float routed_scaling_factor) -{ - extern __shared__ char shared_mem[]; - const int token_idx = blockIdx.x; - - float* scores = reinterpret_cast(shared_mem); - - using cktype_i = typename t2opus::type; - f32vec* scores_vec = reinterpret_cast(scores); - static constexpr int vec_size = opus::vector_traits::size(); - using vec_i = opus::vector_t; - const int num_experts_vec = num_experts / vec_size; - - // Step 1: compute sqrt(softplus(x)) and optionally add bias for topk selection - auto const* input_ptr = gating_output + token_idx * num_experts; - for(int e = threadIdx.x; e < num_experts_vec; e += blockDim.x) - { - vec_i tmp = reinterpret_cast(input_ptr)[e]; - f32vec gating; -#pragma unroll - for(size_t i = 0; i < vec_size; i++) - { - float x = static_cast(tmp[i]); - // sqrt(softplus(x)) = sqrt(log1p(exp(x))) - // For numerical stability: when x > 20, softplus(x) ≈ x - float sp = x > 20.0f ? x : log1pf(expf(x)); - gating[i] = sqrtf(sp); - if(correction_bias != nullptr) - { - int idx = e * vec_size + i; - float bias_val = static_cast( - reinterpret_cast(correction_bias)[idx]); - gating[i] += bias_val; - } - } - scores_vec[e] = gating; - } - // Handle remainder if num_experts is not divisible by vec_size - for(int e = num_experts_vec * vec_size + threadIdx.x; e < num_experts; e += blockDim.x) - { - float x = static_cast(input_ptr[e]); - float sp = x > 20.0f ? x : log1pf(expf(x)); - scores[e] = sqrtf(sp); - if(correction_bias != nullptr) - { - scores[e] += static_cast( - reinterpret_cast(correction_bias)[e]); - } - } - __syncthreads(); - - // Step 2: find topk - float sum = 0.0f; - int topk_indice; - float topk_value; - for(int k = 0; k < topk; ++k) - { - float max_val = -INFINITY; - int max_idx = k; - - for(int e = threadIdx.x; e < num_experts_vec; e += blockDim.x) - { - f32vec tmp = scores_vec[e]; -#pragma unroll - for(size_t i = 0; i < vec_size; i++) - { - if(tmp[i] > max_val) - { - max_val = tmp[i]; - max_idx = e * vec_size + i; - } - } - } - - warpReduceMax_softplus(max_val, max_idx); - - { - // Subtract bias to get original score as the routing weight - if(correction_bias != nullptr) - { - max_val -= static_cast( - reinterpret_cast(correction_bias)[max_idx]); - } - scores[max_idx] = -INFINITY; - topk_indice = threadIdx.x == k ? max_idx : topk_indice; - topk_value = threadIdx.x == k ? max_val : topk_value; - if(need_renorm) - { - sum += max_val; - } - } - } - - // Step 3: apply renorm and route_scale - if(need_renorm) - { - sum = routed_scaling_factor / sum; - } - else - { - sum = routed_scaling_factor; - } - - for(int k = threadIdx.x; k < topk; k += blockDim.x) - { - topk_weights[token_idx * stride_tk + k] = topk_value * sum; - topk_ids[token_idx * stride_tk + k] = topk_indice; - } -} - -#define LAUNCH_TOPK_SOFTPLUS_KERNEL(VEC_F, need_renorm_val) \ - VLLM_DISPATCH_FLOATING_TYPES(gating_output.scalar_type(), "topk_softplus_kernel", [&] { \ - hipLaunchKernelGGL( \ - (aiter::topk_softplus_kernel), \ - dim3(grid), \ - dim3(block), \ - shared_mem_size, \ - stream, \ - gating_output.data_ptr(), \ - has_bias ? correction_bias.data_ptr() : nullptr, \ - topk_weights.data_ptr(), \ - topk_indices.data_ptr(), \ - stride_tk, \ - num_experts, \ - topk, \ - num_tokens, \ - routed_scaling_factor); \ - }); - -void topk_softplus(torch::Tensor& topk_weights, // [num_tokens, topk] - torch::Tensor& topk_indices, // [num_tokens, topk] - torch::Tensor& gating_output, // [num_tokens, num_experts] - torch::Tensor& correction_bias, // [num_experts] - bool need_renorm, - float routed_scaling_factor) -{ - int num_tokens = gating_output.size(0); - int num_experts = gating_output.size(1); - int topk = topk_indices.size(1); - size_t stride_tk = topk_indices.stride(0); - bool has_bias = correction_bias.numel() > 0; - - dim3 grid(num_tokens); - dim3 block(WARP_SIZE); - size_t shared_mem_size = num_experts * sizeof(float); - - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); - const hipStream_t stream = at::hip::getCurrentHIPStream(); - - switch(num_experts % 4) - { - case 0: { - using vec4_type = opus::vector_t; - if(need_renorm) - { - LAUNCH_TOPK_SOFTPLUS_KERNEL(vec4_type, true) - } - else - { - LAUNCH_TOPK_SOFTPLUS_KERNEL(vec4_type, false) - } - break; - } - case 2: { - using vec2_type = opus::vector_t; - if(need_renorm) - { - LAUNCH_TOPK_SOFTPLUS_KERNEL(vec2_type, true) - } - else - { - LAUNCH_TOPK_SOFTPLUS_KERNEL(vec2_type, false) - } - break; - } - default: { - using vec1_type = opus::vector_t; - if(need_renorm) - { - LAUNCH_TOPK_SOFTPLUS_KERNEL(vec1_type, true) - } - else - { - LAUNCH_TOPK_SOFTPLUS_KERNEL(vec1_type, false) - } - break; - } - } -} - -} // namespace aiter diff --git a/op_tests/test_moe_topk_sigmoid.py b/op_tests/test_moe_topk_gating.py similarity index 56% rename from op_tests/test_moe_topk_sigmoid.py rename to op_tests/test_moe_topk_gating.py index 6da465b527..6add07a1f8 100644 --- a/op_tests/test_moe_topk_sigmoid.py +++ b/op_tests/test_moe_topk_gating.py @@ -15,6 +15,7 @@ import argparse import itertools +import sys import pandas as pd import pytest @@ -26,6 +27,13 @@ ) from aiter.utility.dtypes import str2Dtype, str2tuple +# NOTE on correctness metrics by score function: +# - sigmoid uses element-wise comparison (score_errors/index_errors) because +# both torch/topk and fused paths return sorted top-K. +# - softplus/softmax use set-based ID matching (id_errors/max_weight_err) +# because torch references intentionally use `topk(..., sorted=False)` to +# mirror routing behavior where top-K order is not semantically required. + @perftest(num_iters=10, num_warmup=1) def run_torch(gating_output: torch.Tensor, topk: int): @@ -35,22 +43,26 @@ def run_torch(gating_output: torch.Tensor, topk: int): return router_scores, router_indices.to(torch.int32) -@perftest(num_iters=10, num_warmup=1) +@perftest(num_iters=100, num_warmup=1) def run_fused(gating_output: torch.Tensor, topk: int): - tokens, _ = gating_output.shape + tokens, num_experts = gating_output.shape router_scores = torch.empty( (tokens, topk), dtype=torch.float32, device=gating_output.device ) router_indices = torch.empty( (tokens, topk), dtype=torch.int32, device=gating_output.device ) - aiter.topk_sigmoid(router_scores, router_indices, gating_output) + aiter.topk_gating( + router_scores, + router_indices, + gating_output, + score_func="sigmoid", + need_renorm=False, + ) return router_scores, router_indices # -- topk_softplus (DeepSeek V4-Pro sqrtsoftplus routing) -------------- - - @perftest(num_iters=10, num_warmup=1) def run_torch_softplus( gating_output: torch.Tensor, @@ -69,7 +81,7 @@ def run_torch_softplus( return topk_weights, topk_ids.to(torch.int32) -@perftest(num_iters=10, num_warmup=1) +@perftest(num_iters=100, num_warmup=1) def run_fused_softplus( gating_output: torch.Tensor, bias: torch.Tensor, @@ -90,6 +102,47 @@ def run_fused_softplus( return topk_weights, topk_ids +# -- topk_softmax ( classic MoE softmax routing) -------------- +@perftest(num_iters=10, num_warmup=1) +def run_torch_softmax( + gating_output: torch.Tensor, + bias: torch.Tensor, + topk: int, + route_scale: float, +): + scores = torch.softmax(gating_output.float(), dim=-1) + scores_biased = scores + bias.float() if bias.numel() > 0 else scores + topk_ids = scores_biased.topk(topk, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_ids) * route_scale + return topk_weights, topk_ids.to(torch.int32) + + +@perftest(num_iters=100, num_warmup=1) +def run_fused_softmax( + gating_output: torch.Tensor, + bias: torch.Tensor, + topk: int, + route_scale: float, +): + tokens, _ = gating_output.shape + topk_weights = torch.empty( + (tokens, topk), dtype=torch.float32, device=gating_output.device + ) + topk_ids = torch.empty( + (tokens, topk), dtype=torch.int32, device=gating_output.device + ) + aiter.topk_gating( + topk_weights, + topk_ids, + gating_output, + bias, + need_renorm=False, # softmax is already normalized + routed_scaling_factor=route_scale, + score_func="softmax", + ) + return topk_weights, topk_ids + + def benchmark_topk_sigmoid( num_experts: int = 128, num_tokens: int = 1024, @@ -210,13 +263,84 @@ def benchmark_topk_softplus( return result +def benchmark_topk_softmax( + num_experts: int = 256, + num_tokens: int = 1024, + topk: int = 8, + dtype: torch.dtype = torch.bfloat16, + route_scale: float = 1.0, + use_bias: bool = True, +): + gating_output = ( + torch.arange(-1, 1, 2.0 / num_experts) + .repeat((num_tokens, 1)) + .to(dtype=dtype, device="cuda") + ) + permutation = torch.argsort(torch.rand_like(gating_output), dim=-1) + gating_output = torch.gather(gating_output, dim=-1, index=permutation) + bias = ( + torch.randn(num_experts, dtype=torch.float32, device="cuda") * 0.1 + if use_bias + else torch.empty(0, device="cuda") + ) + + (w_torch, i_torch), avg_torch = run_torch_softmax( + gating_output.clone(), bias, topk, route_scale + ) + (w_fused, i_fused), avg_fused = run_fused_softmax( + gating_output.clone(), bias, topk, route_scale + ) + + id_match = 0 + max_w_err = 0.0 + for t in range(num_tokens): + kern_set = set(i_fused[t].tolist()) + ref_set = set(i_torch[t].tolist()) + if kern_set == ref_set: + id_match += 1 + for k in range(topk): + kid = i_fused[t, k].item() + ref_k = (i_torch[t] == kid).nonzero(as_tuple=True)[0] + if len(ref_k) > 0: + err = abs(w_fused[t, k].item() - w_torch[t, ref_k[0]].item()) + max_w_err = max(max_w_err, err) + + id_err = 1.0 - id_match / num_tokens + + result = { + "num_experts": num_experts, + "num_tokens": num_tokens, + "topk": topk, + "dtype": str(dtype).split(".")[-1], + "torch_us": avg_torch, + "fused_us": avg_fused, + "uplift": avg_torch / avg_fused, + "id_errors": id_err, + "max_weight_err": max_w_err, + } + + if id_err > 0.01: + print( + f"\n[ERROR] softmax: num_experts={num_experts}, num_tokens={num_tokens}, " + f"topk={topk}, dtype={str(dtype).split('.')[-1]}, id_err={id_err:.4f}" + ) + + return result + + # Pytest-parametrized test functions -- topk_softplus -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("topk", [1, 2, 4, 8]) +# Mirrors DeepSeek-V4 model integration: gating fp32 + bias fp32 is the default. +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("bias_dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) @pytest.mark.parametrize("num_tokens", [64, 1024, 2048]) -@pytest.mark.parametrize("num_experts", [64, 128, 256]) -def test_topk_softplus_correctness(num_experts, num_tokens, topk, dtype): - """Pytest test for correctness of topk_softplus (sqrtsoftplus) operation.""" +@pytest.mark.parametrize("num_experts", [64, 128, 256, 384]) +def test_topk_softplus_correctness(num_experts, num_tokens, topk, dtype, bias_dtype): + """Pytest test for correctness of topk_softplus (sqrtsoftplus) operation. + + Covers the DeepSeek-V4-Pro use case: router_logits=fp32, bias=fp32. + Also covers fp16/bf16 gating with mixed bias dtypes. + """ torch.random.manual_seed(0) route_scale = 2.5 @@ -227,7 +351,9 @@ def test_topk_softplus_correctness(num_experts, num_tokens, topk, dtype): ) permutation = torch.argsort(torch.rand_like(gating_output), dim=-1) gating_output = torch.gather(gating_output, dim=-1, index=permutation) - bias = torch.randn(num_experts, dtype=dtype, device="cuda") * 0.1 + bias = (torch.randn(num_experts, dtype=torch.float32, device="cuda") * 0.1).to( + bias_dtype + ) (w_torch, i_torch), _ = run_torch_softplus( gating_output.clone(), bias, topk, True, route_scale @@ -240,9 +366,10 @@ def test_topk_softplus_correctness(num_experts, num_tokens, topk, dtype): for t in range(num_tokens): kern_set = set(i_fused[t].tolist()) ref_set = set(i_torch[t].tolist()) - assert ( - kern_set == ref_set - ), f"Token {t}: ID mismatch kernel={sorted(kern_set)} ref={sorted(ref_set)}" + assert kern_set == ref_set, ( + f"Token {t} (gating={dtype},bias={bias_dtype},E={num_experts},topk={topk}): " + f"ID mismatch kernel={sorted(kern_set)} ref={sorted(ref_set)}" + ) # compare weights (match by expert id) for t in range(num_tokens): @@ -259,7 +386,7 @@ def test_topk_softplus_correctness(num_experts, num_tokens, topk, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("topk", [1, 2, 4, 8]) @pytest.mark.parametrize("num_tokens", [64, 1024, 2048]) -@pytest.mark.parametrize("num_experts", [64, 128]) +@pytest.mark.parametrize("num_experts", [64, 128, 256, 384]) def test_topk_sigmoid_correctness(num_experts, num_tokens, topk, dtype): """Pytest test for correctness of topk_sigmoid operation.""" torch.random.manual_seed(0) @@ -287,6 +414,51 @@ def test_topk_sigmoid_correctness(num_experts, num_tokens, topk, dtype): assert index_errors <= 0.01, f"Index errors {index_errors} exceed tolerance" +# Pytest-parametrized test functions -- topk_softmax (via topk_gating) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) +@pytest.mark.parametrize("num_tokens", [64, 1024, 2048]) +@pytest.mark.parametrize("num_experts", [64, 128, 256, 384]) +def test_topk_softmax_correctness(num_experts, num_tokens, topk, dtype): + """Pytest test for correctness of topk_gating with score_func='softmax'.""" + torch.random.manual_seed(0) + route_scale = 1.0 + + gating_output = ( + torch.arange(-1, 1, 2.0 / num_experts) + .repeat((num_tokens, 1)) + .to(dtype=dtype, device="cuda") + ) + permutation = torch.argsort(torch.rand_like(gating_output), dim=-1) + gating_output = torch.gather(gating_output, dim=-1, index=permutation) + bias = torch.randn(num_experts, dtype=torch.float32, device="cuda") * 0.1 + + (w_torch, i_torch), _ = run_torch_softmax( + gating_output.clone(), bias, topk, route_scale + ) + (w_fused, i_fused), _ = run_fused_softmax( + gating_output.clone(), bias, topk, route_scale + ) + + # compare ids per token (order may differ) + for t in range(num_tokens): + kern_set = set(i_fused[t].tolist()) + ref_set = set(i_torch[t].tolist()) + assert ( + kern_set == ref_set + ), f"Token {t}: ID mismatch kernel={sorted(kern_set)} ref={sorted(ref_set)}" + + # compare weights (match by expert id) + for t in range(num_tokens): + for k in range(topk): + kid = i_fused[t, k].item() + ref_k = (i_torch[t] == kid).nonzero(as_tuple=True)[0] + assert len(ref_k) > 0 + torch.testing.assert_close( + w_fused[t, k], w_torch[t, ref_k[0]], atol=1e-5, rtol=1e-4 + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Test topk_sigmoid and topk_softplus operations" @@ -294,32 +466,32 @@ def test_topk_sigmoid_correctness(num_experts, num_tokens, topk, dtype): parser.add_argument( "--num-experts", type=str2tuple, - default=[128], - help="Comma-separated list of number of experts (default: 128)", + default=[64, 128, 256, 384], + help="Comma-separated list of number of experts (default: 64,128,256,384)", ) parser.add_argument( "--num-tokens", type=str2tuple, - default=[1024], - help="Comma-separated list of number of tokens (default: 1024)", + default=[64, 1024, 2048], + help="Comma-separated list of number of tokens (default: 64,1024,2048)", ) parser.add_argument( "--topk", type=str2tuple, - default=[8], - help="Comma-separated list of topk values (default: 8)", + default=[1, 2, 4, 6, 8], + help="Comma-separated list of topk values (default: 1,2,4,6,8)", ) parser.add_argument( "--dtype", type=str2Dtype, - default=[torch.float16, torch.bfloat16], - help="Comma-separated list of dtypes: fp16, bf16 (default: fp16,bf16)", + default=[torch.float16, torch.bfloat16, torch.float32], + help="Comma-separated list of dtypes: fp16, bf16, fp32 (default: fp16,bf16,fp32)", ) parser.add_argument( "--test", type=str, default="all", - choices=["sigmoid", "softplus", "all"], + choices=["sigmoid", "softplus", "softmax", "all"], help="Which test to run (default: all)", ) @@ -333,16 +505,23 @@ def to_list(x): topk_list = to_list(args.topk) dtype_list = to_list(args.dtype) - configs = list( - itertools.product(num_experts_list, num_tokens_list, topk_list, dtype_list) - ) + # Track whether any benchmark section saw a correctness regression + # (id_errors > 1%); exit non-zero at the end so CI catches it. + failed_sections: list[str] = [] if args.test in ("sigmoid", "all"): + sigmoid_experts = [e for e in num_experts_list] + sigmoid_dtypes = [d for d in dtype_list if d != torch.float32] + sigmoid_configs = list( + itertools.product( + sigmoid_experts, num_tokens_list, topk_list, sigmoid_dtypes + ) + ) print("=" * 80) print("topk_sigmoid benchmark") print("=" * 80) collected = [] - for num_experts, num_tokens, topk, dtype in configs: + for num_experts, num_tokens, topk, dtype in sigmoid_configs: result = benchmark_topk_sigmoid( num_experts=num_experts, num_tokens=num_tokens, topk=topk, dtype=dtype ) @@ -350,13 +529,22 @@ def to_list(x): df = pd.DataFrame(collected) print(df.to_string(index=False)) print(f"\nAverage uplift: {df['uplift'].mean():.2f}x") + # benchmark_topk_sigmoid uses {score,index}_errors columns + errors = df[(df["index_errors"] > 0.01) | (df["score_errors"] > 0.01)] + if len(errors) > 0: + print(f"\nERROR: {len(errors)} sigmoid config(s) had errors > 1%!") + print(errors.to_string(index=False)) + failed_sections.append("sigmoid") if args.test in ("softplus", "all"): + softplus_configs = list( + itertools.product(num_experts_list, num_tokens_list, topk_list, dtype_list) + ) print("\n" + "=" * 80) print("topk_softplus benchmark") print("=" * 80) collected = [] - for num_experts, num_tokens, topk, dtype in configs: + for num_experts, num_tokens, topk, dtype in softplus_configs: result = benchmark_topk_softplus( num_experts=num_experts, num_tokens=num_tokens, topk=topk, dtype=dtype ) @@ -366,8 +554,41 @@ def to_list(x): print(f"\nAverage uplift: {df['uplift'].mean():.2f}x") errors = df[df["id_errors"] > 0.01] if len(errors) > 0: - print(f"\nWARNING: {len(errors)} config(s) had id errors > 1%!") + print(f"\nERROR: {len(errors)} softplus config(s) had id errors > 1%!") print(errors.to_string(index=False)) + failed_sections.append("softplus") else: - print("All tests passed!") + print("All softplus tests passed!") + + if args.test in ("softmax", "all"): + softmax_configs = list( + itertools.product(num_experts_list, num_tokens_list, topk_list, dtype_list) + ) + print("\n" + "=" * 80) + print("topk_softmax benchmark (via topk_gating)") + print("=" * 80) + collected = [] + for num_experts, num_tokens, topk, dtype in softmax_configs: + result = benchmark_topk_softmax( + num_experts=num_experts, num_tokens=num_tokens, topk=topk, dtype=dtype + ) + collected.append(result) + df = pd.DataFrame(collected) + print(df.to_string(index=False)) + print(f"\nAverage uplift: {df['uplift'].mean():.2f}x") + errors = df[df["id_errors"] > 0.01] + if len(errors) > 0: + print(f"\nERROR: {len(errors)} softmax config(s) had id errors > 1%!") + print(errors.to_string(index=False)) + failed_sections.append("softmax") + else: + print("All softmax tests passed!") print("=" * 80) + + if failed_sections: + print( + f"FAIL: correctness regression in section(s): " + f"{', '.join(failed_sections)}", + file=sys.stderr, + ) + sys.exit(1)