Implement FlashDecoding++ async softmax for split-K SDPA#18867
Implement FlashDecoding++ async softmax for split-K SDPA#18867Gasoonjia wants to merge 6 commits intocuda-graph-samplingfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18867
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 6 New Failures, 1 Pending, 4 Unrelated FailuresAs of commit 1a79d9d with merge base c48ea12 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1af2029 to
d5209fc
Compare
This PR needs a
|
There was a problem hiding this comment.
TIL FlashDecoding++, thanks.
Couple of things, seems like this is still online softmax (i.e you need to do reduce later). And can you add performance numbers with a sweep in the PR summary, thanks.
Also ask claude to do a quality comparisons using various prompts just to spotcheck quality.
|
@claude review this PR. |
|
Claude finished @digantdesai's task in 1m 17s —— View job Review of PR #18867
This PR makes several simplifying changes to the split-K decode kernel. Here's my analysis: Summary of Changes
Issues & Questions1. In the previous version, 2. The previous version used 3. The old heuristic targeted 2 full SM waves and adapted to the actual GPU (A100 vs RTX 4090 have different SM counts). The new 4. This was a minor optimization (avoiding an integer division in the kernel). Since 5. Reduce kernel Removing Minor Observations
OverallThe core FlashDecoding++ logic (fixed phi, simple sum in reduce) is correct and unchanged from the prior commit. The main changes are simplification/tuning tradeoffs. The buffer caching is a good optimization for decode latency. The |
9ced5cc to
b2fbad0
Compare
Replace online softmax (per-tile max tracking + cross-split rescaling) with a unified maximum value (phi=5.0) approach from FlashDecoding++. Key changes: - Split kernel: subtract fixed phi instead of tracking running max m_i, eliminating alpha rescaling between tiles - Reduce kernel: simple summation of partial outputs instead of max-aware weighted combination; removes M_partial buffer - ~12.9% average kernel-level speedup (6.8%-20.1% range) by saving HBM bandwidth (no M_partial reads/writes) and reducing ALU ops The unified phi works because exp(qk - phi) is numerically stable for typical attention score ranges, and the fixed constant allows all splits to compute independently without synchronization.
Keep only sdpa.py changes on this branch; revert all other files (aoti_delegate_handle.h, benchmark_sdpa.py, cuda_backend.cpp, main.cpp, model.py) to their main branch state.
b2fbad0 to
39589ae
Compare
d3bca0d to
5245f64
Compare
| safe_diff = tl.where( | ||
| m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") | ||
| ) | ||
| # FlashDecoding++ async softmax: subtract unified phi instead of local max |
There was a problem hiding this comment.
@digantdesai here we replace the online softmax with async softmax by using a unified phi.
Replace online softmax (per-tile max tracking + cross-split rescaling) with a unified maximum value (phi=5.0) approach from FlashDecoding++.
Key changes:
The unified phi works because exp(qk - phi) is numerically stable for typical attention score ranges, and the fixed constant allows all splits to compute independently without synchronization.
Also used KernelAgent(https://github.com/meta-pytorch/KernelAgent) to further optimized the kernel.
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell