Skip to content

Commit 4c677af

Browse files
authored
Merge pull request #47 from InfiniTensor/add-causal-variants
Add causal variants
2 parents b2623d4 + 4035dca commit 4c677af

3 files changed

Lines changed: 90 additions & 8 deletions

File tree

src/ntops/kernels/scaled_dot_product_attention.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23

34
import ninetoothed
@@ -8,6 +9,14 @@
89
BLOCK_SIZE_N = ninetoothed.block_size()
910

1011

12+
class CausalVariant(enum.IntEnum):
13+
"""Please refer to `<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.bias.CausalVariant.html>`_."""
14+
15+
UPPER_LEFT = enum.auto()
16+
17+
LOWER_RIGHT = enum.auto()
18+
19+
1120
def arrangement(
1221
query,
1322
key,
@@ -21,6 +30,7 @@ def arrangement(
2130
scale,
2231
output,
2332
with_attn_mask,
33+
causal_variant,
2434
with_kv_cache,
2535
block_size_m=None,
2636
block_size_n=None,
@@ -78,6 +88,7 @@ def arrange_attn_mask(input):
7888
scale_arranged = scale
7989
output_arranged = arrange_query_or_output(output)
8090
with_attn_mask_arranged = with_attn_mask
91+
causal_variant_arranged = causal_variant
8192

8293
if with_kv_cache:
8394
return (
@@ -93,6 +104,7 @@ def arrange_attn_mask(input):
93104
scale_arranged,
94105
output_arranged,
95106
with_attn_mask_arranged,
107+
causal_variant_arranged,
96108
)
97109

98110
return (
@@ -104,6 +116,7 @@ def arrange_attn_mask(input):
104116
scale_arranged,
105117
output_arranged,
106118
with_attn_mask_arranged,
119+
causal_variant_arranged,
107120
)
108121

109122

@@ -120,17 +133,34 @@ def application_with_kv_cache(
120133
scale,
121134
output,
122135
with_attn_mask,
136+
causal_variant,
123137
):
124138
present_key_slot = present_key # noqa: F841
125139
present_value_slot = present_value # noqa: F841
126140

127141
application_without_kv_cache(
128-
query, key, value, attn_mask, is_causal, scale, output, with_attn_mask
142+
query,
143+
key,
144+
value,
145+
attn_mask,
146+
is_causal,
147+
scale,
148+
output,
149+
with_attn_mask,
150+
causal_variant,
129151
)
130152

131153

132154
def application_without_kv_cache(
133-
query, key, value, attn_mask, is_causal, scale, output, with_attn_mask
155+
query,
156+
key,
157+
value,
158+
attn_mask,
159+
is_causal,
160+
scale,
161+
output,
162+
with_attn_mask,
163+
causal_variant,
134164
):
135165
for i in range(query.shape[0]):
136166
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
@@ -147,7 +177,16 @@ def application_without_kv_cache(
147177
qk += attn_mask[j]
148178

149179
if is_causal:
150-
mask = query[i].offsets(-2)[:, None] >= key[j].offsets(-2)[None, :]
180+
if causal_variant == 2: # CausalVariant.LOWER_RIGHT:
181+
mask = (
182+
query[i].offsets(-2)[:, None]
183+
+ key.source.shape[-2]
184+
- query.source.shape[-2]
185+
>= key[j].offsets(-2)[None, :]
186+
)
187+
else:
188+
mask = query[i].offsets(-2)[:, None] >= key[j].offsets(-2)[None, :]
189+
151190
qk = ntl.where(mask, qk, float("-inf"))
152191

153192
next_max = ntl.maximum(max, ntl.max(qk, 1))
@@ -167,6 +206,7 @@ def premake(
167206
emb_dim=None,
168207
is_causal=None,
169208
with_attn_mask=None,
209+
causal_variant=None,
170210
dtype=None,
171211
block_size_m=None,
172212
block_size_n=None,
@@ -192,6 +232,7 @@ def premake(
192232
scale = Tensor(0, dtype=dtype)
193233
is_causal = Tensor(0, dtype=dtype, constexpr=True, value=is_causal)
194234
with_attn_mask = Tensor(0, dtype=dtype, constexpr=True, value=with_attn_mask)
235+
causal_variant = Tensor(0, dtype=dtype, constexpr=True, value=causal_variant)
195236

196237
if emb_dim is not None:
197238
for tensor in (query, key, value, attn_mask, output):
@@ -215,6 +256,7 @@ def premake(
215256
scale,
216257
output,
217258
with_attn_mask,
259+
causal_variant,
218260
)
219261

220262
return arrangement_, application, tensors

src/ntops/torch.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import ntops.kernels.softmax
4343
import ntops.kernels.sub
4444
import ntops.kernels.tanh
45+
from ntops.kernels.scaled_dot_product_attention import CausalVariant
4546

4647

4748
def abs(input, *, out=None):
@@ -445,6 +446,7 @@ def scaled_dot_product_attention(
445446
is_causal=False,
446447
scale=None,
447448
enable_gqa=False,
449+
causal_variant=None,
448450
present_key=None,
449451
present_value=None,
450452
present_key_slot=None,
@@ -490,6 +492,9 @@ def scaled_dot_product_attention(
490492
if scale is None:
491493
scale = 1 / math.sqrt(query.shape[-1])
492494

495+
if causal_variant is None:
496+
causal_variant = CausalVariant.UPPER_LEFT
497+
493498
if present_key is not None:
494499
with_kv_cache = True
495500
else:
@@ -515,9 +520,20 @@ def scaled_dot_product_attention(
515520
scale,
516521
output,
517522
with_attn_mask,
523+
causal_variant,
518524
)
519525
else:
520-
kernel(query, key, value, attn_mask, is_causal, scale, output, with_attn_mask)
526+
kernel(
527+
query,
528+
key,
529+
value,
530+
attn_mask,
531+
is_causal,
532+
scale,
533+
output,
534+
with_attn_mask,
535+
causal_variant,
536+
)
521537

522538
return output
523539

tests/test_scaled_dot_product_attention.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import pytest
66
import torch
77
import torch.nn.functional as F
8+
from torch.nn.attention.bias import causal_lower_right
89

910
import ntops.torch
11+
from ntops.kernels.scaled_dot_product_attention import CausalVariant
1012
from tests.skippers import skip_if_cuda_not_available
1113

1214

@@ -21,9 +23,22 @@ def _generate_random_size():
2123
scales = (None, random.uniform(0.05, 0.5))
2224
dtypes = (torch.float32, torch.float16)
2325
with_kv_cache_values = (False, True)
24-
25-
for attn_mask_type, is_causal, scale, dtype, with_kv_cache in itertools.product(
26-
attn_mask_types, is_causal_values, scales, dtypes, with_kv_cache_values
26+
causal_variants = (None, CausalVariant.LOWER_RIGHT, CausalVariant.UPPER_LEFT)
27+
28+
for (
29+
attn_mask_type,
30+
is_causal,
31+
scale,
32+
dtype,
33+
with_kv_cache,
34+
causal_variant,
35+
) in itertools.product(
36+
attn_mask_types,
37+
is_causal_values,
38+
scales,
39+
dtypes,
40+
with_kv_cache_values,
41+
causal_variants,
2742
):
2843
if attn_mask_type is not None and is_causal:
2944
continue
@@ -56,6 +71,7 @@ def _generate_random_size():
5671
is_causal,
5772
scale,
5873
enable_gqa,
74+
causal_variant,
5975
with_kv_cache,
6076
dtype,
6177
atol,
@@ -64,7 +80,7 @@ def _generate_random_size():
6480
)
6581

6682
return (
67-
"batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, with_kv_cache, dtype, atol, rtol",
83+
"batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, causal_variant, with_kv_cache, dtype, atol, rtol",
6884
arguments,
6985
)
7086

@@ -82,6 +98,7 @@ def test_cuda(
8298
is_causal,
8399
scale,
84100
enable_gqa,
101+
causal_variant,
85102
with_kv_cache,
86103
dtype,
87104
atol,
@@ -138,11 +155,18 @@ def _generate_present_and_slot(tensor):
138155
is_causal=is_causal,
139156
scale=scale,
140157
enable_gqa=enable_gqa,
158+
causal_variant=causal_variant,
141159
present_key=present_key,
142160
present_value=present_value,
143161
present_key_slot=present_key_slot,
144162
present_value_slot=present_value_slot,
145163
)
164+
165+
if is_causal:
166+
if causal_variant == CausalVariant.LOWER_RIGHT:
167+
attn_mask = causal_lower_right(query.shape[-2], key.shape[-2])
168+
is_causal = False
169+
146170
reference_output = F.scaled_dot_product_attention(
147171
query,
148172
key_cloned,

0 commit comments

Comments
 (0)