Skip to content

added support for MLA decode#139

Merged
airMeng merged 13 commits into
sgl-project:mainfrom
pralay-das:dev/pralay/mla
Apr 7, 2026
Merged

added support for MLA decode#139
airMeng merged 13 commits into
sgl-project:mainfrom
pralay-das:dev/pralay/mla

Conversation

@pralay-das
Copy link
Copy Markdown
Collaborator

@pralay-das pralay-das commented Mar 17, 2026

This PR implements a CUTLASS-based Multi-head Latent Attention (MLA) decode kernel cutlass_mla_decode optimized for Intel XPU (BMG). MLA is the attention mechanism used in DeepSeek-V2/V3 models, which compresses KV cache using low-rank projections to reduce memory bandwidth requirements.

Algorithm Overview

MLA decoding performs the following computation:

Score = Q_nope @ K^T + Q_pe @ K_pe^T    # Split attention scores
P = softmax(Score * scale)               # Online softmax
O = P @ V                                 # Output accumulation

Where:

  • [Q_nope](bs, num_heads, 512): Query without positional encoding
  • [Q_pe](bs, num_heads, 64): Query positional encoding
  • [K](page_size, 512, num_pages): Compressed KV cache (latent dimension)
  • [K_pe](page_size, 64, num_pages): K positional encoding
  • [V](page_size, 512, num_pages): Value (shares storage with K)

Architecture

The implementation follows a CUTLASS-style modular design:

Component File Purpose
Mainloop xe_mla_mainloop.hpp QK score computation, online softmax, PV output accumulation
Epilogue xe_mla_epilogue.hpp Cross-subgroup reduction, softmax normalization, output write
Kernel xe_mla_kernel.hpp Top-level orchestrator, tensor construction
Tile Scheduler mla_tile_scheduler.hpp Workgroup-to-tile mapping
PyTorch Interface mla_decode.cpp Python binding, kernel dispatch

Key Features

  • Paged KV Cache Support: Handles both fixed and variable-length sequences with page table lookup for efficient memory utilization
  • Online Softmax: Numerically stable softmax computed incrementally across K tiles
  • Cross-Subgroup Reduction: Efficient reduction of partial results across Intel GPU subgroups
  • XE_DPAS MMA: Leverages Intel XMX (Xe Matrix Extensions) for high-throughput matrix operations
  • Multiple Page Sizes: Supports page sizes of 16, 32, 64, and 128
  • Mixed Precision: Supports both BF16 and FP16 data types

Testing

  • Data types: BF16, FP16
  • Sequence lengths: 128, 1024, 4096
  • Batch sizes: 1, 2, 4
  • Page sizes: 16, 32, 64, 128
  • Number of heads: 16, 32, 64, 128
  • Variable length: True, False

Total: 1152 test configurations

API

from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size

# Get workspace size
workspace_size = cutlass_mla_get_workspace_size(max_seq_len, batch_size, num_kv_splits)
workspace = torch.empty(workspace_size, device="xpu", dtype=torch.uint8)

# Run MLA decode
output = cutlass_mla_decode(
    q_nope,       # (bs, num_heads, 512)
    q_pe,         # (bs, num_heads, 64)  
    kv_cache,     # (num_blocks, block_size, 576)
    seq_lens,     # (bs,)
    block_table,  # (bs, max_blocks)
    workspace,
    scale,
    num_kv_splits
)

Performance Benchmarking

batch_size seq_len num_heads block_size num_kv_splits time_ms GB/s (median) GB/s (min) GB/s (max) GB/s
1 1024 128 16 -1 0.144 10.16 10.14 10.18 10.16
1 1024 64 16 -1 0.136 9.69 9.65 9.71 9.69
1 1024 32 16 -1 0.125 9.96 9.91 9.97 9.96
1 1024 16 16 -1 0.124 9.76 9.73 9.77 9.76
1 2048 128 16 -1 0.29 9.09 9.08 9.1 9.09
1 2048 64 16 -1 0.273 9.17 9.16 9.18 9.17
1 2048 32 16 -1 0.247 9.83 9.82 9.84 9.83
1 2048 16 16 -1 0.247 9.71 9.71 9.72 9.71
1 4096 128 16 -1 0.578 8.65 8.64 8.65 8.65
1 4096 64 16 -1 0.538 9.03 9.02 9.03 9.03
1 4096 32 16 -1 0.491 9.76 9.76 9.77 9.76
1 4096 16 16 -1 0.492 9.66 9.65 9.66 9.66
1 8192 128 16 -1 1.164 8.35 8.34 8.35 8.35
1 8192 64 16 -1 1.072 8.94 8.93 8.94 8.94
1 8192 32 16 -1 0.978 9.73 9.72 9.73 9.73
1 8192 16 16 -1 0.982 9.65 9.65 9.66 9.65
4 1024 128 16 -1 0.5 11.68 11.6 11.75 11.68
4 1024 64 16 -1 0.239 22.1 21.94 22.19 22.1
4 1024 32 16 -1 0.185 26.99 26.88 27.09 26.99
4 1024 16 16 -1 0.175 27.79 27.66 27.89 27.79
4 2048 128 16 -1 1.032 10.22 10.12 10.39 10.22
4 2048 64 16 -1 0.476 21.01 20.09 21.2 21.01
4 2048 32 16 -1 0.364 26.67 26.57 26.73 26.67
4 2048 16 16 -1 0.343 27.89 27.83 27.92 27.89
4 4096 128 16 -1 2.092 9.55 9.33 9.71 9.55
4 4096 64 16 -1 0.944 20.58 20.09 20.65 20.58
4 4096 32 16 -1 0.723 26.49 26.45 26.52 26.49
4 4096 16 16 -1 0.684 27.81 27.76 27.85 27.81
4 8192 128 16 -1 4.405 8.82 8.68 9 8.82
4 8192 64 16 -1 1.959 19.56 19.01 20.09 19.56
4 8192 32 16 -1 1.474 25.8 25.77 25.83 25.8
4 8192 16 16 -1 1.356 27.95 27.92 27.96 27.95
16 1024 128 16 -1 1.619 14.42 14.1 14.64 14.42
16 1024 64 16 -1 1.119 18.86 18.73 18.98 18.86
16 1024 32 16 -1 0.577 34.62 34.31 35.02 34.62
16 1024 16 16 -1 0.345 56.34 56.18 56.51 56.34
16 2048 128 16 -1 3.786 11.15 11.03 11.29 11.15
16 2048 64 16 -1 2.331 17.15 16.92 17.27 17.15
16 2048 32 16 -1 1.149 33.84 33.46 34.23 33.84
16 2048 16 16 -1 0.685 55.92 55.74 56.11 55.92
16 4096 128 16 -1 8.385 9.54 9.32 9.71 9.54
16 4096 64 16 -1 5.04 15.42 15.37 15.67 15.42
16 4096 32 16 -1 2.351 32.59 32.17 33.46 32.59
16 4096 16 16 -1 1.372 55.43 55.33 55.58 55.43
16 8192 128 16 -1 19.429 8 7.92 8.08 8
16 8192 64 16 -1 13.191 11.62 10.45 11.76 11.62
16 8192 32 16 -1 4.962 30.66 29.83 31.4 30.66
16 8192 16 16 -1 2.764 54.84 54.69 55.08 54.84
1 1024 128 32 -1 0.125 11.69 11.58 11.8 11.69
1 1024 64 32 -1 0.093 14.15 13.95 14.21 14.15
1 1024 32 32 -1 0.093 13.51 13.39 13.55 13.51
1 1024 16 32 -1 0.089 13.64 13.55 13.69 13.64
1 2048 128 32 -1 0.256 10.3 10.19 10.41 10.3
1 2048 64 32 -1 0.191 13.09 13.05 13.12 13.09
1 2048 32 32 -1 0.178 13.61 13.58 13.63 13.61
1 2048 16 32 -1 0.177 13.52 13.49 13.53 13.52
1 4096 128 32 -1 0.519 9.62 9.52 9.72 9.62
1 4096 64 32 -1 0.375 12.94 12.9 12.97 12.94
1 4096 32 32 -1 0.351 13.64 13.6 13.66 13.64
1 4096 16 32 -1 0.349 13.64 13.62 13.66 13.64
1 8192 128 32 -1 1.059 9.17 9.05 9.28 9.17
1 8192 64 32 -1 0.744 12.87 12.84 12.88 12.87
1 8192 32 32 -1 0.695 13.68 13.64 13.7 13.68
1 8192 16 32 -1 0.691 13.71 13.68 13.72 13.71
4 1024 128 32 -1 0.531 10.99 10.75 11.33 10.99
4 1024 64 32 -1 0.323 16.32 16.23 16.43 16.32
4 1024 32 32 -1 0.166 30.1 30 30.2 30.1
4 1024 16 32 -1 0.125 38.74 38.43 39.01 38.74
4 2048 128 32 -1 1.181 8.93 8.78 9.08 8.93
4 2048 64 32 -1 0.658 15.19 15.14 15.24 15.19
4 2048 32 32 -1 0.329 29.53 29.45 29.6 29.53
4 2048 16 32 -1 0.243 39.44 39.14 39.75 39.44
4 4096 128 32 -1 2.534 7.89 7.8 7.99 7.89
4 4096 64 32 -1 1.336 14.55 14.51 14.59 14.55
4 4096 32 32 -1 0.654 29.28 29.18 29.37 29.28
4 4096 16 32 -1 0.477 39.9 39.68 40.08 39.9
4 8192 128 32 -1 5.39 7.21 7.17 7.28 7.21
4 8192 64 32 -1 2.712 14.13 14.11 14.15 14.13
4 8192 32 32 -1 1.32 28.81 28.59 28.93 28.81
4 8192 16 32 -1 0.939 40.37 40.27 40.5 40.37
16 1024 128 32 -1 2.326 10.03 9.99 10.08 10.03
16 1024 64 32 -1 1.189 17.75 17.66 17.83 17.75
16 1024 32 32 -1 0.685 29.16 28.65 29.52 29.16
16 1024 16 32 -1 0.348 55.91 55.83 56.02 55.91
16 2048 128 32 -1 4.956 8.52 8.5 8.57 8.52
16 2048 64 32 -1 2.479 16.13 16 16.23 16.13
16 2048 32 32 -1 1.449 26.82 26.64 27.25 26.82
16 2048 16 32 -1 0.691 55.48 55.39 55.6 55.48
16 4096 128 32 -1 10.246 7.8 7.76 7.83 7.8
16 4096 64 32 -1 5.155 15.08 14.98 15.23 15.08
16 4096 32 32 -1 2.944 26.03 25.82 26.27 26.03
16 4096 16 32 -1 1.376 55.29 55.23 55.38 55.29
16 8192 128 32 -1 21.167 7.34 7.33 7.36 7.34
16 8192 64 32 -1 10.714 14.3 14.17 14.48 14.3
16 8192 32 32 -1 6.195 24.56 24.27 24.96 24.56
16 8192 16 32 -1 2.812 53.9 53.58 54.1 53.9
1 1024 128 64 -1 0.127 11.53 11.28 11.83 11.53
1 1024 64 64 -1 0.079 16.62 16.49 16.71 16.62
1 1024 32 64 -1 0.057 21.9 21.61 22.02 21.9
1 1024 16 64 -1 0.057 21.25 21.08 21.35 21.25
1 2048 128 64 -1 0.284 9.27 9.13 9.44 9.27
1 2048 64 64 -1 0.15 16.65 16.57 16.72 16.65
1 2048 32 64 -1 0.108 22.41 22.22 22.5 22.41
1 2048 16 64 -1 0.107 22.36 22.19 22.46 22.36
1 4096 128 64 -1 0.603 8.29 8.2 8.35 8.29
1 4096 64 64 -1 0.299 16.26 16.19 16.33 16.26
1 4096 32 64 -1 0.21 22.76 22.71 22.8 22.76
1 4096 16 64 -1 0.207 22.97 22.92 23.02 22.97
1 8192 128 64 -1 1.248 7.78 7.74 7.82 7.78
1 8192 64 64 -1 0.585 16.37 16.3 16.43 16.37
1 8192 32 64 -1 0.423 22.47 22.22 22.88 22.47
1 8192 16 64 -1 0.409 23.17 23.1 23.21 23.17
4 1024 128 64 -1 0.475 12.29 12.05 12.5 12.29
4 1024 64 64 -1 0.289 18.25 17.61 18.61 18.25
4 1024 32 64 -1 0.179 27.91 27.8 28.02 27.91
4 1024 16 64 -1 0.106 45.7 45.42 45.95 45.7
4 2048 128 64 -1 1.042 10.13 10 10.29 10.13
4 2048 64 64 -1 0.616 16.24 16 16.48 16.24
4 2048 32 64 -1 0.352 27.62 27.56 27.7 27.62
4 2048 16 64 -1 0.208 46.05 45.84 46.22 46.05
4 4096 128 64 -1 2.306 8.67 8.54 8.75 8.67
4 4096 64 64 -1 1.289 15.07 14.8 15.25 15.07
4 4096 32 64 -1 0.697 27.46 27.41 27.51 27.46
4 4096 16 64 -1 0.41 46.4 46.29 46.51 46.4
4 8192 128 64 -1 4.898 7.93 7.89 7.98 7.93
4 8192 64 64 -1 2.692 14.23 13.68 14.35 14.23
4 8192 32 64 -1 1.401 27.15 27.12 27.19 27.15
4 8192 16 64 -1 0.817 46.39 46.31 46.47 46.39
16 1024 128 64 -1 2.307 10.11 10.07 10.15 10.11
16 1024 64 64 -1 1.227 17.2 17.13 17.32 17.2
16 1024 32 64 -1 0.641 31.17 30.72 31.56 31.17
16 1024 16 64 -1 0.361 53.79 53.34 54.25 53.79
16 2048 128 64 -1 4.808 8.78 8.76 8.82 8.78
16 2048 64 64 -1 2.579 15.5 15.44 15.56 15.5
16 2048 32 64 -1 1.344 28.91 28.11 29.66 28.91
16 2048 16 64 -1 0.728 52.64 52.06 52.87 52.64
16 4096 128 64 -1 9.942 8.04 8.03 8.06 8.04
16 4096 64 64 -1 5.285 14.71 14.68 14.73 14.71
16 4096 32 64 -1 2.68 28.59 27.24 28.9 28.59
16 4096 16 64 -1 1.454 52.31 51.98 52.74 52.31
16 8192 128 64 -1 21.321 7.29 7.29 7.29 7.29
16 8192 64 64 -1 11.298 13.56 13.53 13.57 13.56
16 8192 32 64 -1 5.756 26.43 26.15 26.73 26.43
16 8192 16 64 -1 3.06 49.53 48.75 50.28 49.53
1 1024 128 128 -1 0.119 12.21 12.03 12.34 12.21
1 1024 64 128 -1 0.057 23.02 22.79 23.14 23.02
1 1024 32 128 -1 0.049 25.53 25.24 25.78 25.53
1 1024 16 128 -1 0.042 29.23 28.98 29.42 29.23
1 2048 128 128 -1 0.229 11.54 11.44 11.62 11.54
1 2048 64 128 -1 0.107 23.35 23.19 23.45 23.35
1 2048 32 128 -1 0.088 27.71 27.43 28.04 27.71
1 2048 16 128 -1 0.078 30.63 30.37 30.84 30.63
1 4096 128 128 -1 0.444 11.26 11.15 11.34 11.26
1 4096 64 128 -1 0.205 23.65 23.56 23.74 23.65
1 4096 32 128 -1 0.16 29.9 29.7 30.11 29.9
1 4096 16 128 -1 0.149 31.92 31.75 32.01 31.92
1 8192 128 128 -1 0.875 11.11 10.98 11.29 11.11
1 8192 64 128 -1 0.418 22.91 21 23.8 22.91
1 8192 32 128 -1 0.304 31.27 31.08 31.46 31.27
1 8192 16 128 -1 0.29 32.68 32.55 32.84 32.68
4 1024 128 128 -1 0.484 12.06 11.81 12.31 12.06
4 1024 64 128 -1 0.266 19.87 19.58 20.27 19.87
4 1024 32 128 -1 0.175 28.57 28.15 28.87 28.57
4 1024 16 128 -1 0.103 46.96 46.64 47.16 46.96
4 2048 128 128 -1 1.003 10.52 10.33 10.85 10.52
4 2048 64 128 -1 0.538 18.59 18.33 19 18.59
4 2048 32 128 -1 0.341 28.46 28.03 28.79 28.46
4 2048 16 128 -1 0.199 48.05 47.87 48.18 48.05
4 4096 128 128 -1 2.09 9.56 9.45 9.84 9.56
4 4096 64 128 -1 1.131 17.18 16.98 17.5 17.18
4 4096 32 128 -1 0.689 27.78 27.36 28.13 27.78
4 4096 16 128 -1 0.397 47.93 47.85 48 47.93
4 8192 128 128 -1 4.43 8.77 8.66 9.01 8.77
4 8192 64 128 -1 2.494 15.36 15.21 15.65 15.36
4 8192 32 128 -1 1.435 26.49 26.21 26.8 26.49
4 8192 16 128 -1 0.789 48.03 47.92 48.09 48.03
16 1024 128 128 -1 2.084 11.19 11 11.41 11.19
16 1024 64 128 -1 1.038 20.32 19.88 20.7 20.32
16 1024 32 128 -1 0.547 36.56 35.98 37.25 36.56
16 1024 16 128 -1 0.339 57.32 56.54 58 57.32
16 2048 128 128 -1 4.459 9.47 9.26 9.63 9.47
16 2048 64 128 -1 2.177 18.36 18.15 18.64 18.36
16 2048 32 128 -1 1.14 34.1 33.35 34.67 34.1
16 2048 16 128 -1 0.688 55.65 54.86 56.18 55.65
16 4096 128 128 -1 9.675 8.26 8.15 8.29 8.26
16 4096 64 128 -1 4.686 16.59 16.39 17.04 16.59
16 4096 32 128 -1 2.396 31.97 31.42 32.51 31.97
16 4096 16 128 -1 1.416 53.7 53.08 54.23 53.7
16 8192 128 128 -1 22.345 6.96 6.92 7 6.96
16 8192 64 128 -1 13.546 11.31 11.15 11.6 11.31
16 8192 32 128 -1 6.924 21.97 21.39 25.66 21.97
16 8192 16 128 -1 3.202 47.33 46.3 47.99 47.33

@pralay-das pralay-das marked this pull request as ready for review March 18, 2026 09:20
Copilot AI review requested due to automatic review settings March 18, 2026 09:20
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an Intel XPU (BMG)-targeted CUTLASS-style Multi-head Latent Attention (MLA) decode kernel and wires it through the C++/PyTorch extension, Python API, tests, and a benchmark script.

Changes:

  • Introduces a new SYCL/CUTLASS-based MLA decode kernel (mainloop/epilogue/kernel + tile scheduler + runner) and a PyTorch C++ entrypoint.
  • Updates the Python API to accept split query inputs (q_nope, q_pe) and exposes the new cutlass_mla_decode / cutlass_mla_get_workspace_size ops.
  • Updates tests and benchmark to run on XPU and match the new API.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
src/sycl/mla_decode.cpp Adds the MLA decode PyTorch C++ interface and dispatch (dtype/page-size).
src/sycl/kernels/mla/xe_mla_mainloop.hpp Implements the MLA mainloop: QK, online softmax, and PV accumulation.
src/sycl/kernels/mla/xe_mla_epilogue.hpp Implements epilogue reduction/normalization and output writeback.
src/sycl/kernels/mla/xe_mla_kernel.hpp Orchestrates mainloop + epilogue and constructs tensors/tiles.
src/sycl/kernels/mla/mla_tile_scheduler.hpp Provides workgroup-to-tile mapping for MLA decode.
src/sycl/kernels/mla/mla_runner.hpp Device-layer wrapper to launch the MLA kernel via SYCL launch APIs.
src/sycl/kernels/mla/copy_block_slm.hpp Adds SLM copy helpers used by the epilogue reduction path.
src/torch_extension_sycl.cc Registers cutlass_mla_decode and cutlass_mla_get_workspace_size with torch.ops.
include/sgl_flash_kernel_ops.h Declares the new MLA public C++ entrypoints.
include/sgl_kernel_ops.h Removes MLA declarations from the non-flash ops header.
src/sycl/Utils.h Adds a CUTLASS_CHECK helper macro used by the new kernel path.
python/sgl_kernel/attention.py Updates Python wrapper to new MLA API (q_nope, q_pe, sm_scale).
tests/test_cutlass_mla.py Updates MLA test to be XPU-only and validate XPU kernel vs CPU reference.
benchmark/bench_cutlass_mla.py Updates benchmark for XPU and adds result collection/plotting utilities.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread src/sycl/mla_decode.cpp Outdated
Comment thread src/sycl/mla_decode.cpp
Comment thread tests/test_cutlass_mla.py Outdated
Comment thread src/sycl/kernels/mla/collective/xe_mla_epilogue.hpp
Comment thread python/sgl_kernel/attention.py
Comment thread python/sgl_kernel/attention.py
Comment thread python/sgl_kernel/attention.py
Copy link
Copy Markdown
Collaborator

@kareemshaik80 kareemshaik80 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the current performance benchmarking results to the description is possible.

Copy link
Copy Markdown
Collaborator

@airMeng airMeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems without template instantiations into per-kernel compilation units like #140, the compilation time is much longer.

This PR CI costs 36min, while the previous PR costs ~22min

@pralay-das
Copy link
Copy Markdown
Collaborator Author

Please add the current performance benchmarking results to the description is possible.

updated.

@mingfeima
Copy link
Copy Markdown
Collaborator

Before we go into details, two things to focus right now:

  • interface: MLA is integrated into sglang in two separated stages: forward_prepare and forward_core. prepare handles the weight absorption and core is just a normal GQA process, only that the K V are same buffer.
  • perf evaluation: we can simplify the benchmark tuning at the moment. You can simply test only BS=1, 4, 16 with seqlen_kv = 1K, 2K, 4K and 8K. you can compare the performance with a standard GQA (KV different buffer) and compare the performance.

@pralay-das
Copy link
Copy Markdown
Collaborator Author

pralay-das commented Mar 27, 2026

Seems without template instantiations into per-kernel compilation units like #140, the compilation time is much longer.

This PR CI costs 36min, while the previous PR costs ~22min

hi, I have updated the changes, now CI cost is ~24 min.

@pralay-das
Copy link
Copy Markdown
Collaborator Author

Before we go into details, two things to focus right now:

  • interface: MLA is integrated into sglang in two separated stages: forward_prepare and forward_core. prepare handles the weight absorption and core is just a normal GQA process, only that the K V are same buffer.
  • perf evaluation: we can simplify the benchmark tuning at the moment. You can simply test only BS=1, 4, 16 with seqlen_kv = 1K, 2K, 4K and 8K. you can compare the performance with a standard GQA (KV different buffer) and compare the performance.

hi, I agree, one thing to note here though the core of MLA is a standard GQA, but currently our mha_fwd has a limitation, it can support up to 256 head-dim, but for our MLA has head dim as 576 and it is the standard dim, so directly we can't compare.

@mkumargarg
Copy link
Copy Markdown

Hi @mingfeima @pralay-das @kareemshaik80 @airMeng as the critical comments related to CI timing and performance numbers have been addressed so could we please further merge this PR. Pralay is working on more optimizations and those would come soon in follow-up PRs. We would further work to bring the best performance however at the same time it is important to merge this PR as it is pending in review for quite long. This would also enable QA to do exhaustive functional testing and topology team to integrate. Thanks.

Comment thread src/sycl/kernels/mla/copy_block_slm.hpp Outdated
@pralay-das pralay-das requested a review from airMeng March 28, 2026 04:35
Copy link
Copy Markdown
Collaborator

@airMeng airMeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but @sunjiweiswift please review ASAP

Comment thread tests/test_cutlass_mla.py Outdated
Comment on lines +132 to +133
out = cutlass_mla_decode(
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
q_nope,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this kernel support nope fusion? we need another SGLang PR to enable this, how much benefit we get from the fusion?

Comment thread tests/test_cutlass_mla.py Outdated
Comment thread src/sycl/kernels/mla/kernel/xe_mla_kernel.hpp
Comment thread benchmark/bench_cutlass_mla.py Outdated
@mingfeima
Copy link
Copy Markdown
Collaborator

Before we go into details, two things to focus right now:

  • interface: MLA is integrated into sglang in two separated stages: forward_prepare and forward_core. prepare handles the weight absorption and core is just a normal GQA process, only that the K V are same buffer.
  • perf evaluation: we can simplify the benchmark tuning at the moment. You can simply test only BS=1, 4, 16 with seqlen_kv = 1K, 2K, 4K and 8K. you can compare the performance with a standard GQA (KV different buffer) and compare the performance.

hi, I agree, one thing to note here though the core of MLA is a standard GQA, but currently our mha_fwd has a limitation, it can support up to 256 head-dim, but for our MLA has head dim as 576 and it is the standard dim, so directly we can't compare.

I prefer hack the GQA code a little bit to let it be able to run 576, even this may lead to sub optimal perf. Still we need to compare MLA v.s. GQA as a baseline reference.

Comment thread src/sycl/kernels/mla/mla_runner.hpp Outdated
Comment thread src/sycl/mla_decode.cpp Outdated
Comment thread python/sgl_kernel/attention.py Outdated
Comment thread src/sycl/kernels/mla/mla_decode_dispatch.hpp Outdated
Comment thread src/sycl/kernels/mla/mla_decode_types.hpp Outdated
Comment thread src/sycl/kernels/mla/collective/xe_mla_epilogue.hpp
Comment thread src/sycl/kernels/mla/device/mla_runner.hpp
Comment thread src/sycl/kernels/mla/device/mla_decode_types.hpp
Comment thread src/sycl/kernels/mla/device/mla_decode_types.hpp
@mingfeima
Copy link
Copy Markdown
Collaborator

mingfeima commented Mar 30, 2026

let's align the interface a little bit.

MLA in sglang has a very complexed dispatch logic on CUDA side, this comes with a historical reason, which intel not necessarily follow.

  • (MLA, Hopper CUDA12.3) -> fa3 -> FlashAttentionBackend
  • (MLA, SM100) -> flashinfer -> FlashInferMLAAttnBackend
  • (DeepSeek V3/R1/V3.1 on SM100) -> trtllm_mla -> TRTLLMMLABackend
  • (MLA, diff CUDA version) -> triton -> TritonAttnBackend /DoubleSparseAttnBackend
  • flashinfer + MLA -> FlashInferMLAAttnBackend
  • trtllm_mla -> TRTLLMMLABackend
  • flashmla -> FlashMLABackend
  • cutlass_mla -> CutlassMLABackend
  • fa4 + MLA -> FlashAttentionBackend(fa_impl_ver=4)
  • aiter + MLA -> AiterAttnBackend
  • ascend + MLA -> AscendAttnBackend
  • nsa + MLA -> NativeSparseAttnBackend

First of all, we decided that we use something one for all, just like aiter and ascend, which is our XPUAttnBackend. Secondly, it is essentially a mapping of fa3 from implementation level.

So make sure that the inferface aligns with existing XPUAttnBackend.

mapping to cutlass_mla is not a good option, given that we already followed fa3. that is to say, still use

flash_attn_varlen_func(...)
flash_attn_with_kvcache(...)

and manage the workspace in a similar approach as current GQA split kv path.

@mkumargarg
Copy link
Copy Markdown

mkumargarg commented Mar 30, 2026

let's align the interface a little bit.

MLA in sglang has a very complexed dispatch logic on CUDA side, this comes with a historical reason, which intel not necessarily follow.

  • (MLA, Hopper CUDA12.3) -> fa3 -> FlashAttentionBackend
  • (MLA, SM100) -> flashinfer -> FlashInferMLAAttnBackend
  • (DeepSeek V3/R1/V3.1 on SM100) -> trtllm_mla -> TRTLLMMLABackend
  • (MLA, diff CUDA version) -> triton -> TritonAttnBackend /DoubleSparseAttnBackend
  • flashinfer + MLA -> FlashInferMLAAttnBackend
  • trtllm_mla -> TRTLLMMLABackend
  • flashmla -> FlashMLABackend
  • cutlass_mla -> CutlassMLABackend
  • fa4 + MLA -> FlashAttentionBackend(fa_impl_ver=4)
  • aiter + MLA -> AiterAttnBackend
  • ascend + MLA -> AscendAttnBackend
  • nsa + MLA -> NativeSparseAttnBackend

First of all, we decided that we use something one for all, just like aiter and ascend, which is our XPUAttnBackend. Secondly, it is essentially a mapping of fa3 from implementation level.

So make sure that the inferface aligns with existing XPUAttnBackend.

mapping to cutlass_mla is not a good option, given that we already followed fa3. that is to say, still use

flash_attn_varlen_func(...)
flash_attn_with_kvcache(...)

and manage the workspace in a similar approach as current GQA split kv path.

Hello @mingfeima for any decisions which we have taken and also the approach being followed here in this PR, I would suggest if you could align with @pralay-das. Thanks.

@airMeng
Copy link
Copy Markdown
Collaborator

airMeng commented Mar 31, 2026

@pralay-das we have an auto performance monitor

- name: Run Sglang Kernel Benchmarks
and
- name: Auto PR for baseline.json update
to track performance per PR, currently only tracking MoE, FA later, Would you like to add MLA into monitor? The performance track will be like #128

@sunjiweiswift
Copy link
Copy Markdown
Collaborator

I believe the current benchmark is low. It should be raised to 350G.

@pralay-das
Copy link
Copy Markdown
Collaborator Author

I believe the current benchmark is low. It should be raised to 350G.

Hi, I agree, I am working on it, all performance related fixes will come with follow up PR.

@mingfeima
Copy link
Copy Markdown
Collaborator

mingfeima commented Apr 2, 2026

I believe the current benchmark is low. It should be raised to 350G.

Hi, I agree, I am working on it, all performance related fixes will come with follow up PR.

any reason the perf is low right now? any difficulty that the perf issue can not be fixed right now in this PR?

@pralay-das
Copy link
Copy Markdown
Collaborator Author

pralay-das commented Apr 2, 2026

culty that the perf issue can not be fixed right now in this PR?

Hi, I am not sure, where is the bottleneck. Few changes I did but it is not helping that much. I am working on it, I need some more time and would provide changes in follow-up PRs.

@airMeng
Copy link
Copy Markdown
Collaborator

airMeng commented Apr 3, 2026

@mingfeima @sunjiweiswift shall me merge with current performance?

@mkumargarg
Copy link
Copy Markdown

@mingfeima @sunjiweiswift shall me merge with current performance?

Hello @airMeng @mingfeima @sunjiweiswift I would request to merge this PR as the changes are big and it would unblock exhaustive QA validation and also the integration to topologies, as well as help us avoid logistic issues like rebase/resolve conflicts etc. As these are initial changes so span is big across the code base so better to merge them. we are further working on performance optimizations and those would definitely come in future PRs. Thanks.
cc: @pralay-das

Comment thread benchmark/bench_cutlass_mla.py Outdated
@mingfeima
Copy link
Copy Markdown
Collaborator

@mingfeima @sunjiweiswift shall me merge with current performance?

Hello @airMeng @mingfeima @sunjiweiswift I would request to merge this PR as the changes are big and it would unblock exhaustive QA validation and also the integration to topologies, as well as help us avoid logistic issues like rebase/resolve conflicts etc. As these are initial changes so span is big across the code base so better to merge them. we are further working on performance optimizations and those would definitely come in future PRs. Thanks. cc: @pralay-das

NP. Usually i won't recommend to land a kernel with perf still having big issue. Anyway if this is your preference, we can land this one and please continue working on the perf optimization.

Please fix the naming issue, don't use cutlass_xxx, intel doesn't really have cutlass :)

@mingfeima
Copy link
Copy Markdown
Collaborator

mingfeima commented Apr 7, 2026

@mkumargarg you may merge this one if this is your preference. Please continue with the optimization efforts to improve kernel performance, we especially care decoding perf when input sequence length is 3500 and output length is 1500.

additionally, please update the deepseek R1 attention shape in the benchmark, which has num_heads of 22, USE WELL for the magic number. We need to parallel on this dimension, and we need different policy with regard to different num_requests or batch size, e.g. 1*22, 2*11, 4*6(with padding), etc.

@airMeng airMeng merged commit a1e6c7a into sgl-project:main Apr 7, 2026
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants