Skip to content

Commit c14e60d

Browse files
committed
Add scaled_dot_product_attention operator
1 parent 2da73f8 commit c14e60d

3 files changed

Lines changed: 378 additions & 0 deletions

File tree

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
BLOCK_SIZE_M = ninetoothed.block_size()
8+
BLOCK_SIZE_N = ninetoothed.block_size()
9+
10+
11+
def arrangement(
12+
query,
13+
key,
14+
value,
15+
present_key,
16+
present_value,
17+
present_key_slot,
18+
present_value_slot,
19+
attn_mask,
20+
scale,
21+
output,
22+
with_kv_cache,
23+
BLOCK_SIZE_M=BLOCK_SIZE_M,
24+
BLOCK_SIZE_N=BLOCK_SIZE_N,
25+
):
26+
def arrange_query_or_output(input):
27+
arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)).tile(
28+
(1, query.shape[-3] // key.shape[-3], 1, 1)
29+
)
30+
arranged.dtype = arranged.dtype.squeeze((0, 2, 3))
31+
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
32+
33+
return arranged
34+
35+
def arrange_key_or_value(input):
36+
arranged = (
37+
input.tile((1, 1, BLOCK_SIZE_N, -1))
38+
.tile((1, 1, -1, -1))
39+
.expand((-1, -1, query_arranged.shape[-2], -1))
40+
)
41+
arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
42+
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
43+
44+
return arranged
45+
46+
def arrange_present_key_or_present_value(input):
47+
arranged = input.tile((1, 1, -1, -1))
48+
arranged.dtype = arranged.dtype.squeeze((0, 1))
49+
50+
return arranged
51+
52+
def arrange_attn_mask(input):
53+
arranged = input.tile((1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N)).tile((1, 1, 1, -1))
54+
arranged.dtype = arranged.dtype.squeeze((0, 1, 2))
55+
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
56+
57+
return arranged
58+
59+
query_arranged = arrange_query_or_output(query)
60+
key_arranged = arrange_key_or_value(key)
61+
value_arranged = arrange_key_or_value(value)
62+
present_key_arranged = arrange_present_key_or_present_value(present_key)
63+
present_value_arranged = arrange_present_key_or_present_value(present_value)
64+
present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot)
65+
present_value_slot_arranged = arrange_present_key_or_present_value(
66+
present_value_slot
67+
)
68+
attn_mask_arranged = arrange_attn_mask(attn_mask)
69+
scale_arranged = scale
70+
output_arranged = arrange_query_or_output(output)
71+
72+
if with_kv_cache:
73+
return (
74+
query_arranged,
75+
key_arranged,
76+
value_arranged,
77+
present_key_arranged,
78+
present_value_arranged,
79+
present_key_slot_arranged,
80+
present_value_slot_arranged,
81+
attn_mask_arranged,
82+
scale_arranged,
83+
output_arranged,
84+
)
85+
86+
return (
87+
query_arranged,
88+
key_arranged,
89+
value_arranged,
90+
attn_mask_arranged,
91+
scale_arranged,
92+
output_arranged,
93+
)
94+
95+
96+
def application_with_kv_cache(
97+
query,
98+
key,
99+
value,
100+
present_key,
101+
present_value,
102+
present_key_slot,
103+
present_value_slot,
104+
attn_mask,
105+
scale,
106+
output,
107+
):
108+
present_key_slot = present_key # noqa: F841
109+
present_value_slot = present_value # noqa: F841
110+
111+
application_without_kv_cache(query, key, value, attn_mask, scale, output)
112+
113+
114+
def application_without_kv_cache(query, key, value, attn_mask, scale, output):
115+
for i in range(query.shape[0]):
116+
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
117+
118+
acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32)
119+
lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32)
120+
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)
121+
122+
for j in range(key.shape[0]):
123+
qk = ntl.dot(query_i, ntl.trans(key[j])) + attn_mask[j]
124+
qk = ntl.where(key[j].offsets(-2) < key.source.shape[-2], qk, float("-inf"))
125+
126+
next_max = ntl.maximum(max, ntl.max(qk, 1))
127+
stable_qk = ntl.exp2(qk - next_max[:, None])
128+
129+
alpha = ntl.exp2(max - next_max)
130+
acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j])
131+
max = next_max
132+
lse = lse * alpha + ntl.sum(stable_qk, 1)
133+
134+
acc /= lse[:, None]
135+
output[i] = acc # noqa: F841
136+
137+
138+
@functools.cache
139+
def make(with_kv_cache):
140+
query, key, value, attn_mask, output = (
141+
Tensor(
142+
4, shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128})
143+
)
144+
for _ in range(5)
145+
)
146+
present_key, present_value, present_key_slot, present_value_slot = (
147+
Tensor(
148+
4,
149+
shape_options=(
150+
None,
151+
None,
152+
{"constexpr": True, "upper_bound": 1},
153+
{"constexpr": True, "upper_bound": 128},
154+
),
155+
)
156+
for _ in range(4)
157+
)
158+
scale = Tensor(0)
159+
160+
if with_kv_cache:
161+
application = application_with_kv_cache
162+
else:
163+
application = application_without_kv_cache
164+
165+
tensors = (
166+
query,
167+
key,
168+
value,
169+
present_key,
170+
present_value,
171+
present_key_slot,
172+
present_value_slot,
173+
attn_mask,
174+
scale,
175+
output,
176+
)
177+
178+
return ninetoothed.make(
179+
functools.partial(arrangement, with_kv_cache=with_kv_cache),
180+
application,
181+
tensors,
182+
)

src/ntops/torch.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24

35
import ntops.kernels.abs
@@ -25,6 +27,7 @@
2527
import ntops.kernels.neg
2628
import ntops.kernels.relu
2729
import ntops.kernels.rsqrt
30+
import ntops.kernels.scaled_dot_product_attention
2831
import ntops.kernels.sigmoid
2932
import ntops.kernels.sin
3033
import ntops.kernels.softmax
@@ -314,6 +317,69 @@ def rsqrt(input, *, out=None):
314317
return out
315318

316319

320+
def scaled_dot_product_attention(
321+
query,
322+
key,
323+
value,
324+
attn_mask=None,
325+
dropout_p=0,
326+
is_causal=False,
327+
scale=None,
328+
# The default value here differs from that of
329+
# `torch.nn.functional.scaled_dot_product_attention`
330+
# because GQA cannot be disabled at the moment.
331+
enable_gqa=True,
332+
present_key=None,
333+
present_value=None,
334+
present_key_slot=None,
335+
present_value_slot=None,
336+
):
337+
# TODO: Support `dropout_p`.
338+
assert dropout_p == 0, "`dropout_p` is not supported yet."
339+
# TODO: Support `is_causal`.
340+
assert not is_causal, "`is_causal` is not supported yet."
341+
assert enable_gqa, "GQA must be enabled for now."
342+
343+
mask_shape = query.shape[:-1] + (key.shape[-2],)
344+
345+
if attn_mask is None:
346+
attn_mask = torch.zeros(mask_shape, dtype=query.dtype, device=query.device)
347+
elif attn_mask.dtype == torch.bool:
348+
attn_mask = torch.where(attn_mask, 0, float("-inf"))
349+
350+
attn_mask = attn_mask.expand(mask_shape)
351+
352+
if scale is None:
353+
scale = 1 / math.sqrt(query.shape[-1])
354+
355+
if present_key is not None:
356+
with_kv_cache = True
357+
else:
358+
with_kv_cache = False
359+
360+
output = torch.empty_like(query, dtype=value.dtype)
361+
362+
kernel = ntops.kernels.scaled_dot_product_attention.make(with_kv_cache)
363+
364+
if with_kv_cache:
365+
kernel(
366+
query,
367+
key,
368+
value,
369+
present_key,
370+
present_value,
371+
present_key_slot,
372+
present_value_slot,
373+
attn_mask,
374+
scale,
375+
output,
376+
)
377+
else:
378+
kernel(query, key, value, attn_mask, scale, output)
379+
380+
return output
381+
382+
317383
def sigmoid(input, *, out=None):
318384
if out is None:
319385
out = torch.empty_like(input)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import itertools
2+
import math
3+
import random
4+
5+
import pytest
6+
import torch
7+
import torch.nn.functional as F
8+
9+
import ntops.torch
10+
from tests.skippers import skip_if_cuda_not_available
11+
12+
13+
def generate_arguments():
14+
def _generate_random_qkv(dtype, device):
15+
def _generate_random_size():
16+
return random.randint(1, 512)
17+
18+
batch_size = random.randint(1, 4)
19+
num_heads_q = 2 ** random.randint(1, 5)
20+
seq_len_q = _generate_random_size()
21+
head_dim = random.choice([32, 64])
22+
num_heads_kv = 2 ** random.randint(1, math.floor(math.log2(num_heads_q)))
23+
seq_len_kv = _generate_random_size()
24+
25+
shape_q = (batch_size, num_heads_q, seq_len_q, head_dim)
26+
shape_kv = (batch_size, num_heads_kv, seq_len_kv, head_dim)
27+
28+
query = torch.randn(shape_q, dtype=dtype, device=device)
29+
key = torch.randn(shape_kv, dtype=dtype, device=device)
30+
value = torch.randn(shape_kv, dtype=dtype, device=device)
31+
32+
return query, key, value
33+
34+
device = "cuda"
35+
36+
arguments = []
37+
38+
attn_mask_types = (None, torch.bool, torch.float32)
39+
scales = (None, random.uniform(0.05, 0.5))
40+
dtypes = (torch.float32, torch.float16)
41+
with_kv_cache_values = (False, True)
42+
43+
for attn_mask_type, scale, dtype, with_kv_cache in itertools.product(
44+
attn_mask_types, scales, dtypes, with_kv_cache_values
45+
):
46+
query, key, value = _generate_random_qkv(dtype, device)
47+
48+
if attn_mask_type is not None:
49+
attn_mask = torch.rand(
50+
(query.shape[-2], key.shape[-2]), dtype=query.dtype, device=query.device
51+
)
52+
53+
if attn_mask_type is torch.bool:
54+
attn_mask = attn_mask > 0.5
55+
# TODO: Non-infinite floating-point masks may cause
56+
# precision issues. Revisit here later.
57+
else:
58+
attn_mask = torch.where(attn_mask > 0.5, 0, float("-inf"))
59+
attn_mask = attn_mask.to(query.dtype)
60+
else:
61+
attn_mask = None
62+
63+
enable_gqa = True
64+
65+
if dtype is torch.float32:
66+
atol = 0.01
67+
rtol = 0.01
68+
else:
69+
atol = 0.025
70+
rtol = 0.025
71+
72+
arguments.append(
73+
(query, key, value, attn_mask, scale, enable_gqa, with_kv_cache, atol, rtol)
74+
)
75+
76+
return (
77+
"query, key, value, attn_mask, scale, enable_gqa, with_kv_cache, atol, rtol",
78+
arguments,
79+
)
80+
81+
82+
@skip_if_cuda_not_available
83+
class TestScaledDotProductAttention:
84+
@pytest.mark.parametrize(*generate_arguments())
85+
def test_cuda(
86+
self, query, key, value, attn_mask, scale, enable_gqa, with_kv_cache, atol, rtol
87+
):
88+
key_cloned = key.clone()
89+
value_cloned = value.clone()
90+
91+
def _generate_present_and_slot(tensor):
92+
present = tensor[:, :, -1:, :].clone()
93+
present_slot = tensor[:, :, -1:, :]
94+
present_slot[...] = 0
95+
96+
return present, present_slot
97+
98+
if with_kv_cache:
99+
present_key, present_key_slot = _generate_present_and_slot(key)
100+
present_value, present_value_slot = _generate_present_and_slot(value)
101+
else:
102+
present_key = None
103+
present_value = None
104+
present_key_slot = None
105+
present_value_slot = None
106+
107+
ninetoothed_output = ntops.torch.scaled_dot_product_attention(
108+
query,
109+
key,
110+
value,
111+
attn_mask=attn_mask,
112+
scale=scale,
113+
enable_gqa=enable_gqa,
114+
present_key=present_key,
115+
present_value=present_value,
116+
present_key_slot=present_key_slot,
117+
present_value_slot=present_value_slot,
118+
)
119+
reference_output = F.scaled_dot_product_attention(
120+
query,
121+
key_cloned,
122+
value_cloned,
123+
attn_mask=attn_mask,
124+
scale=scale,
125+
enable_gqa=enable_gqa,
126+
)
127+
128+
assert torch.allclose(
129+
ninetoothed_output, reference_output, atol=atol, rtol=rtol
130+
)

0 commit comments

Comments
 (0)