Commit 1bd74b3
[ET-VK][sdpa] Use numerically-stable softmax in attention weights
The SDPA attention weights softmax shader computed naive softmax:
exp(x) / sum(exp(x)). When attention weights are large (e.g., 151.29 for
Phi-4-mini with head_dim=128), exp(x) overflows float32 (threshold ~88.7),
producing Infinity and then NaN from inf/inf in the normalization step.
This replaces the naive softmax with the standard numerically-stable variant:
exp(x - max(x)) / sum(exp(x - max(x))). The implementation adds a cooperative
max-finding pass (same workgroup reduction pattern as the existing exp_sum pass)
before the exp_sum and normalization passes. The max subtraction ensures that the
largest exponent is 0, preventing overflow.
This fixes Phi-4-mini Vulkan inference which previously produced garbage output
due to NaN propagation from the first transformer layer's attention.
On-device A/B benchmarks on Samsung Galaxy S24 (Adreno 750) with Llama 3.2 1B
(8da4w g128 q4emb, 677 MB) confirm no performance regression:
Llama 3.2 1B (short prompt, 4 tokens, --warmup):
Prefill: 67.2 tok/s | Decode: 59.4 tok/s | TTFT: 60 ms
Llama 3.2 1B (medium prompt, 197 tokens, --warmup):
Prefill: 723.5 tok/s | Decode: 53.3 tok/s | TTFT: 273 ms
These numbers are within run-to-run variance of the baseline (no fix) measurements,
confirming the additional max-finding pass has negligible overhead.
Differential Revision: [D97757920](https://our.internmc.facebook.com/intern/diff/D97757920/)
ghstack-source-id: 356136427
Pull Request resolved: #184071 parent dc084a9 commit 1bd74b3
1 file changed
Lines changed: 64 additions & 18 deletions
Lines changed: 64 additions & 18 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | | - | |
| 35 | + | |
| 36 | + | |
36 | 37 | | |
37 | 38 | | |
38 | 39 | | |
| |||
87 | 88 | | |
88 | 89 | | |
89 | 90 | | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | 91 | | |
94 | 92 | | |
95 | 93 | | |
96 | | - | |
97 | | - | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
98 | 101 | | |
99 | 102 | | |
100 | 103 | | |
101 | 104 | | |
102 | 105 | | |
103 | | - | |
| 106 | + | |
104 | 107 | | |
105 | 108 | | |
106 | | - | |
107 | | - | |
108 | 109 | | |
109 | 110 | | |
110 | 111 | | |
| |||
113 | 114 | | |
114 | 115 | | |
115 | 116 | | |
116 | | - | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
117 | 163 | | |
118 | 164 | | |
119 | 165 | | |
120 | 166 | | |
121 | 167 | | |
122 | | - | |
123 | 168 | | |
124 | 169 | | |
125 | 170 | | |
126 | 171 | | |
127 | 172 | | |
128 | | - | |
| 173 | + | |
129 | 174 | | |
130 | 175 | | |
131 | 176 | | |
| |||
136 | 181 | | |
137 | 182 | | |
138 | 183 | | |
139 | | - | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
140 | 189 | | |
141 | 190 | | |
142 | 191 | | |
143 | 192 | | |
144 | | - | |
| 193 | + | |
145 | 194 | | |
146 | 195 | | |
147 | 196 | | |
148 | | - | |
149 | | - | |
150 | 197 | | |
151 | 198 | | |
152 | 199 | | |
153 | 200 | | |
154 | 201 | | |
155 | 202 | | |
156 | | - | |
157 | 203 | | |
158 | 204 | | |
159 | 205 | | |
160 | | - | |
| 206 | + | |
161 | 207 | | |
162 | 208 | | |
163 | 209 | | |
| |||
0 commit comments