Skip to content

Commit 92b81ac

Browse files
committed
issue/932 - feat: add paged attention operator referencing nvidia implementation
1 parent 180674d commit 92b81ac

3 files changed

Lines changed: 168 additions & 0 deletions

File tree

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __PAGED_ATTENTION_MOORE_H__
2+
#define __PAGED_ATTENTION_MOORE_H__
3+
4+
#include "../paged_attention.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif // __PAGED_ATTENTION_MOORE_H__
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#include <cub/block/block_reduce.cuh>
2+
3+
#include "../../../devices/moore/moore_common.h"
4+
#include "../../../devices/moore/moore_kernel_common.h"
5+
6+
#include "../../../reduce/cuda/reduce.cuh"
7+
#include "../cuda/kernel.cuh"
8+
#include "paged_attention_moore.h"
9+
10+
template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
11+
INFINIOP_MOORE_KERNEL pagedAttention(
12+
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
13+
const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes,
14+
const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq,
15+
const size_t block_size,
16+
const ptrdiff_t q_stride,
17+
const ptrdiff_t kv_block_stride,
18+
const ptrdiff_t kv_head_stride,
19+
const ptrdiff_t o_stride) {
20+
op::paged_attention::cuda::pagedAttentionKernel<Tdata, Tcompute, HEAD_SIZE, NUM_THREADS>(
21+
out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale,
22+
max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride);
23+
}
24+
25+
namespace op::paged_attention::moore {
26+
27+
struct Descriptor::Opaque {
28+
std::shared_ptr<device::moore::Handle::Internal> internal;
29+
};
30+
31+
Descriptor::~Descriptor() {
32+
delete _opaque;
33+
}
34+
35+
infiniStatus_t Descriptor::create(
36+
infiniopHandle_t handle,
37+
Descriptor **desc_ptr,
38+
infiniopTensorDescriptor_t out_desc,
39+
infiniopTensorDescriptor_t q_desc,
40+
infiniopTensorDescriptor_t k_cache_desc,
41+
infiniopTensorDescriptor_t v_cache_desc,
42+
infiniopTensorDescriptor_t block_tables_desc,
43+
infiniopTensorDescriptor_t seq_lens_desc,
44+
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
45+
float scale) {
46+
auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale);
47+
CHECK_RESULT(info);
48+
*desc_ptr = new Descriptor(
49+
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
50+
info.take(), 0, handle->device, handle->device_id);
51+
52+
return INFINI_STATUS_SUCCESS;
53+
}
54+
55+
template <size_t HEAD_SIZE, size_t NUM_THREADS>
56+
infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache,
57+
infiniDtype_t dtype,
58+
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
59+
size_t num_heads, size_t num_seqs,
60+
size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size,
61+
ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride,
62+
musaStream_t stream) {
63+
dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1);
64+
dim3 block(NUM_THREADS);
65+
size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float);
66+
67+
if (dtype == INFINI_DTYPE_F16) {
68+
pagedAttention<half, float, HEAD_SIZE, NUM_THREADS>
69+
<<<grid, block, shared_mem_size, stream>>>(
70+
(half *)out,
71+
(const half *)q, (const half *)k_cache, (const half *)v_cache,
72+
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
73+
scale, max_num_blocks_per_seq, block_size,
74+
q_stride, kv_block_stride, kv_head_stride, o_stride);
75+
} else if (dtype == INFINI_DTYPE_BF16) {
76+
pagedAttention<__mt_bfloat16, float, HEAD_SIZE, NUM_THREADS>
77+
<<<grid, block, shared_mem_size, stream>>>(
78+
(__mt_bfloat16 *)out, (const __mt_bfloat16 *)q, (const __mt_bfloat16 *)k_cache, (const __mt_bfloat16 *)v_cache,
79+
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
80+
scale, max_num_blocks_per_seq, block_size,
81+
q_stride, kv_block_stride, kv_head_stride, o_stride);
82+
} else if (dtype == INFINI_DTYPE_F32) {
83+
pagedAttention<float, float, HEAD_SIZE, NUM_THREADS>
84+
<<<grid, block, shared_mem_size, stream>>>(
85+
(float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache,
86+
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
87+
scale, max_num_blocks_per_seq, block_size,
88+
q_stride, kv_block_stride, kv_head_stride, o_stride);
89+
} else {
90+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
91+
}
92+
return INFINI_STATUS_SUCCESS;
93+
}
94+
95+
infiniStatus_t Descriptor::calculate(
96+
void *workspace, size_t workspace_size,
97+
void *out, const void *q, const void *k_cache, const void *v_cache,
98+
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
99+
void *stream_) const {
100+
musaStream_t stream = (musaStream_t)stream_;
101+
102+
#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
103+
launchKernel<__H_SIZE, __B_SIZE>( \
104+
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
105+
_info.num_heads, _info.num_seqs, \
106+
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
107+
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
108+
stream);
109+
110+
#define SWITCH_HEAD_SIZE(__B_SIZE) \
111+
switch (_info.head_size) { \
112+
case 16: \
113+
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
114+
break; \
115+
case 32: \
116+
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
117+
break; \
118+
case 64: \
119+
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
120+
break; \
121+
case 128: \
122+
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
123+
break; \
124+
case 256: \
125+
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
126+
break; \
127+
default: \
128+
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
129+
}
130+
131+
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
132+
SWITCH_HEAD_SIZE(MOORE_BLOCK_SIZE_1024)
133+
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
134+
SWITCH_HEAD_SIZE(MOORE_BLOCK_SIZE_512)
135+
} else {
136+
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
137+
}
138+
139+
#undef LAUNCH_HEADSIZE_BLOCKSIZE
140+
#undef SWITCH_HEAD_SIZE
141+
142+
return INFINI_STATUS_SUCCESS;
143+
}
144+
145+
} // namespace op::paged_attention::moore

src/infiniop/ops/paged_attention/operator.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef ENABLE_NVIDIA_API
66
#include "nvidia/paged_attention_nvidia.cuh"
77
#endif
8+
#ifdef ENABLE_MOORE_API
9+
#include "moore/paged_attention_moore.h"
10+
#endif
811
// #ifdef ENABLE_METAX_API
912
// #include "metax/paged_attention_metax.h"
1013
// #endif
@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
3336
switch (handle->device) {
3437
#ifdef ENABLE_NVIDIA_API
3538
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
39+
#endif
40+
#ifdef ENABLE_MOORE_API
41+
CREATE(INFINI_DEVICE_MOORE, moore)
3642
#endif
3743
// #ifdef ENABLE_METAX_API
3844
// CREATE(INFINI_DEVICE_METAX, metax)
@@ -54,6 +60,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
5460
switch (desc->device_type) {
5561
#ifdef ENABLE_NVIDIA_API
5662
GET(INFINI_DEVICE_NVIDIA, nvidia)
63+
#endif
64+
#ifdef ENABLE_MOORE_API
65+
GET(INFINI_DEVICE_MOORE, moore)
5766
#endif
5867
// #ifdef ENABLE_METAX_API
5968
// GET(INFINI_DEVICE_METAX, metax)
@@ -79,6 +88,9 @@ __C infiniStatus_t infiniopPagedAttention(
7988
switch (desc->device_type) {
8089
#ifdef ENABLE_NVIDIA_API
8190
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
91+
#endif
92+
#ifdef ENABLE_MOORE_API
93+
CALCULATE(INFINI_DEVICE_MOORE, moore)
8294
#endif
8395
// #ifdef ENABLE_METAX_API
8496
// CALCULATE(INFINI_DEVICE_METAX, metax)
@@ -99,6 +111,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
99111
switch (desc->device_type) {
100112
#ifdef ENABLE_NVIDIA_API
101113
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
114+
#endif
115+
#ifdef ENABLE_MOORE_API
116+
DESTROY(INFINI_DEVICE_MOORE, moore)
102117
#endif
103118
// #ifdef ENABLE_METAX_API
104119
// DESTROY(INFINI_DEVICE_METAX, metax)

0 commit comments

Comments
 (0)