Skip to content

Commit 13d44a6

Browse files
committed
issue/889 - changed size_t to tensor in flash attn interface
1 parent 7891488 commit 13d44a6

12 files changed

Lines changed: 478 additions & 58 deletions

File tree

include/infinicore/ops/flash_attention.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
namespace infinicore::op {
77

8-
INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, Tensor, Tensor, Tensor, std::size_t, float, bool);
8+
INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool);
99

10-
Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal);
11-
void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal);
10+
Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
11+
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
1212
} // namespace infinicore::op

include/infiniop/ops/flash_attention.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ __C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
1212
infiniopTensorDescriptor_t q_desc,
1313
infiniopTensorDescriptor_t k_desc,
1414
infiniopTensorDescriptor_t v_desc,
15-
std::size_t total_kv_len,
15+
infiniopTensorDescriptor_t total_kv_len,
1616
float scale,
1717
char is_causal);
1818

@@ -28,6 +28,7 @@ __C __export infiniStatus_t infiniopFlashAttention(
2828
const void *q,
2929
const void *k,
3030
const void *v,
31+
const void *total_kv_len,
3132
void *stream);
3233

3334
__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .causal_softmax import causal_softmax
22
from .embedding import embedding
3+
from .flash_attention import flash_attention
34
from .linear import linear
45
from .random_sample import random_sample
56
from .rms_norm import rms_norm
@@ -10,13 +11,14 @@
1011

1112
__all__ = [
1213
"causal_softmax",
14+
"embedding",
15+
"flash_attention",
16+
"linear",
1317
"random_sample",
1418
"rms_norm",
19+
"rope",
1520
"scaled_dot_product_attention",
1621
"silu",
1722
"swiglu",
18-
"linear",
19-
"embedding",
20-
"rope",
2123
"RopeAlgo",
2224
]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import math
2+
3+
from infinicore.lib import _infinicore
4+
from infinicore.tensor import Tensor
5+
6+
7+
def flash_attention(
8+
query,
9+
key,
10+
value,
11+
total_kv_len,
12+
attn_mask=None,
13+
dropout_p=0,
14+
is_causal=False,
15+
scale=None,
16+
enable_gqa=False,
17+
):
18+
assert attn_mask is None and dropout_p == 0 and not enable_gqa
19+
20+
emb_dim = query.shape[-1]
21+
22+
if scale is None:
23+
scale = 1 / math.sqrt(emb_dim)
24+
25+
return Tensor(
26+
_infinicore.flash_attention(
27+
query._underlying,
28+
key._underlying,
29+
value._underlying,
30+
total_kv_len._underlying,
31+
scale,
32+
is_causal,
33+
)
34+
)

python/infinicore/nn/functional/scaled_dot_product_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ def scaled_dot_product_attention(
1414
scale=None,
1515
enable_gqa=False,
1616
):
17+
raise NotImplementedError("Scaled Dot Product Attention is not yet supported.")
18+
1719
assert attn_mask is None and dropout_p == 0 and not enable_gqa
1820

1921
emb_dim = query.shape[-1]

src/infinicore/ops/flash_attention/flash_attention.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,24 @@ namespace infinicore::op {
66

77
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention);
88

9-
FlashAttention::FlashAttention(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
9+
FlashAttention::FlashAttention(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
1010
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v);
1111
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
1212
out, q, k, v, total_kv_len, scale, is_causal);
1313
}
1414

15-
void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
15+
void FlashAttention::execute(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
1616
INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal);
1717
}
1818

19-
Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
19+
Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
2020
Shape shape = q->shape();
2121
auto out = Tensor::empty(shape, q->dtype(), q->device());
2222
flash_attention_(out, q, k, v, total_kv_len, scale, is_causal);
2323
return out;
2424
}
2525

26-
void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
26+
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
2727
FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal);
2828
}
2929
} // namespace infinicore::op

src/infinicore/ops/flash_attention/flash_attention_infiniop.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@ INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100);
1111

1212
struct PlannedMeta {
1313
std::shared_ptr<Descriptor> descriptor;
14-
graph::GraphTensor workspace, out, q, k, v;
15-
std::size_t total_kv_len;
14+
graph::GraphTensor workspace, out, q, k, v, total_kv_len;
1615
float scale;
1716
bool is_causal;
1817
};
1918

20-
void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
19+
void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
2120
size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal);
2221

2322
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
2423
Descriptor, descriptor, FlashAttention,
25-
seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len, scale, is_causal);
24+
seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len->desc(), scale, is_causal);
2625

2726
INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor);
2827

@@ -33,7 +32,7 @@ void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, f
3332
graph::GraphTensor(q),
3433
graph::GraphTensor(k),
3534
graph::GraphTensor(v),
36-
total_kv_len, scale, is_causal};
35+
graph::GraphTensor(total_kv_len), scale, is_causal};
3736

3837
return planned;
3938
}
@@ -43,7 +42,7 @@ void run(void *planned_meta) {
4342

4443
INFINICORE_CHECK_ERROR(infiniopFlashAttention(
4544
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
46-
planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), context::getStream()));
45+
planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), planned->total_kv_len->data(), context::getStream()));
4746
}
4847

4948
void cleanup(void **planned_meta_ptr) {

src/infiniop/ops/flash_attention/ninetoothed/build.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import ninetoothed
2-
from ntops.kernels import scaled_dot_product_attention
3-
from ntops.kernels.scaled_dot_product_attention import CausalVariant
2+
from . import flash_attention
3+
from .flash_attention import CausalVariant
44

55
import infiniop.ninetoothed.build
66

@@ -27,7 +27,7 @@ def build():
2727
}
2828

2929
infiniop.ninetoothed.build.build(
30-
scaled_dot_product_attention.premake,
30+
flash_attention.premake,
3131
constexpr_param_grid,
3232
caller="cuda",
3333
op_name="flash_attention",

src/infiniop/ops/flash_attention/ninetoothed/descriptor.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Descriptor final : public InfiniopDescriptor {
1717
infiniopTensorDescriptor_t q_desc,
1818
infiniopTensorDescriptor_t k_desc,
1919
infiniopTensorDescriptor_t v_desc,
20-
std::size_t total_kv_len,
20+
infiniopTensorDescriptor_t total_kv_len,
2121
double scale,
2222
char is_causal) : InfiniopDescriptor{handle->device, handle->device_id},
2323
_query_shape{q_desc->shape()},
@@ -26,12 +26,12 @@ class Descriptor final : public InfiniopDescriptor {
2626
_key_strides{k_desc->strides()},
2727
_value_shape{v_desc->shape()},
2828
_value_strides{v_desc->strides()},
29+
_total_kv_shape{total_kv_len->shape()},
30+
_total_kv_strides{total_kv_len->strides()},
2931
_output_strides{out_desc->strides()},
3032
_dtype{q_desc->dtype()},
3133
_scale{scale},
3234
_is_causal{is_causal} {
33-
_key_shape[_key_shape.size() - 2] = total_kv_len;
34-
_value_shape[_key_shape.size() - 2] = total_kv_len;
3535
}
3636

3737
~Descriptor() = default;
@@ -46,13 +46,15 @@ class Descriptor final : public InfiniopDescriptor {
4646
const void *q,
4747
const void *k,
4848
const void *v,
49+
const void *total_kv_len,
4950
void *stream) const {
5051
uint64_t empty_shape[4];
5152
int64_t empty_strides[4];
5253

5354
auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}};
5455
auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}};
5556
auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}};
57+
auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}};
5658

5759
NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides};
5860
NineToothedTensor is_causal;
@@ -75,6 +77,7 @@ class Descriptor final : public InfiniopDescriptor {
7577
query,
7678
key,
7779
value,
80+
total_kv_length,
7881
attn_mask,
7982
is_causal,
8083
scale,
@@ -101,7 +104,7 @@ class Descriptor final : public InfiniopDescriptor {
101104
infiniopTensorDescriptor_t q_desc,
102105
infiniopTensorDescriptor_t k_desc,
103106
infiniopTensorDescriptor_t v_desc,
104-
std::size_t total_kv_len,
107+
infiniopTensorDescriptor_t total_kv_len,
105108
double scale,
106109
char is_causal) {
107110
*desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal};
@@ -126,6 +129,10 @@ class Descriptor final : public InfiniopDescriptor {
126129

127130
std::vector<Stride> _value_strides;
128131

132+
std::vector<Size> _total_kv_shape;
133+
134+
std::vector<Stride> _total_kv_strides;
135+
129136
std::vector<Stride> _output_strides;
130137

131138
infiniDtype_t _dtype;

0 commit comments

Comments
 (0)