[Kernel] feat: add NVFP4 blockwise MoE kernels for sm_120#1528
[Kernel] feat: add NVFP4 blockwise MoE kernels for sm_120#1528AlpinDale wants to merge 1 commit into
Conversation
Summary of ChangesHello @AlpinDale, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays the groundwork for supporting NVFP4 blockwise Mixture-of-Experts (MoE) operations on NVIDIA's latest SM120 (Blackwell) GPU architecture. It integrates a new reference kernel, updates the underlying CUTLASS dependency, and adjusts the build system and Python dispatch mechanisms to enable this new hardware capability. While functional, the current implementation is noted as not fully optimized, indicating it's an initial step towards full performance parity. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVFP4 blockwise Mixture-of-Experts (MoE) kernels specifically for the sm_120 architecture, likely targeting upcoming Blackwell GPUs. The changes span across CMake build configurations, Python dispatch logic, and new CUDA kernel implementations. While the overall approach is sound, I've identified several critical issues in the new sm_120 kernel that could lead to incorrect computations, particularly in the dequantization logic and loop bounds. Additionally, there are opportunities to improve code maintainability and performance. My review provides specific suggestions to address these points.
|
|
||
| // E4M3 scale factor dequantization | ||
| __device__ __forceinline__ float dequantize_e4m3_scale(uint8_t e4m3_val) { | ||
| if (e4m3_val == 0) return 1.0f; |
There was a problem hiding this comment.
The special handling for e4m3_val == 0 is incorrect. An E4M3 value of 0x00 represents +0.0f, but the function currently returns 1.0f. The subsequent logic in the function already correctly handles denormalized numbers (when exp == 0), so this special case is not only incorrect but also unnecessary. Removing this line will allow the function to correctly return 0.0f for an input of 0x00.
| // Compute dot product in 16-element blocks | ||
| float sum = 0.0f; | ||
| int k_packed = K / 2; | ||
| int k_blocks = K / 16; |
There was a problem hiding this comment.
The calculation of k_blocks uses integer division, which will truncate the result if K is not a multiple of 16. This will cause the kernel to skip processing the tail end of the K dimension, leading to incorrect matrix multiplication results. You should use ceiling division to ensure all elements are processed.
int k_blocks = (K + 15) / 16;
| major, minor = torch.cuda.get_device_capability(device) | ||
|
|
||
| # Use SM120 kernel for compute capability 12.0 and above | ||
| if major == 12 and minor == 0: |
There was a problem hiding this comment.
The check for the SM120 architecture is too specific. By checking for major == 12 and minor == 0, you are limiting this path to compute capability 12.0 exactly. To ensure forward compatibility with future minor revisions of the same architecture (e.g., 12.1), it's better to check only the major version number.
| if major == 12 and minor == 0: | |
| if major == 12: |
| #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 | ||
| { | ||
| int32_t version_num = get_sm_version_num(); | ||
| if (version_num >= 120) { | ||
| return cutlass_fp4_group_mm_sm120(output, a, b, a_blockscale, | ||
| b_blockscales, alphas, problem_sizes, | ||
| expert_offsets, sf_offsets); | ||
| } | ||
| } | ||
| #endif |
There was a problem hiding this comment.
| for (int e = 0; e < num_experts; e++) { | ||
| int start = expert_offsets[e]; | ||
| int end = (e == num_experts - 1) ? M : expert_offsets[e + 1]; | ||
| if (tid_y >= start && tid_y < end) { | ||
| expert_id = e; | ||
| break; | ||
| } | ||
| } |
There was a problem hiding this comment.
This loop performs a linear scan to find the expert_id. This can be inefficient if num_experts is large. Since expert_offsets is sorted, you can achieve better performance by using a binary search to locate the expert.
// Find expert using binary search
int low = 0, high = num_experts;
while (low < high) {
int mid = low + (high - low) / 2;
if (tid_y < expert_offsets[mid]) {
high = mid;
} else {
low = mid + 1;
}
}
int expert_id = low - 1;
Not fully optimized, as a lot of the sm_100 codepath is still used for this.
Tested with alpindale/Ling-mini-2.0-NVFP4, it gets about 91 tok/s decode (slower than the 140 tok/s with AWQ Marlin).