Skip to content

Commit c23cd85

Browse files
committed
issue/932 - feat: add paged attention prefill operator referencing nvidia implementation
1 parent 2714eb0 commit c23cd85

3 files changed

Lines changed: 149 additions & 0 deletions

File tree

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __PAGED_ATTENTION_PREFILL_MOORE_H__
2+
#define __PAGED_ATTENTION_PREFILL_MOORE_H__
3+
4+
#include "../paged_attention_prefill.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif // __PAGED_ATTENTION_PREFILL_MOORE_H__
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#include <musa_fp16.h>
2+
#include <float.h>
3+
#include <math.h>
4+
#include <stdint.h>
5+
6+
#include "../../../devices/moore/moore_common.h"
7+
#include "../../../devices/moore/moore_kernel_common.h"
8+
#include "../cuda/kernel.cuh"
9+
#include "paged_attention_prefill_moore.h"
10+
11+
template <typename Tdata, typename Tcompute>
12+
infiniStatus_t launchPagedAttentionPrefill(
13+
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
14+
const int64_t *block_tables,
15+
const int64_t *seq_lens,
16+
const int64_t *cum_seq_lens_q,
17+
const float *alibi_slopes,
18+
const size_t num_heads,
19+
const size_t num_seqs,
20+
const size_t num_kv_heads,
21+
const float scale,
22+
const size_t max_num_blocks_per_seq,
23+
const size_t block_size,
24+
const size_t total_q_tokens,
25+
const size_t head_size,
26+
const ptrdiff_t kv_block_stride,
27+
const ptrdiff_t kv_head_stride,
28+
const ptrdiff_t q_stride,
29+
const ptrdiff_t q_head_stride,
30+
musaStream_t stream) {
31+
32+
if (total_q_tokens == 0 || num_heads == 0) {
33+
return INFINI_STATUS_BAD_TENSOR_SHAPE;
34+
}
35+
36+
dim3 grid(total_q_tokens, num_heads);
37+
dim3 block(head_size);
38+
39+
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
40+
<<<grid, block, 0, stream>>>(
41+
out, q, k_cache, v_cache,
42+
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
43+
num_heads, num_kv_heads, scale,
44+
max_num_blocks_per_seq, block_size,
45+
kv_block_stride, kv_head_stride,
46+
q_stride, q_head_stride,
47+
head_size,
48+
num_seqs);
49+
50+
return INFINI_STATUS_SUCCESS;
51+
}
52+
53+
namespace op::paged_attention_prefill::moore {
54+
55+
struct Descriptor::Opaque {
56+
std::shared_ptr<device::moore::Handle::Internal> internal;
57+
};
58+
59+
Descriptor::~Descriptor() {
60+
delete _opaque;
61+
}
62+
63+
infiniStatus_t Descriptor::create(
64+
infiniopHandle_t handle,
65+
Descriptor **desc_ptr,
66+
infiniopTensorDescriptor_t out_desc,
67+
infiniopTensorDescriptor_t q_desc,
68+
infiniopTensorDescriptor_t k_cache_desc,
69+
infiniopTensorDescriptor_t v_cache_desc,
70+
infiniopTensorDescriptor_t block_tables_desc,
71+
infiniopTensorDescriptor_t seq_lens_desc,
72+
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
73+
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
74+
float scale) {
75+
76+
auto info = PagedAttentionPrefillInfo::create(
77+
out_desc, q_desc, k_cache_desc, v_cache_desc,
78+
block_tables_desc, seq_lens_desc,
79+
cum_seq_lens_q_desc,
80+
alibi_slopes_desc, scale);
81+
82+
CHECK_RESULT(info);
83+
84+
*desc_ptr = new Descriptor(
85+
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
86+
info.take(), 0, handle->device, handle->device_id);
87+
88+
return INFINI_STATUS_SUCCESS;
89+
}
90+
91+
infiniStatus_t Descriptor::calculate(
92+
void *workspace, size_t workspace_size,
93+
void *out, const void *q, const void *k_cache, const void *v_cache,
94+
const void *block_tables,
95+
const void *seq_lens,
96+
const void *cum_seq_lens_q,
97+
const void *alibi_slopes,
98+
void *stream_) const {
99+
100+
musaStream_t stream = (musaStream_t)stream_;
101+
102+
#define LAUNCH_KERNEL(Tdata, Tcompute) \
103+
launchPagedAttentionPrefill<Tdata, Tcompute>( \
104+
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
105+
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
106+
(const float *)alibi_slopes, \
107+
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
108+
_info.scale, _info.max_num_blocks_per_seq, \
109+
_info.block_size, _info.total_q_tokens, \
110+
_info.head_size, \
111+
_info.kv_block_stride, _info.kv_head_stride, \
112+
_info.q_stride, _info.q_head_stride, \
113+
stream)
114+
115+
if (_info.dtype == INFINI_DTYPE_F16) {
116+
return LAUNCH_KERNEL(half, float);
117+
} else if (_info.dtype == INFINI_DTYPE_BF16) {
118+
return LAUNCH_KERNEL(__mt_bfloat16, float);
119+
} else if (_info.dtype == INFINI_DTYPE_F32) {
120+
return LAUNCH_KERNEL(float, float);
121+
}
122+
123+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
124+
}
125+
126+
} // namespace op::paged_attention_prefill::moore

src/infiniop/ops/paged_attention_prefill/operator.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef ENABLE_NVIDIA_API
66
#include "nvidia/paged_attention_prefill_nvidia.cuh"
77
#endif
8+
#ifdef ENABLE_MOORE_API
9+
#include "moore/paged_attention_prefill_moore.h"
10+
#endif
811

912
__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
1013
infiniopHandle_t handle,
@@ -32,6 +35,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
3235
switch (handle->device) {
3336
#ifdef ENABLE_NVIDIA_API
3437
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
38+
#endif
39+
#ifdef ENABLE_MOORE_API
40+
CREATE(INFINI_DEVICE_MOORE, moore)
3541
#endif
3642
default:
3743
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -50,6 +56,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
5056
switch (desc->device_type) {
5157
#ifdef ENABLE_NVIDIA_API
5258
GET(INFINI_DEVICE_NVIDIA, nvidia)
59+
#endif
60+
#ifdef ENABLE_MOORE_API
61+
GET(INFINI_DEVICE_MOORE, moore)
5362
#endif
5463
default:
5564
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
7584
switch (desc->device_type) {
7685
#ifdef ENABLE_NVIDIA_API
7786
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
87+
#endif
88+
#ifdef ENABLE_MOORE_API
89+
CALCULATE(INFINI_DEVICE_MOORE, moore)
7890
#endif
7991
default:
8092
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
92104
switch (desc->device_type) {
93105
#ifdef ENABLE_NVIDIA_API
94106
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
107+
#endif
108+
#ifdef ENABLE_MOORE_API
109+
DESTROY(INFINI_DEVICE_MOORE, moore)
95110
#endif
96111
default:
97112
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;

0 commit comments

Comments
 (0)