Skip to content

Commit a2abca6

Browse files
committed
Attention Perf: Transpose blocked K right before QK instead of pre-transposing before the kernel
stack-info: PR: #2374, branch: AmesingFlank/stack/50
1 parent b14c444 commit a2abca6

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

examples/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def attention(
6161
assert head_dim == k_in.size(-1) == v_in.size(-1)
6262
q_view = q_in.reshape([-1, m_dim, head_dim])
6363
v_view = v_in.reshape([-1, n_dim, head_dim])
64-
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
64+
k_view = k_in.reshape([-1, n_dim, head_dim])
6565
out = torch.empty_like(q_view)
6666
sm_scale = 1.0 / math.sqrt(head_dim)
6767
qk_scale = sm_scale * 1.44269504 # 1/log(2)
@@ -73,10 +73,10 @@ def attention(
7373
for tile_n in hl.tile(v_view.size(1)):
7474
# scaling Q in-loop on-demand reduces spillage, faster than keeping pre-scaled Q
7575
q_scaled = q * qk_scale
76-
k = k_view[tile_b, :, tile_n]
76+
k = k_view[tile_b, tile_n, :]
7777
# Keep scores in fp32 to match SDPA tolerances on bf16/fp16 inputs.
7878
# same as hl.dot(q, k, out_dtype=torch.float32)
79-
qk = torch.bmm(q_scaled, k, torch.float32)
79+
qk = torch.bmm(q_scaled, k.transpose(1, 2), torch.float32)
8080
m_ij = torch.maximum(m_i, torch.amax(qk, -1))
8181
qk = qk - m_ij[:, :, None]
8282
p = torch.exp2(qk)

test/test_pallas.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def pallas_attention(
306306
head_dim = hl.specialize(q_in.size(-1))
307307
assert head_dim == k_in.size(-1) == v_in.size(-1)
308308
q_view = q_in.reshape([-1, m_dim, head_dim])
309-
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
309+
k_view = k_in.reshape([-1, n_dim, head_dim])
310310
v_view = v_in.reshape([-1, n_dim, head_dim])
311311
out = torch.empty_like(q_view)
312312
sm_scale = 1.0 / math.sqrt(head_dim)
@@ -319,9 +319,10 @@ def pallas_attention(
319319
for tile_n in hl.tile(v_view.size(1)):
320320
# scaling Q in-loop on-demand reduces spillage, faster than keeping pre-scaled Q
321321
q_scaled = q * qk_scale
322-
k = k_view[tile_b, :, tile_n]
322+
k = k_view[tile_b, tile_n, :]
323+
# Keep scores in fp32 to match SDPA tolerances on bf16/fp16 inputs.
323324
# same as hl.dot(q, k, out_dtype=torch.float32)
324-
qk = torch.bmm(q_scaled, k, torch.float32)
325+
qk = torch.bmm(q_scaled, k.transpose(1, 2), torch.float32)
325326
m_ij = torch.maximum(m_i, torch.amax(qk, -1))
326327
qk = qk - m_ij[:, :, None]
327328
p = torch.exp2(qk)

0 commit comments

Comments
 (0)