Skip to content

Commit d18e73f

Browse files
Ziminlivoltjia
authored andcommitted
Fix the Triton implementation in scaled_dot_product_attention.py
1 parent 95499b9 commit d18e73f

2 files changed

Lines changed: 14 additions & 13 deletions

File tree

ops/triton/kernels/scaled_dot_product_attention.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def kernel(
4848
o_stride_m,
4949
o_stride_n,
5050
scale,
51-
seq_len,
51+
seq_len_q,
52+
seq_len_k_v,
5253
EMB_DIM: tl.constexpr,
5354
BLOCK_SIZE_M: tl.constexpr,
5455
BLOCK_SIZE_N: tl.constexpr,
@@ -62,7 +63,7 @@ def kernel(
6263
q_off = off_z * q_stride_z + off_h * q_stride_h
6364
q_block_ptr = tl.make_block_ptr(
6465
base=q_ptr + q_off,
65-
shape=(seq_len, EMB_DIM),
66+
shape=(seq_len_q, EMB_DIM),
6667
strides=(q_stride_m, q_stride_k),
6768
offsets=(offs_m_start, 0),
6869
block_shape=(BLOCK_SIZE_M, EMB_DIM),
@@ -71,7 +72,7 @@ def kernel(
7172
k_off = off_z * k_stride_z + off_h * k_stride_h
7273
k_block_ptr = tl.make_block_ptr(
7374
base=k_ptr + k_off,
74-
shape=(EMB_DIM, seq_len),
75+
shape=(EMB_DIM, seq_len_k_v),
7576
strides=(k_stride_k, k_stride_n),
7677
offsets=(0, 0),
7778
block_shape=(EMB_DIM, BLOCK_SIZE_N),
@@ -80,7 +81,7 @@ def kernel(
8081
v_off = off_z * v_stride_z + off_h * v_stride_h
8182
v_block_ptr = tl.make_block_ptr(
8283
base=v_ptr + v_off,
83-
shape=(seq_len, EMB_DIM),
84+
shape=(seq_len_k_v, EMB_DIM),
8485
strides=(v_stride_k, v_stride_n),
8586
offsets=(0, 0),
8687
block_shape=(BLOCK_SIZE_N, EMB_DIM),
@@ -89,7 +90,7 @@ def kernel(
8990
o_off = off_z * o_stride_z + off_h * o_stride_h
9091
o_block_ptr = tl.make_block_ptr(
9192
base=o_ptr + o_off,
92-
shape=(seq_len, EMB_DIM),
93+
shape=(seq_len_q, EMB_DIM),
9394
strides=(o_stride_m, o_stride_n),
9495
offsets=(offs_m_start, 0),
9596
block_shape=(BLOCK_SIZE_M, EMB_DIM),
@@ -103,10 +104,10 @@ def kernel(
103104
l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32)
104105
m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32)
105106

106-
for i in range(0, tl.cdiv(seq_len, BLOCK_SIZE_N)):
107+
for i in range(0, tl.cdiv(seq_len_k_v, BLOCK_SIZE_N)):
107108
k = tl.load(k_block_ptr, boundary_check=(0, 1))
108109

109-
mask = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) < seq_len
110+
mask = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) < seq_len_k_v
110111
qk = tl.where(mask, tl.dot(q, k), float("-inf"))
111112

112113
m_ij = tl.maximum(m_i, tl.max(qk, 1))

scaled_dot_product_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from transformers.models.llama.modeling_llama import repeat_kv
99

1010
import ops.ninetoothed.torch
11-
import ops.triton.torch
1211
import rope
1312

1413

@@ -89,13 +88,14 @@ def _rope(x, sin_table, cos_table):
8988
if __name__ == "__main__":
9089
torch.manual_seed(0)
9190

92-
shape = (2, 4, 1024, 64)
91+
q_o_shape = (2, 8, 1024, 64)
92+
k_v_shape = (2, 8, 1024, 64)
9393
dtype = torch.float16
9494
device = "cuda"
9595

96-
q = torch.randn(shape, dtype=dtype, device=device)
97-
k = torch.randn(shape, dtype=dtype, device=device)
98-
v = torch.randn(shape, dtype=dtype, device=device)
96+
q = torch.randn(q_o_shape, dtype=dtype, device=device)
97+
k = torch.randn(k_v_shape, dtype=dtype, device=device)
98+
v = torch.randn(k_v_shape, dtype=dtype, device=device)
9999

100100
ninetoothed_output = ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v)
101101
torch_output = F.scaled_dot_product_attention(q, k, v)
@@ -109,7 +109,7 @@ def _rope(x, sin_table, cos_table):
109109
print("✅ NineToothed and PyTorch match.")
110110
else:
111111
print("❌ NineToothed and PyTorch differ.")
112-
if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0):
112+
if torch.allclose(ninetoothed_output, triton_output, atol=1e-3, rtol=0):
113113
print("✅ NineToothed and Triton match.")
114114
else:
115115
print("❌ NineToothed and Triton differ.")

0 commit comments

Comments
 (0)