Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions src/ntops/kernels/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import functools

import ninetoothed
Expand All @@ -8,6 +9,14 @@
BLOCK_SIZE_N = ninetoothed.block_size()


class CausalVariant(enum.IntEnum):
"""Please refer to `<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.bias.CausalVariant.html>`_."""

UPPER_LEFT = enum.auto()

LOWER_RIGHT = enum.auto()


def arrangement(
query,
key,
Expand All @@ -21,6 +30,7 @@ def arrangement(
scale,
output,
with_attn_mask,
causal_variant,
with_kv_cache,
block_size_m=None,
block_size_n=None,
Expand Down Expand Up @@ -78,6 +88,7 @@ def arrange_attn_mask(input):
scale_arranged = scale
output_arranged = arrange_query_or_output(output)
with_attn_mask_arranged = with_attn_mask
causal_variant_arranged = causal_variant

if with_kv_cache:
return (
Expand All @@ -93,6 +104,7 @@ def arrange_attn_mask(input):
scale_arranged,
output_arranged,
with_attn_mask_arranged,
causal_variant_arranged,
)

return (
Expand All @@ -104,6 +116,7 @@ def arrange_attn_mask(input):
scale_arranged,
output_arranged,
with_attn_mask_arranged,
causal_variant_arranged,
)


Expand All @@ -120,17 +133,34 @@ def application_with_kv_cache(
scale,
output,
with_attn_mask,
causal_variant,
):
present_key_slot = present_key # noqa: F841
present_value_slot = present_value # noqa: F841

application_without_kv_cache(
query, key, value, attn_mask, is_causal, scale, output, with_attn_mask
query,
key,
value,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
)


def application_without_kv_cache(
query, key, value, attn_mask, is_causal, scale, output, with_attn_mask
query,
key,
value,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
):
for i in range(query.shape[0]):
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
Expand All @@ -147,7 +177,16 @@ def application_without_kv_cache(
qk += attn_mask[j]

if is_causal:
mask = query[i].offsets(-2)[:, None] >= key[j].offsets(-2)[None, :]
if causal_variant == 2: # CausalVariant.LOWER_RIGHT:
mask = (
query[i].offsets(-2)[:, None]
+ key.source.shape[-2]
- query.source.shape[-2]
>= key[j].offsets(-2)[None, :]
)
else:
mask = query[i].offsets(-2)[:, None] >= key[j].offsets(-2)[None, :]

qk = ntl.where(mask, qk, float("-inf"))

next_max = ntl.maximum(max, ntl.max(qk, 1))
Expand All @@ -167,6 +206,7 @@ def premake(
emb_dim=None,
is_causal=None,
with_attn_mask=None,
causal_variant=None,
dtype=None,
block_size_m=None,
block_size_n=None,
Expand All @@ -192,6 +232,7 @@ def premake(
scale = Tensor(0, dtype=dtype)
is_causal = Tensor(0, dtype=dtype, constexpr=True, value=is_causal)
with_attn_mask = Tensor(0, dtype=dtype, constexpr=True, value=with_attn_mask)
causal_variant = Tensor(0, dtype=dtype, constexpr=True, value=causal_variant)

if emb_dim is not None:
for tensor in (query, key, value, attn_mask, output):
Expand All @@ -215,6 +256,7 @@ def premake(
scale,
output,
with_attn_mask,
causal_variant,
)

return arrangement_, application, tensors
18 changes: 17 additions & 1 deletion src/ntops/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import ntops.kernels.softmax
import ntops.kernels.sub
import ntops.kernels.tanh
from ntops.kernels.scaled_dot_product_attention import CausalVariant


def abs(input, *, out=None):
Expand Down Expand Up @@ -445,6 +446,7 @@ def scaled_dot_product_attention(
is_causal=False,
scale=None,
enable_gqa=False,
causal_variant=None,
present_key=None,
present_value=None,
present_key_slot=None,
Expand Down Expand Up @@ -490,6 +492,9 @@ def scaled_dot_product_attention(
if scale is None:
scale = 1 / math.sqrt(query.shape[-1])

if causal_variant is None:
causal_variant = CausalVariant.UPPER_LEFT

if present_key is not None:
with_kv_cache = True
else:
Expand All @@ -515,9 +520,20 @@ def scaled_dot_product_attention(
scale,
output,
with_attn_mask,
causal_variant,
)
else:
kernel(query, key, value, attn_mask, is_causal, scale, output, with_attn_mask)
kernel(
query,
key,
value,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
)

return output

Expand Down
32 changes: 28 additions & 4 deletions tests/test_scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import pytest
import torch
import torch.nn.functional as F
from torch.nn.attention.bias import causal_lower_right

import ntops.torch
from ntops.kernels.scaled_dot_product_attention import CausalVariant
from tests.skippers import skip_if_cuda_not_available


Expand All @@ -21,9 +23,22 @@ def _generate_random_size():
scales = (None, random.uniform(0.05, 0.5))
dtypes = (torch.float32, torch.float16)
with_kv_cache_values = (False, True)

for attn_mask_type, is_causal, scale, dtype, with_kv_cache in itertools.product(
attn_mask_types, is_causal_values, scales, dtypes, with_kv_cache_values
causal_variants = (None, CausalVariant.LOWER_RIGHT, CausalVariant.UPPER_LEFT)

for (
attn_mask_type,
is_causal,
scale,
dtype,
with_kv_cache,
causal_variant,
) in itertools.product(
attn_mask_types,
is_causal_values,
scales,
dtypes,
with_kv_cache_values,
causal_variants,
):
if attn_mask_type is not None and is_causal:
continue
Expand Down Expand Up @@ -56,6 +71,7 @@ def _generate_random_size():
is_causal,
scale,
enable_gqa,
causal_variant,
with_kv_cache,
dtype,
atol,
Expand All @@ -64,7 +80,7 @@ def _generate_random_size():
)

return (
"batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, with_kv_cache, dtype, atol, rtol",
"batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, causal_variant, with_kv_cache, dtype, atol, rtol",
arguments,
)

Expand All @@ -82,6 +98,7 @@ def test_cuda(
is_causal,
scale,
enable_gqa,
causal_variant,
with_kv_cache,
dtype,
atol,
Expand Down Expand Up @@ -138,11 +155,18 @@ def _generate_present_and_slot(tensor):
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
causal_variant=causal_variant,
present_key=present_key,
present_value=present_value,
present_key_slot=present_key_slot,
present_value_slot=present_value_slot,
)

if is_causal:
if causal_variant == CausalVariant.LOWER_RIGHT:
attn_mask = causal_lower_right(query.shape[-2], key.shape[-2])
is_causal = False

reference_output = F.scaled_dot_product_attention(
query,
key_cloned,
Expand Down