Skip to content

Commit 1bd74b3

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][sdpa] Use numerically-stable softmax in attention weights
The SDPA attention weights softmax shader computed naive softmax: exp(x) / sum(exp(x)). When attention weights are large (e.g., 151.29 for Phi-4-mini with head_dim=128), exp(x) overflows float32 (threshold ~88.7), producing Infinity and then NaN from inf/inf in the normalization step. This replaces the naive softmax with the standard numerically-stable variant: exp(x - max(x)) / sum(exp(x - max(x))). The implementation adds a cooperative max-finding pass (same workgroup reduction pattern as the existing exp_sum pass) before the exp_sum and normalization passes. The max subtraction ensures that the largest exponent is 0, preventing overflow. This fixes Phi-4-mini Vulkan inference which previously produced garbage output due to NaN propagation from the first transformer layer's attention. On-device A/B benchmarks on Samsung Galaxy S24 (Adreno 750) with Llama 3.2 1B (8da4w g128 q4emb, 677 MB) confirm no performance regression: Llama 3.2 1B (short prompt, 4 tokens, --warmup): Prefill: 67.2 tok/s | Decode: 59.4 tok/s | TTFT: 60 ms Llama 3.2 1B (medium prompt, 197 tokens, --warmup): Prefill: 723.5 tok/s | Decode: 53.3 tok/s | TTFT: 273 ms These numbers are within run-to-run variance of the baseline (no fix) measurements, confirming the additional max-finding pass has negligible overhead. Differential Revision: [D97757920](https://our.internmc.facebook.com/intern/diff/D97757920/) ghstack-source-id: 356136427 Pull Request resolved: #18407
1 parent dc084a9 commit 1bd74b3

1 file changed

Lines changed: 64 additions & 18 deletions

File tree

backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ ${layout_declare_ubo(B, "int", "input_pos")}
3232

3333
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434

35-
// Shared memory for cooperative exp sum finding
35+
// Shared memory for cooperative max finding and exp sum reduction
36+
shared T shared_max[NUM_WORKERS_PER_WG];
3637
shared T shared_exp_sum[NUM_WORKERS_PER_WG];
3738

3839
VEC4_T load_attn_weights_c4(
@@ -87,24 +88,24 @@ void main() {
8788
return;
8889
}
8990

90-
// Initialize thread-local min/max
91-
T local_exp_sum = T(0);
92-
9391
const int context_len_aligned_down = context_len - mod_4(context_len);
9492
const int C4_limit = div_4(context_len_aligned_down);
9593

96-
// Each thread processes elements along a context_len row with a stride of the
97-
// number of threads in the work group.
94+
// =========================================================================
95+
// Pass 1: Find the maximum value across the row for numerical stability.
96+
// Without this, exp(x) can overflow float32 when x > ~88.7.
97+
// =========================================================================
98+
99+
T local_max = T(-1.0 / 0.0); // -infinity
100+
98101
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
99102
VEC4_T in_texel = load_attn_weights_c4(
100103
c4, s, q_h, context_texel_len, S_aligned, Q_H);
101104

102105
for (int comp = 0; comp < 4; comp++) {
103-
local_exp_sum += exp(in_texel[comp]);
106+
local_max = max(local_max, in_texel[comp]);
104107
}
105108
}
106-
// First thread in the work group responsible for handling last texel if it
107-
// contains any padded elements
108109
if (worker_id == 0) {
109110
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
110111
const int c_base = mul_4(c4);
@@ -113,19 +114,63 @@ void main() {
113114

114115
[[unroll]] for (int comp = 0; comp < 4; comp++) {
115116
if (c_base + comp < context_len) {
116-
local_exp_sum += exp(in_texel[comp]);
117+
local_max = max(local_max, in_texel[comp]);
118+
}
119+
}
120+
}
121+
}
122+
123+
shared_max[worker_id] = local_max;
124+
125+
memoryBarrierShared();
126+
barrier();
127+
128+
// Tree reduction to find the global max
129+
for (int i = NUM_WORKERS_PER_WG / 2; i > 0; i >>= 1) {
130+
if (worker_id < i) {
131+
shared_max[worker_id] = max(
132+
shared_max[worker_id], shared_max[worker_id + i]);
133+
}
134+
memoryBarrierShared();
135+
barrier();
136+
}
137+
138+
const T global_max = shared_max[0];
139+
140+
// =========================================================================
141+
// Pass 2: Compute sum(exp(x - max)) using the global max for stability
142+
// =========================================================================
143+
144+
T local_exp_sum = T(0);
145+
146+
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
147+
VEC4_T in_texel = load_attn_weights_c4(
148+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
149+
150+
for (int comp = 0; comp < 4; comp++) {
151+
local_exp_sum += exp(in_texel[comp] - global_max);
152+
}
153+
}
154+
if (worker_id == 0) {
155+
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
156+
const int c_base = mul_4(c4);
157+
VEC4_T in_texel = load_attn_weights_c4(
158+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
159+
160+
[[unroll]] for (int comp = 0; comp < 4; comp++) {
161+
if (c_base + comp < context_len) {
162+
local_exp_sum += exp(in_texel[comp] - global_max);
117163
}
118164
}
119165
}
120166
}
121167

122-
// Store thread-local results in shared memory
123168
shared_exp_sum[worker_id] = local_exp_sum;
124169

125170
memoryBarrierShared();
126171
barrier();
127172

128-
// Tree reduction to compute the overall result
173+
// Tree reduction to compute the overall exp sum
129174
for (int i = NUM_WORKERS_PER_WG / 2; i > 0; i >>= 1) {
130175
if (worker_id < i) {
131176
shared_exp_sum[worker_id] = shared_exp_sum[worker_id] +
@@ -136,28 +181,29 @@ void main() {
136181
}
137182

138183
local_exp_sum = shared_exp_sum[0];
139-
// Now go back through each element in the row and normalize
184+
185+
// =========================================================================
186+
// Pass 3: Normalize each element: out = exp(x - max) / sum(exp(x - max))
187+
// =========================================================================
188+
140189
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
141190
VEC4_T in_texel = load_attn_weights_c4(
142191
c4, s, q_h, context_texel_len, S_aligned, Q_H);
143192

144-
VEC4_T out_texel = exp(in_texel) / local_exp_sum;
193+
VEC4_T out_texel = exp(in_texel - global_max) / local_exp_sum;
145194
store_attn_weights_softmax_c4(
146195
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
147196
}
148-
// First thread in the work group responsible for handling last texel if it
149-
// contains any padded elements
150197
if (worker_id == 0) {
151198
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
152199
const int c_base = mul_4(c4);
153200
VEC4_T in_texel = load_attn_weights_c4(
154201
c4, s, q_h, context_texel_len, S_aligned, Q_H);
155202

156-
// Ensure that padding elements are set to 0.
157203
VEC4_T out_texel = VEC4_T(0);
158204
[[unroll]] for (int comp = 0; comp < 4; comp++) {
159205
if (c_base + comp < context_len) {
160-
out_texel[comp] = exp(in_texel[comp]) / local_exp_sum;
206+
out_texel[comp] = exp(in_texel[comp] - global_max) / local_exp_sum;
161207
}
162208
}
163209
store_attn_weights_softmax_c4(

0 commit comments

Comments
 (0)