[ET-VK][sdpa] Use numerically-stable softmax in attention weights#18407
[ET-VK][sdpa] Use numerically-stable softmax in attention weights#18407meta-codesync[bot] merged 1 commit intogh/SS-JIA/500/basefrom
Conversation
The SDPA attention weights softmax shader computed naive softmax:
exp(x) / sum(exp(x)). When attention weights are large (e.g., 151.29 for
Phi-4-mini with head_dim=128), exp(x) overflows float32 (threshold ~88.7),
producing Infinity and then NaN from inf/inf in the normalization step.
This replaces the naive softmax with the standard numerically-stable variant:
exp(x - max(x)) / sum(exp(x - max(x))). The implementation adds a cooperative
max-finding pass (same workgroup reduction pattern as the existing exp_sum pass)
before the exp_sum and normalization passes. The max subtraction ensures that the
largest exponent is 0, preventing overflow.
This fixes Phi-4-mini Vulkan inference which previously produced garbage output
due to NaN propagation from the first transformer layer's attention.
On-device A/B benchmarks on Samsung Galaxy S24 (Adreno 750) with Llama 3.2 1B
(8da4w g128 q4emb, 677 MB) confirm no performance regression:
Llama 3.2 1B (short prompt, 4 tokens, --warmup):
Prefill: 67.2 tok/s | Decode: 59.4 tok/s | TTFT: 60 ms
Llama 3.2 1B (medium prompt, 197 tokens, --warmup):
Prefill: 723.5 tok/s | Decode: 53.3 tok/s | TTFT: 273 ms
These numbers are within run-to-run variance of the baseline (no fix) measurements,
confirming the additional max-finding pass has negligible overhead.
Differential Revision: [D97757920](https://our.internmc.facebook.com/intern/diff/D97757920/)
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18407
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 2 Unrelated FailuresAs of commit 4f2bffe with merge base 60d57e5 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
3ba9dec
into
gh/SS-JIA/500/base
The SDPA attention weights softmax shader computed naive softmax:
exp(x) / sum(exp(x)). When attention weights are large (e.g., 151.29 for
Phi-4-mini with head_dim=128), exp(x) overflows float32 (threshold ~88.7),
producing Infinity and then NaN from inf/inf in the normalization step.
This replaces the naive softmax with the standard numerically-stable variant:
exp(x - max(x)) / sum(exp(x - max(x))). The implementation adds a cooperative
max-finding pass (same workgroup reduction pattern as the existing exp_sum pass)
before the exp_sum and normalization passes. The max subtraction ensures that the
largest exponent is 0, preventing overflow.
This fixes Phi-4-mini Vulkan inference which previously produced garbage output
due to NaN propagation from the first transformer layer's attention.
On-device A/B benchmarks on Samsung Galaxy S24 (Adreno 750) with Llama 3.2 1B
(8da4w g128 q4emb, 677 MB) confirm no performance regression:
Llama 3.2 1B (short prompt, 4 tokens, --warmup):
Prefill: 67.2 tok/s | Decode: 59.4 tok/s | TTFT: 60 ms
Llama 3.2 1B (medium prompt, 197 tokens, --warmup):
Prefill: 723.5 tok/s | Decode: 53.3 tok/s | TTFT: 273 ms
These numbers are within run-to-run variance of the baseline (no fix) measurements,
confirming the additional max-finding pass has negligible overhead.
Differential Revision: [D97757920](https://our.internmc.facebook.com/intern/diff/D97757920/)
ghstack-source-id: 356136427
Pull Request resolved: #18407
Stack from ghstack (oldest at bottom):
The SDPA attention weights softmax shader computed naive softmax:
exp(x) / sum(exp(x)). When attention weights are large (e.g., 151.29 for
Phi-4-mini with head_dim=128), exp(x) overflows float32 (threshold ~88.7),
producing Infinity and then NaN from inf/inf in the normalization step.
This replaces the naive softmax with the standard numerically-stable variant:
exp(x - max(x)) / sum(exp(x - max(x))). The implementation adds a cooperative
max-finding pass (same workgroup reduction pattern as the existing exp_sum pass)
before the exp_sum and normalization passes. The max subtraction ensures that the
largest exponent is 0, preventing overflow.
This fixes Phi-4-mini Vulkan inference which previously produced garbage output
due to NaN propagation from the first transformer layer's attention.
On-device A/B benchmarks on Samsung Galaxy S24 (Adreno 750) with Llama 3.2 1B
(8da4w g128 q4emb, 677 MB) confirm no performance regression:
Llama 3.2 1B (short prompt, 4 tokens, --warmup):
Prefill: 67.2 tok/s | Decode: 59.4 tok/s | TTFT: 60 ms
Llama 3.2 1B (medium prompt, 197 tokens, --warmup):
Prefill: 723.5 tok/s | Decode: 53.3 tok/s | TTFT: 273 ms
These numbers are within run-to-run variance of the baseline (no fix) measurements,
confirming the additional max-finding pass has negligible overhead.
Differential Revision: D97757920