added support for MLA decode#139
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds an Intel XPU (BMG)-targeted CUTLASS-style Multi-head Latent Attention (MLA) decode kernel and wires it through the C++/PyTorch extension, Python API, tests, and a benchmark script.
Changes:
- Introduces a new SYCL/CUTLASS-based MLA decode kernel (mainloop/epilogue/kernel + tile scheduler + runner) and a PyTorch C++ entrypoint.
- Updates the Python API to accept split query inputs (
q_nope,q_pe) and exposes the newcutlass_mla_decode/cutlass_mla_get_workspace_sizeops. - Updates tests and benchmark to run on XPU and match the new API.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
src/sycl/mla_decode.cpp |
Adds the MLA decode PyTorch C++ interface and dispatch (dtype/page-size). |
src/sycl/kernels/mla/xe_mla_mainloop.hpp |
Implements the MLA mainloop: QK, online softmax, and PV accumulation. |
src/sycl/kernels/mla/xe_mla_epilogue.hpp |
Implements epilogue reduction/normalization and output writeback. |
src/sycl/kernels/mla/xe_mla_kernel.hpp |
Orchestrates mainloop + epilogue and constructs tensors/tiles. |
src/sycl/kernels/mla/mla_tile_scheduler.hpp |
Provides workgroup-to-tile mapping for MLA decode. |
src/sycl/kernels/mla/mla_runner.hpp |
Device-layer wrapper to launch the MLA kernel via SYCL launch APIs. |
src/sycl/kernels/mla/copy_block_slm.hpp |
Adds SLM copy helpers used by the epilogue reduction path. |
src/torch_extension_sycl.cc |
Registers cutlass_mla_decode and cutlass_mla_get_workspace_size with torch.ops. |
include/sgl_flash_kernel_ops.h |
Declares the new MLA public C++ entrypoints. |
include/sgl_kernel_ops.h |
Removes MLA declarations from the non-flash ops header. |
src/sycl/Utils.h |
Adds a CUTLASS_CHECK helper macro used by the new kernel path. |
python/sgl_kernel/attention.py |
Updates Python wrapper to new MLA API (q_nope, q_pe, sm_scale). |
tests/test_cutlass_mla.py |
Updates MLA test to be XPU-only and validate XPU kernel vs CPU reference. |
benchmark/bench_cutlass_mla.py |
Updates benchmark for XPU and adds result collection/plotting utilities. |
You can also share your feedback on Copilot code review. Take the survey.
kareemshaik80
left a comment
There was a problem hiding this comment.
Please add the current performance benchmarking results to the description is possible.
updated. |
|
Before we go into details, two things to focus right now:
|
6a612d0 to
654cc4a
Compare
hi, I have updated the changes, now CI cost is ~24 min. |
hi, I agree, one thing to note here though the core of MLA is a standard GQA, but currently our |
|
Hi @mingfeima @pralay-das @kareemshaik80 @airMeng as the critical comments related to CI timing and performance numbers have been addressed so could we please further merge this PR. Pralay is working on more optimizations and those would come soon in follow-up PRs. We would further work to bring the best performance however at the same time it is important to merge this PR as it is pending in review for quite long. This would also enable QA to do exhaustive functional testing and topology team to integrate. Thanks. |
airMeng
left a comment
There was a problem hiding this comment.
LGTM but @sunjiweiswift please review ASAP
| out = cutlass_mla_decode( | ||
| q, kv_cache, seq_lens, block_table, workspace, num_kv_splits | ||
| q_nope, |
There was a problem hiding this comment.
So this kernel support nope fusion? we need another SGLang PR to enable this, how much benefit we get from the fusion?
I prefer hack the GQA code a little bit to let it be able to run 576, even this may lead to sub optimal perf. Still we need to compare MLA v.s. GQA as a baseline reference. |
|
let's align the interface a little bit. MLA in sglang has a very complexed dispatch logic on CUDA side, this comes with a historical reason, which intel not necessarily follow.
First of all, we decided that we use something one for all, just like aiter and ascend, which is our XPUAttnBackend. Secondly, it is essentially a mapping of fa3 from implementation level. So make sure that the inferface aligns with existing XPUAttnBackend. mapping to and manage the workspace in a similar approach as current GQA split kv path. |
Hello @mingfeima for any decisions which we have taken and also the approach being followed here in this PR, I would suggest if you could align with @pralay-das. Thanks. |
|
@pralay-das we have an auto performance monitor and to track performance per PR, currently only tracking MoE, FA later, Would you like to add MLA into monitor? The performance track will be like #128 |
|
I believe the current benchmark is low. It should be raised to 350G. |
Hi, I agree, I am working on it, all performance related fixes will come with follow up PR. |
any reason the perf is low right now? any difficulty that the perf issue can not be fixed right now in this PR? |
Hi, I am not sure, where is the bottleneck. Few changes I did but it is not helping that much. I am working on it, I need some more time and would provide changes in follow-up PRs. |
|
@mingfeima @sunjiweiswift shall me merge with current performance? |
Hello @airMeng @mingfeima @sunjiweiswift I would request to merge this PR as the changes are big and it would unblock exhaustive QA validation and also the integration to topologies, as well as help us avoid logistic issues like rebase/resolve conflicts etc. As these are initial changes so span is big across the code base so better to merge them. we are further working on performance optimizations and those would definitely come in future PRs. Thanks. |
NP. Usually i won't recommend to land a kernel with perf still having big issue. Anyway if this is your preference, we can land this one and please continue working on the perf optimization. Please fix the naming issue, don't use |
|
@mkumargarg you may merge this one if this is your preference. Please continue with the optimization efforts to improve kernel performance, we especially care decoding perf when input sequence length is 3500 and output length is 1500. additionally, please update the deepseek R1 attention shape in the benchmark, which has |
This PR implements a CUTLASS-based Multi-head Latent Attention (MLA) decode kernel
cutlass_mla_decodeoptimized for Intel XPU (BMG). MLA is the attention mechanism used in DeepSeek-V2/V3 models, which compresses KV cache using low-rank projections to reduce memory bandwidth requirements.Algorithm Overview
MLA decoding performs the following computation:
Where:
Architecture
The implementation follows a CUTLASS-style modular design:
Key Features
Testing
Total: 1152 test configurations
API
Performance Benchmarking