Skip to content

Commit fc45cea

Browse files
committed
Fix the boundary issues
1 parent 6952672 commit fc45cea

3 files changed

Lines changed: 13 additions & 11 deletions

File tree

ops/ninetoothed/kernels/scaled_dot_product_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def application(q, k, v, scale, o):
3838

3939
for i in range(k.shape[0]):
4040
qk = ntl.dot(q_loaded, ntl.trans(k[i]))
41+
qk = ntl.where(k[i].offsets(-2) < k.source.shape[-2], qk, float("-inf"))
4142

4243
m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
4344
p = ntl.exp2(qk - m_ij[:, None])
@@ -49,7 +50,7 @@ def application(q, k, v, scale, o):
4950
l_i = l_i * alpha + l_ij
5051

5152
acc /= l_i[:, None]
52-
o = acc # noqa: F841
53+
o = acc.to(o.dtype) # noqa: F841
5354

5455

5556
shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128})

ops/triton/kernels/scaled_dot_product_attention.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,24 @@ def kernel(
9696
order=(1, 0),
9797
)
9898

99-
q = (tl.load(q_block_ptr) * scale * 1.44269504089).to(q_block_ptr.type.element_ty)
99+
q = tl.load(q_block_ptr, boundary_check=(0, 1))
100+
q = (q * scale * 1.44269504089).to(q_block_ptr.type.element_ty)
100101

101102
acc = tl.zeros((BLOCK_SIZE_M, EMB_DIM), dtype=tl.float32)
102103
l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32)
103104
m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32)
104105

105-
for _ in range(0, tl.cdiv(seq_len, BLOCK_SIZE_N)):
106-
k = tl.load(k_block_ptr)
106+
for i in range(0, tl.cdiv(seq_len, BLOCK_SIZE_N)):
107+
k = tl.load(k_block_ptr, boundary_check=(0, 1))
107108

108-
qk = tl.dot(q, k)
109+
mask = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) < seq_len
110+
qk = tl.where(mask, tl.dot(q, k), float("-inf"))
109111

110112
m_ij = tl.maximum(m_i, tl.max(qk, 1))
111-
qk -= m_ij[:, None]
112-
p = tl.exp2(qk)
113+
p = tl.exp2(qk - m_ij[:, None])
113114
l_ij = tl.sum(p, 1)
114115

115-
v = tl.load(v_block_ptr)
116+
v = tl.load(v_block_ptr, boundary_check=(0, 1))
116117
alpha = tl.exp2(m_i - m_ij)
117118
acc = acc * alpha[:, None] + tl.dot(p.to(v_block_ptr.type.element_ty), v)
118119
m_i = m_ij
@@ -123,4 +124,4 @@ def kernel(
123124

124125
acc /= l_i[:, None]
125126

126-
tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty))
127+
tl.store(o_block_ptr, acc.to(o_ptr.type.element_ty), boundary_check=(0, 1))

scaled_dot_product_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _rope(x, sin_table, cos_table):
109109
print("✅ NineToothed and PyTorch match.")
110110
else:
111111
print("❌ NineToothed and PyTorch differ.")
112-
if torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001):
112+
if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0):
113113
print("✅ NineToothed and Triton match.")
114114
else:
115115
print("❌ NineToothed and Triton differ.")
@@ -143,7 +143,7 @@ def benchmark(seq_len, provider):
143143
triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v)
144144

145145
assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025)
146-
assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001)
146+
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)
147147

148148
if provider == "ninetoothed":
149149
ms = triton.testing.do_bench(

0 commit comments

Comments
 (0)