Skip to content

[feat](kt-kernel): AVX2 MXFP4 MoE + AMX tile MXFP4 dispatch#2010

Closed
yyj6666667 wants to merge 2 commits into
kvcache-ai:mainfrom
yyj6666667:feat/avx2-mxfp4-moe
Closed

[feat](kt-kernel): AVX2 MXFP4 MoE + AMX tile MXFP4 dispatch#2010
yyj6666667 wants to merge 2 commits into
kvcache-ai:mainfrom
yyj6666667:feat/avx2-mxfp4-moe

Conversation

@yyj6666667
Copy link
Copy Markdown
Collaborator

@yyj6666667 yyj6666667 commented May 17, 2026

  • Add AVX2 MXFP4 MoE kernel (mxfp4-moe.hpp) with 4-token M-blocking, enabling MXFP4 MoE on non-AMX CPUs
  • Add AMX N-tail fallback in fp4-moe.hpp for expert sizes not aligned to tile dimensions
  • Add AMX tile MXFP4 backend auto-selection (_select_mxfp4_backend in amx.py)
  • Wire AVX2MXFP4_MOE binding in ext_bindings.cpp
  • Support TP_MOE down_proj slicing and multi-pool per-expert loading

Test Results (DeepSeek V4 Flash × 1×RTX 5090)

MMLU 100-subset

Build GPU Experts Chunked Prefill mem-fraction-static Score
AVX2 6 2048 0.80 90%

Changed Files

  • kt-kernel/operators/avx2/mxfp4-moe.hpp — new AVX2 MXFP4 MoE kernel
  • kt-kernel/operators/amx/fp4-moe.hpp — AMX N-tail fallback
  • kt-kernel/python/utils/amx.py_select_mxfp4_backend() dispatch
  • kt-kernel/ext_bindings.cpp — AVX2MXFP4_MOE binding
  • kt-kernel/examples/test_fp4_moe_avx2.py — integration test

- Add AVX2 MXFP4 MoE kernel (mxfp4-moe.hpp) with 4-token M-blocking
- Add AMX N-tail fallback in fp4-moe.hpp for non-aligned expert sizes
- Add AMX tile MXFP4 backend selection (_select_mxfp4_backend in amx.py)
- Wire AVX2MXFP4_MOE binding in ext_bindings.cpp
- Support TP_MOE down_proj slicing and multi-pool per-expert loading
- Add test_fp4_moe_avx2.py integration test
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces MXFP4 (FP4 E2M1) MoE inference support for AVX2 architectures, complementing the existing AMX path. Key additions include a new AVX2 backend using SSSE3 lookup tables for dequantization, an optimized AMX tile path for prefill operations, and Python-level backend selection logic. Review feedback identified several high-severity issues, including memory leaks from unmanaged std::aligned_alloc buffers and potential data corruption when slicing nibble-packed weights on odd boundaries. Additionally, improvements were suggested for memory deallocation consistency and defensive error handling in the test suite.

Comment thread kt-kernel/operators/avx2/mxfp4-moe.hpp Outdated
tpc.physical_to_logical_map = config.physical_to_logical_map;
int per_tp_interm = tpc.intermediate_size;
size_t down_wt_per_expert = (size_t)tpc.hidden_size * per_tp_interm / 2;
uint8_t* down_buf = (uint8_t*)std::aligned_alloc(64, (size_t)tpc.expert_num * down_wt_per_expert);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Memory leak detected. down_buf is allocated using std::aligned_alloc but is never freed. Since this buffer is intended to persist for the model's lifetime, it should be managed by a member variable in the class (e.g., a std::vector<void*> or a smart pointer with a custom deleter) to ensure it is released when the MoE layer is destroyed. Additionally, std::aligned_alloc returns nullptr on failure, which is not checked here before use in std::memcpy.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Memory leak detected. down_buf is allocated using std::aligned_alloc but is never freed. Since this buffer is intended to persist for the model's lifetime, it should be managed by a member variable in the class (e.g., a std::vector<void*> or a smart pointer with a custom deleter) to ensure it is released when the MoE layer is destroyed. Additionally, std::aligned_alloc returns nullptr on failure, which is not checked here before use in std::memcpy.

Thanks bot, I already confirmed it, I'll commit the following patch soon

Comment on lines +647 to +649
std::memcpy(dst_down + (size_t)row * per_tp_interm / 2,
src_down + (size_t)row * full_interm / 2 + (size_t)i * per_tp_interm / 2,
per_tp_interm / 2);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Potential data corruption or crash due to misalignment. MXFP4 is nibble-packed (2 elements per byte). When slicing down_proj along the intermediate dimension (K-split), the starting index i * per_tp_interm must be even to align with a byte boundary. If per_tp_interm is odd, std::memcpy will start or end at a half-byte offset, leading to incorrect weight loading. An assertion or check should be added to ensure per_tp_interm is even.

References
  1. Enforce defensive programming: ensure invalid inputs or states (like odd split points for packed data) are safely handled.

Comment on lines +119 to +128
struct BufferB {
uint8_t* b = nullptr; // nibble-packed FP4 (may be nullptr in scale-only mode)
float* d = nullptr; // FP32 group scales
int n = 0, k = 0, k_group_size = 0, k_group_count = 0;

BufferB() = default;

// Full allocation: b and d packed into a single aligned block.
BufferB(int n_, int k_, int k_group_size_, void* ptr)
: b((uint8_t*)ptr), n(n_), k(k_), k_group_size(k_group_size_) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Memory leak in AVX2_MOE_BASE::init. The pointers gate_bb_ptr, up_bb_ptr, and down_bb_ptr are allocated via std::aligned_alloc but are not owned by BufferB or any other object that handles their deallocation. Furthermore, in load_weights (line 651), the b pointer in BufferB is overwritten with a new allocation, causing the original pointer to be lost and leaked. Use a management strategy that ensures all aligned_alloc memory is eventually passed to std::free.

Comment thread kt-kernel/examples/test_fp4_moe_avx2.py Outdated
print(f" t_output {t_output.flatten()[:debug_print_count]}")

return {"case": pattern, "description": desc,
"mean": sum(diffs)/len(diffs), "max": max(diffs), "min": min(diffs)}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Potential ZeroDivisionError if diffs is empty. Although validation_iter is currently set to 3, it is safer to check the length of the list before division.

Suggested change
"mean": sum(diffs)/len(diffs), "max": max(diffs), "min": min(diffs)}
"mean": sum(diffs)/len(diffs) if diffs else 0.0, "max": max(diffs) if diffs else 0.0, "min": min(diffs) if diffs else 0.0}

Comment on lines +730 to +738
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)tpc.gate_proj;
delete[] (uint8_t*)tpc.up_proj;
delete[] (uint8_t*)tpc.down_proj;
delete[] (ggml_bf16_t*)tpc.gate_scale;
delete[] (ggml_bf16_t*)tpc.up_scale;
delete[] (ggml_bf16_t*)tpc.down_scale;
});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Incorrect deallocation method. The buffers gate_proj, up_proj, etc., were allocated using new[] at lines 672-677, so delete[] is correct for those. However, if any of these pointers were to be allocated via std::aligned_alloc (as seen elsewhere in this file), using delete[] would result in undefined behavior. Ensure consistency between allocation and deallocation methods across the project.

…lignment, dynamic expert update

- Track aligned_alloc pointers in AVX2_MOE_BASE::owned_aligned_allocs_ and
  free them in the destructor (fixes BufferB backing memory leak on destroy).
- Track per-TP down_buf allocations in TP_MOE::tp_owned_down_bufs_ with
  nullptr checks and size rounding to alignment boundary.
- Add nibble-alignment runtime check for per_tp_interm in MXFP4 TP K-split.
- Add write_weight_scale_to_buffer override to TP_MOE<AVX2_MXFP4_MOE_TP>,
  enabling dynamic expert update with kt-threadpool-count>=2.
- Guard against ZeroDivisionError in test_fp4_moe_avx2.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@yyj6666667 yyj6666667 closed this May 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant