[feat](kt-kernel): AVX2 MXFP4 MoE + AMX tile MXFP4 dispatch#2010
[feat](kt-kernel): AVX2 MXFP4 MoE + AMX tile MXFP4 dispatch#2010yyj6666667 wants to merge 2 commits into
Conversation
- Add AVX2 MXFP4 MoE kernel (mxfp4-moe.hpp) with 4-token M-blocking - Add AMX N-tail fallback in fp4-moe.hpp for non-aligned expert sizes - Add AMX tile MXFP4 backend selection (_select_mxfp4_backend in amx.py) - Wire AVX2MXFP4_MOE binding in ext_bindings.cpp - Support TP_MOE down_proj slicing and multi-pool per-expert loading - Add test_fp4_moe_avx2.py integration test
There was a problem hiding this comment.
Code Review
This pull request introduces MXFP4 (FP4 E2M1) MoE inference support for AVX2 architectures, complementing the existing AMX path. Key additions include a new AVX2 backend using SSSE3 lookup tables for dequantization, an optimized AMX tile path for prefill operations, and Python-level backend selection logic. Review feedback identified several high-severity issues, including memory leaks from unmanaged std::aligned_alloc buffers and potential data corruption when slicing nibble-packed weights on odd boundaries. Additionally, improvements were suggested for memory deallocation consistency and defensive error handling in the test suite.
| tpc.physical_to_logical_map = config.physical_to_logical_map; | ||
| int per_tp_interm = tpc.intermediate_size; | ||
| size_t down_wt_per_expert = (size_t)tpc.hidden_size * per_tp_interm / 2; | ||
| uint8_t* down_buf = (uint8_t*)std::aligned_alloc(64, (size_t)tpc.expert_num * down_wt_per_expert); |
There was a problem hiding this comment.
Memory leak detected. down_buf is allocated using std::aligned_alloc but is never freed. Since this buffer is intended to persist for the model's lifetime, it should be managed by a member variable in the class (e.g., a std::vector<void*> or a smart pointer with a custom deleter) to ensure it is released when the MoE layer is destroyed. Additionally, std::aligned_alloc returns nullptr on failure, which is not checked here before use in std::memcpy.
There was a problem hiding this comment.
Memory leak detected.
down_bufis allocated usingstd::aligned_allocbut is never freed. Since this buffer is intended to persist for the model's lifetime, it should be managed by a member variable in the class (e.g., astd::vector<void*>or a smart pointer with a custom deleter) to ensure it is released when the MoE layer is destroyed. Additionally,std::aligned_allocreturnsnullptron failure, which is not checked here before use instd::memcpy.
Thanks bot, I already confirmed it, I'll commit the following patch soon
| std::memcpy(dst_down + (size_t)row * per_tp_interm / 2, | ||
| src_down + (size_t)row * full_interm / 2 + (size_t)i * per_tp_interm / 2, | ||
| per_tp_interm / 2); |
There was a problem hiding this comment.
Potential data corruption or crash due to misalignment. MXFP4 is nibble-packed (2 elements per byte). When slicing down_proj along the intermediate dimension (K-split), the starting index i * per_tp_interm must be even to align with a byte boundary. If per_tp_interm is odd, std::memcpy will start or end at a half-byte offset, leading to incorrect weight loading. An assertion or check should be added to ensure per_tp_interm is even.
References
- Enforce defensive programming: ensure invalid inputs or states (like odd split points for packed data) are safely handled.
| struct BufferB { | ||
| uint8_t* b = nullptr; // nibble-packed FP4 (may be nullptr in scale-only mode) | ||
| float* d = nullptr; // FP32 group scales | ||
| int n = 0, k = 0, k_group_size = 0, k_group_count = 0; | ||
|
|
||
| BufferB() = default; | ||
|
|
||
| // Full allocation: b and d packed into a single aligned block. | ||
| BufferB(int n_, int k_, int k_group_size_, void* ptr) | ||
| : b((uint8_t*)ptr), n(n_), k(k_), k_group_size(k_group_size_) { |
There was a problem hiding this comment.
Memory leak in AVX2_MOE_BASE::init. The pointers gate_bb_ptr, up_bb_ptr, and down_bb_ptr are allocated via std::aligned_alloc but are not owned by BufferB or any other object that handles their deallocation. Furthermore, in load_weights (line 651), the b pointer in BufferB is overwritten with a new allocation, causing the original pointer to be lost and leaked. Use a management strategy that ensures all aligned_alloc memory is eventually passed to std::free.
| print(f" t_output {t_output.flatten()[:debug_print_count]}") | ||
|
|
||
| return {"case": pattern, "description": desc, | ||
| "mean": sum(diffs)/len(diffs), "max": max(diffs), "min": min(diffs)} |
There was a problem hiding this comment.
Potential ZeroDivisionError if diffs is empty. Although validation_iter is currently set to 3, it is safer to check the length of the list before division.
| "mean": sum(diffs)/len(diffs), "max": max(diffs), "min": min(diffs)} | |
| "mean": sum(diffs)/len(diffs) if diffs else 0.0, "max": max(diffs) if diffs else 0.0, "min": min(diffs) if diffs else 0.0} |
| pool->dispense_backend()->do_numa_job([&, this](int i) { | ||
| auto& tpc = tps[i]->config_; | ||
| delete[] (uint8_t*)tpc.gate_proj; | ||
| delete[] (uint8_t*)tpc.up_proj; | ||
| delete[] (uint8_t*)tpc.down_proj; | ||
| delete[] (ggml_bf16_t*)tpc.gate_scale; | ||
| delete[] (ggml_bf16_t*)tpc.up_scale; | ||
| delete[] (ggml_bf16_t*)tpc.down_scale; | ||
| }); |
There was a problem hiding this comment.
Incorrect deallocation method. The buffers gate_proj, up_proj, etc., were allocated using new[] at lines 672-677, so delete[] is correct for those. However, if any of these pointers were to be allocated via std::aligned_alloc (as seen elsewhere in this file), using delete[] would result in undefined behavior. Ensure consistency between allocation and deallocation methods across the project.
…lignment, dynamic expert update - Track aligned_alloc pointers in AVX2_MOE_BASE::owned_aligned_allocs_ and free them in the destructor (fixes BufferB backing memory leak on destroy). - Track per-TP down_buf allocations in TP_MOE::tp_owned_down_bufs_ with nullptr checks and size rounding to alignment boundary. - Add nibble-alignment runtime check for per_tp_interm in MXFP4 TP K-split. - Add write_weight_scale_to_buffer override to TP_MOE<AVX2_MXFP4_MOE_TP>, enabling dynamic expert update with kt-threadpool-count>=2. - Guard against ZeroDivisionError in test_fp4_moe_avx2.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
mxfp4-moe.hpp) with 4-token M-blocking, enabling MXFP4 MoE on non-AMX CPUsfp4-moe.hppfor expert sizes not aligned to tile dimensions_select_mxfp4_backendinamx.py)AVX2MXFP4_MOEbinding inext_bindings.cppdown_projslicing and multi-pool per-expert loadingTest Results (DeepSeek V4 Flash × 1×RTX 5090)
MMLU 100-subset
Changed Files
kt-kernel/operators/avx2/mxfp4-moe.hpp— new AVX2 MXFP4 MoE kernelkt-kernel/operators/amx/fp4-moe.hpp— AMX N-tail fallbackkt-kernel/python/utils/amx.py—_select_mxfp4_backend()dispatchkt-kernel/ext_bindings.cpp— AVX2MXFP4_MOE bindingkt-kernel/examples/test_fp4_moe_avx2.py— integration test