Skip to content

Commit e4009f1

Browse files
voltjiaZiminli
andcommitted
Add scaled_dot_product_attention operator
Co-authored-by: Zimin Li <coollizimin@gmail.com>
1 parent 2da73f8 commit e4009f1

3 files changed

Lines changed: 396 additions & 0 deletions

File tree

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
BLOCK_SIZE_M = ninetoothed.block_size()
8+
BLOCK_SIZE_N = ninetoothed.block_size()
9+
10+
11+
def arrangement(
12+
query,
13+
key,
14+
value,
15+
present_key,
16+
present_value,
17+
present_key_slot,
18+
present_value_slot,
19+
attn_mask,
20+
scale,
21+
output,
22+
with_kv_cache,
23+
BLOCK_SIZE_M=BLOCK_SIZE_M,
24+
BLOCK_SIZE_N=BLOCK_SIZE_N,
25+
):
26+
def arrange_query_or_output(input):
27+
arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)).tile(
28+
(1, query.shape[-3] // key.shape[-3], 1, 1)
29+
)
30+
arranged.dtype = arranged.dtype.squeeze((0, 2, 3))
31+
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
32+
33+
return arranged
34+
35+
def arrange_key_or_value(input):
36+
arranged = (
37+
input.tile((1, 1, BLOCK_SIZE_N, -1))
38+
.tile((1, 1, -1, -1))
39+
.expand((-1, -1, query_arranged.shape[-2], -1))
40+
)
41+
arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
42+
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
43+
44+
return arranged
45+
46+
def arrange_present_key_or_present_value(input):
47+
arranged = input.tile((1, 1, -1, -1))
48+
arranged.dtype = arranged.dtype.squeeze((0, 1))
49+
50+
return arranged
51+
52+
def arrange_attn_mask(input):
53+
arranged = input.tile((1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N)).tile((1, 1, 1, -1))
54+
arranged.dtype = arranged.dtype.squeeze((0, 1, 2))
55+
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
56+
57+
return arranged
58+
59+
query_arranged = arrange_query_or_output(query)
60+
key_arranged = arrange_key_or_value(key)
61+
value_arranged = arrange_key_or_value(value)
62+
present_key_arranged = arrange_present_key_or_present_value(present_key)
63+
present_value_arranged = arrange_present_key_or_present_value(present_value)
64+
present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot)
65+
present_value_slot_arranged = arrange_present_key_or_present_value(
66+
present_value_slot
67+
)
68+
attn_mask_arranged = arrange_attn_mask(attn_mask)
69+
scale_arranged = scale
70+
output_arranged = arrange_query_or_output(output)
71+
72+
if with_kv_cache:
73+
return (
74+
query_arranged,
75+
key_arranged,
76+
value_arranged,
77+
present_key_arranged,
78+
present_value_arranged,
79+
present_key_slot_arranged,
80+
present_value_slot_arranged,
81+
attn_mask_arranged,
82+
scale_arranged,
83+
output_arranged,
84+
)
85+
86+
return (
87+
query_arranged,
88+
key_arranged,
89+
value_arranged,
90+
attn_mask_arranged,
91+
scale_arranged,
92+
output_arranged,
93+
)
94+
95+
96+
def application_with_kv_cache(
97+
query,
98+
key,
99+
value,
100+
present_key,
101+
present_value,
102+
present_key_slot,
103+
present_value_slot,
104+
attn_mask,
105+
scale,
106+
output,
107+
):
108+
present_key_slot = present_key # noqa: F841
109+
present_value_slot = present_value # noqa: F841
110+
111+
application_without_kv_cache(query, key, value, attn_mask, scale, output)
112+
113+
114+
def application_without_kv_cache(query, key, value, attn_mask, scale, output):
115+
for i in range(query.shape[0]):
116+
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
117+
118+
acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32)
119+
lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32)
120+
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)
121+
122+
for j in range(key.shape[0]):
123+
qk = ntl.dot(query_i, ntl.trans(key[j])) + attn_mask[j]
124+
qk = ntl.where(key[j].offsets(-2) < key.source.shape[-2], qk, float("-inf"))
125+
126+
next_max = ntl.maximum(max, ntl.max(qk, 1))
127+
stable_qk = ntl.exp2(qk - next_max[:, None])
128+
129+
alpha = ntl.exp2(max - next_max)
130+
acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j])
131+
max = next_max
132+
lse = lse * alpha + ntl.sum(stable_qk, 1)
133+
134+
acc /= lse[:, None]
135+
output[i] = acc # noqa: F841
136+
137+
138+
@functools.cache
139+
def make(with_kv_cache):
140+
query, key, value, attn_mask, output = (
141+
Tensor(
142+
4, shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128})
143+
)
144+
for _ in range(5)
145+
)
146+
present_key, present_value, present_key_slot, present_value_slot = (
147+
Tensor(
148+
4,
149+
shape_options=(
150+
None,
151+
None,
152+
{"constexpr": True, "upper_bound": 1},
153+
{"constexpr": True, "upper_bound": 128},
154+
),
155+
)
156+
for _ in range(4)
157+
)
158+
scale = Tensor(0)
159+
160+
if with_kv_cache:
161+
application = application_with_kv_cache
162+
else:
163+
application = application_without_kv_cache
164+
165+
tensors = (
166+
query,
167+
key,
168+
value,
169+
present_key,
170+
present_value,
171+
present_key_slot,
172+
present_value_slot,
173+
attn_mask,
174+
scale,
175+
output,
176+
)
177+
178+
return ninetoothed.make(
179+
functools.partial(arrangement, with_kv_cache=with_kv_cache),
180+
application,
181+
tensors,
182+
)

src/ntops/torch.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24

35
import ntops.kernels.abs
@@ -25,6 +27,7 @@
2527
import ntops.kernels.neg
2628
import ntops.kernels.relu
2729
import ntops.kernels.rsqrt
30+
import ntops.kernels.scaled_dot_product_attention
2831
import ntops.kernels.sigmoid
2932
import ntops.kernels.sin
3033
import ntops.kernels.softmax
@@ -314,6 +317,69 @@ def rsqrt(input, *, out=None):
314317
return out
315318

316319

320+
def scaled_dot_product_attention(
321+
query,
322+
key,
323+
value,
324+
attn_mask=None,
325+
dropout_p=0,
326+
is_causal=False,
327+
scale=None,
328+
# The default value here differs from that of
329+
# `torch.nn.functional.scaled_dot_product_attention`
330+
# because GQA cannot be disabled at the moment.
331+
enable_gqa=True,
332+
present_key=None,
333+
present_value=None,
334+
present_key_slot=None,
335+
present_value_slot=None,
336+
):
337+
# TODO: Support `dropout_p`.
338+
assert dropout_p == 0, "`dropout_p` is not supported yet."
339+
# TODO: Support `is_causal`.
340+
assert not is_causal, "`is_causal` is not supported yet."
341+
assert enable_gqa, "GQA must be enabled for now."
342+
343+
mask_shape = query.shape[:-1] + (key.shape[-2],)
344+
345+
if attn_mask is None:
346+
attn_mask = torch.zeros(mask_shape, dtype=query.dtype, device=query.device)
347+
elif attn_mask.dtype == torch.bool:
348+
attn_mask = torch.where(attn_mask, 0, float("-inf"))
349+
350+
attn_mask = attn_mask.expand(mask_shape)
351+
352+
if scale is None:
353+
scale = 1 / math.sqrt(query.shape[-1])
354+
355+
if present_key is not None:
356+
with_kv_cache = True
357+
else:
358+
with_kv_cache = False
359+
360+
output = torch.empty_like(query, dtype=value.dtype)
361+
362+
kernel = ntops.kernels.scaled_dot_product_attention.make(with_kv_cache)
363+
364+
if with_kv_cache:
365+
kernel(
366+
query,
367+
key,
368+
value,
369+
present_key,
370+
present_value,
371+
present_key_slot,
372+
present_value_slot,
373+
attn_mask,
374+
scale,
375+
output,
376+
)
377+
else:
378+
kernel(query, key, value, attn_mask, scale, output)
379+
380+
return output
381+
382+
317383
def sigmoid(input, *, out=None):
318384
if out is None:
319385
out = torch.empty_like(input)

0 commit comments

Comments
 (0)