Skip to content

Commit 7f426b2

Browse files
committed
issue/1061 - feat: use template to replace int64_t in paged_attention_prefill kernel for moore gpu
1 parent e60985d commit 7f426b2

2 files changed

Lines changed: 34 additions & 21 deletions

File tree

src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
22
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
33
namespace op::paged_attention_prefill::cuda {
4-
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cum_seq_lens_q, size_t num_seqs) {
4+
5+
template <typename Tindex>
6+
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const Tindex *cum_seq_lens_q, size_t num_seqs) {
57
size_t low = 0, high = num_seqs - 1;
68
while (low <= high) {
79
size_t mid = (low + high) >> 1;
@@ -48,12 +50,12 @@ __device__ __forceinline__ float blockReduceSum(float val) {
4850
return shared[0];
4951
}
5052

51-
template <typename Tdata, typename Tcompute>
53+
template <typename Tindex, typename Tdata, typename Tcompute>
5254
__global__ void pagedAttentionPrefillKernel(
5355
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
54-
const int64_t *block_tables_,
55-
const int64_t *total_kv_lens_,
56-
const int64_t *cum_seq_lens_q_,
56+
const Tindex *block_tables_,
57+
const Tindex *total_kv_lens_,
58+
const Tindex *cum_seq_lens_q_,
5759
const float *alibi_slopes_,
5860
const size_t num_heads, const size_t num_kv_heads, const float scale,
5961
const size_t max_num_blocks_per_seq, const size_t block_size,
@@ -75,7 +77,7 @@ __global__ void pagedAttentionPrefillKernel(
7577
__shared__ float sh_w;
7678
__shared__ float sh_inv_l;
7779
if (dim_idx == 0) {
78-
sh_seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs);
80+
sh_seq_idx = find_seq_id<Tindex>(global_token_idx, cum_seq_lens_q_, num_seqs);
7981
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx]);
8082
const size_t total_kv_len = static_cast<size_t>(total_kv_lens_[sh_seq_idx]);
8183
const size_t q_len = static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx + 1] - cum_seq_lens_q_[sh_seq_idx]);
@@ -90,7 +92,7 @@ __global__ void pagedAttentionPrefillKernel(
9092
const size_t kv_head_idx = sh_kv_head_idx;
9193
const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
9294
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
93-
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
95+
const Tindex *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
9496
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
9597
const float qv = static_cast<float>(q_vec[dim_idx]);
9698
Tcompute acc = 0.0f;

src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
#include "paged_attention_prefill_kernel.h"
99
#include "paged_attention_prefill_moore.h"
1010

11-
template <typename Tdata, typename Tcompute>
11+
template <typename Tindex, typename Tdata, typename Tcompute>
1212
infiniStatus_t launchPagedAttentionPrefill(
1313
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
14-
const int64_t *block_tables,
15-
const int64_t *seq_lens,
16-
const int64_t *cum_seq_lens_q,
14+
const Tindex *block_tables,
15+
const Tindex *seq_lens,
16+
const Tindex *cum_seq_lens_q,
1717
const float *alibi_slopes,
1818
const size_t num_heads,
1919
const size_t num_seqs,
@@ -36,7 +36,7 @@ infiniStatus_t launchPagedAttentionPrefill(
3636
dim3 grid(total_q_tokens, num_heads);
3737
dim3 block(head_size);
3838

39-
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
39+
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tindex, Tdata, Tcompute>
4040
<<<grid, block, 0, stream>>>(
4141
out, q, k_cache, v_cache,
4242
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
@@ -99,10 +99,10 @@ infiniStatus_t Descriptor::calculate(
9999

100100
musaStream_t stream = (musaStream_t)stream_;
101101

102-
#define LAUNCH_KERNEL(Tdata, Tcompute) \
103-
launchPagedAttentionPrefill<Tdata, Tcompute>( \
102+
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
103+
return launchPagedAttentionPrefill<Tindex, Tdata, Tcompute>( \
104104
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
105-
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
105+
static_cast<const Tindex *>(block_tables), static_cast<const Tindex *>(seq_lens), static_cast<const Tindex *>(cum_seq_lens_q), \
106106
(const float *)alibi_slopes, \
107107
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
108108
_info.scale, _info.max_num_blocks_per_seq, \
@@ -112,12 +112,23 @@ infiniStatus_t Descriptor::calculate(
112112
_info.q_stride, _info.q_head_stride, \
113113
stream)
114114

115-
if (_info.dtype == INFINI_DTYPE_F16) {
116-
return LAUNCH_KERNEL(half, float);
117-
} else if (_info.dtype == INFINI_DTYPE_BF16) {
118-
return LAUNCH_KERNEL(__mt_bfloat16, float);
119-
} else if (_info.dtype == INFINI_DTYPE_F32) {
120-
return LAUNCH_KERNEL(float, float);
115+
#define DISPATCH_INDEX(Tindex) \
116+
do { \
117+
if (_info.dtype == INFINI_DTYPE_F16) { \
118+
DISPATCH_KERNEL(Tindex, half, float); \
119+
} \
120+
if (_info.dtype == INFINI_DTYPE_BF16) { \
121+
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
122+
} \
123+
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
124+
} while (false)
125+
126+
if (_info.index_dtype == INFINI_DTYPE_I64){
127+
DISPATCH_INDEX(int64_t);
128+
} else if (_info.index_dtype == INFINI_DTYPE_I32){
129+
DISPATCH_INDEX(int32_t);
130+
} else if (_info.index_dtype == INFINI_DTYPE_U32){
131+
DISPATCH_INDEX(uint32_t);
121132
}
122133

123134
return INFINI_STATUS_BAD_TENSOR_DTYPE;

0 commit comments

Comments
 (0)