Skip to content

Commit 4838f40

Browse files
committed
Add causal mask to SDPA kernels and use backend SDPA directly
1 parent 89a1639 commit 4838f40

4 files changed

Lines changed: 29 additions & 19 deletions

File tree

modules/attention.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,8 @@ def forward(
9999
key_states = repeat_kv(key_states, self.num_key_value_groups)
100100
value_states = repeat_kv(value_states, self.num_key_value_groups)
101101

102-
# TODO: NineToothed SDPA kernel lacks causal masking support, which is
103-
# required by autoregressive inference. Fall back to torch so end-to-end
104-
# generation produces coherent output.
105-
attn_output = F.scaled_dot_product_attention(
106-
query_states,
107-
key_states,
108-
value_states,
109-
attn_mask=attention_mask,
110-
is_causal=attention_mask is None and query_states.shape[-2] > 1,
111-
scale=self.scaling,
102+
attn_output = type(self).scaled_dot_product_attention(
103+
query_states, key_states, value_states, scale=self.scaling
112104
)
113105
attn_output = attn_output.transpose(1, 2)
114106

ops/ninetoothed/kernels/scaled_dot_product_attention.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def arrangement(
10-
q, k, v, scale, o, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N
10+
q, k, v, scale, q_start, o, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N
1111
):
1212
def arrange_q_or_o(input):
1313
arranged = input.tile((1, 1, BLOCK_SIZE_M, -1))
@@ -26,10 +26,17 @@ def arrange_k_or_v(input):
2626

2727
q_arranged = arrange_q_or_o(q)
2828

29-
return q_arranged, arrange_k_or_v(k), arrange_k_or_v(v), scale, arrange_q_or_o(o)
29+
return (
30+
q_arranged,
31+
arrange_k_or_v(k),
32+
arrange_k_or_v(v),
33+
scale,
34+
q_start,
35+
arrange_q_or_o(o),
36+
)
3037

3138

32-
def application(q, k, v, scale, o):
39+
def application(q, k, v, scale, q_start, o):
3340
q_loaded = (q * scale * 1.44269504089).to(q.dtype)
3441

3542
acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
@@ -38,7 +45,11 @@ def application(q, k, v, scale, o):
3845

3946
for i in range(k.shape[0]):
4047
qk = ntl.dot(q_loaded, ntl.trans(k[i]))
41-
qk = ntl.where(k[i].offsets(-2) < k.source.shape[-2], qk, float("-inf"))
48+
qk = ntl.where(
49+
(q.offsets(-2) + q_start)[:, None] >= k[i].offsets(-2),
50+
qk,
51+
float("-inf"),
52+
)
4253

4354
m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
4455
p = ntl.exp2(qk - m_ij[:, None])
@@ -53,8 +64,8 @@ def application(q, k, v, scale, o):
5364
o = acc.to(o.dtype) # noqa: F841
5465

5566

56-
shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128})
57-
q, k, v, o = (Tensor(4, shape_options=shape_options) for _ in range(4))
58-
tensors = (q, k, v, Tensor(0), o)
67+
_shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128})
68+
_q, _k, _v, _o = (Tensor(4, shape_options=_shape_options) for _ in range(4))
69+
tensors = (_q, _k, _v, Tensor(0), Tensor(0), _o)
5970

6071
kernel = ninetoothed.make(arrangement, application, tensors)

ops/ninetoothed/torch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,13 @@ def scaled_dot_product_attention(q, k, v, scale=None):
149149
if scale is None:
150150
scale = 1 / math.sqrt(q.shape[-1])
151151

152+
q_start = k.shape[-2] - q.shape[-2]
153+
152154
o = torch.empty_like(q)
153155

154-
ops.ninetoothed.kernels.scaled_dot_product_attention.kernel(q, k, v, scale, o)
156+
ops.ninetoothed.kernels.scaled_dot_product_attention.kernel(
157+
q, k, v, scale, q_start, o
158+
)
155159

156160
return o
157161

ops/triton/kernels/scaled_dot_product_attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,13 @@ def kernel(
9696
l_i = tl.full((BLOCK_SIZE_M,), 1, dtype=tl.float32)
9797
m_i = tl.full((BLOCK_SIZE_M,), float("-inf"), dtype=tl.float32)
9898

99+
q_offsets = seq_len_k_v - seq_len_q + offs_m_start + tl.arange(0, BLOCK_SIZE_M)
100+
99101
for i in range(0, tl.cdiv(seq_len_k_v, BLOCK_SIZE_N)):
100102
k = tl.load(k_block_ptr, boundary_check=(0, 1))
101103

102-
mask = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) < seq_len_k_v
104+
k_offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
105+
mask = q_offsets[:, None] >= k_offsets[None, :]
103106
qk = tl.where(mask, tl.dot(q, k), float("-inf"))
104107

105108
m_ij = tl.maximum(m_i, tl.max(qk, 1))

0 commit comments

Comments
 (0)