88#include "paged_attention_prefill_kernel.h"
99#include "paged_attention_prefill_moore.h"
1010
11- template < typename Tdata , typename Tcompute >
11+ template < typename Tindex , typename Tdata , typename Tcompute >
1212infiniStatus_t launchPagedAttentionPrefill (
1313 Tdata * out , const Tdata * q , const Tdata * k_cache , const Tdata * v_cache ,
14- const int64_t * block_tables ,
15- const int64_t * seq_lens ,
16- const int64_t * cum_seq_lens_q ,
14+ const Tindex * block_tables ,
15+ const Tindex * seq_lens ,
16+ const Tindex * cum_seq_lens_q ,
1717 const float * alibi_slopes ,
1818 const size_t num_heads ,
1919 const size_t num_seqs ,
@@ -36,7 +36,7 @@ infiniStatus_t launchPagedAttentionPrefill(
3636 dim3 grid (total_q_tokens , num_heads );
3737 dim3 block (head_size );
3838
39- op ::paged_attention_prefill ::cuda ::pagedAttentionPrefillKernel < Tdata , Tcompute >
39+ op ::paged_attention_prefill ::cuda ::pagedAttentionPrefillKernel < Tindex , Tdata , Tcompute >
4040 <<< grid , block , 0 , stream >>> (
4141 out , q , k_cache , v_cache ,
4242 block_tables , seq_lens , cum_seq_lens_q , alibi_slopes ,
@@ -99,10 +99,10 @@ infiniStatus_t Descriptor::calculate(
9999
100100 musaStream_t stream = (musaStream_t )stream_ ;
101101
102- #define LAUNCH_KERNEL ( Tdata , Tcompute ) \
103- launchPagedAttentionPrefill < Tdata , Tcompute > ( \
102+ #define DISPATCH_KERNEL ( Tindex , Tdata , Tcompute ) \
103+ return launchPagedAttentionPrefill < Tindex , Tdata , Tcompute > ( \
104104 (Tdata * )out , (const Tdata * )q , (const Tdata * )k_cache , (const Tdata * )v_cache , \
105- ( const int64_t * ) block_tables , ( const int64_t * ) seq_lens , ( const int64_t * ) cum_seq_lens_q , \
105+ static_cast < const Tindex * > ( block_tables ), static_cast < const Tindex * > ( seq_lens ), static_cast < const Tindex * > ( cum_seq_lens_q ) , \
106106 (const float * )alibi_slopes , \
107107 _info . num_heads , _info . num_seqs , _info . num_kv_heads , \
108108 _info . scale , _info . max_num_blocks_per_seq , \
@@ -112,12 +112,23 @@ infiniStatus_t Descriptor::calculate(
112112 _info . q_stride , _info . q_head_stride , \
113113 stream )
114114
115- if (_info . dtype == INFINI_DTYPE_F16 ) {
116- return LAUNCH_KERNEL (half , float );
117- } else if (_info . dtype == INFINI_DTYPE_BF16 ) {
118- return LAUNCH_KERNEL (__mt_bfloat16 , float );
119- } else if (_info . dtype == INFINI_DTYPE_F32 ) {
120- return LAUNCH_KERNEL (float , float );
115+ #define DISPATCH_INDEX (Tindex ) \
116+ do { \
117+ if (_info . dtype == INFINI_DTYPE_F16 ) { \
118+ DISPATCH_KERNEL (Tindex , half , float ); \
119+ } \
120+ if (_info . dtype == INFINI_DTYPE_BF16 ) { \
121+ DISPATCH_KERNEL (Tindex , __nv_bfloat16 , float ); \
122+ } \
123+ return INFINI_STATUS_BAD_TENSOR_DTYPE ; \
124+ } while (false )
125+
126+ if (_info . index_dtype == INFINI_DTYPE_I64 ){
127+ DISPATCH_INDEX (int64_t );
128+ } else if (_info . index_dtype == INFINI_DTYPE_I32 ){
129+ DISPATCH_INDEX (int32_t );
130+ } else if (_info . index_dtype == INFINI_DTYPE_U32 ){
131+ DISPATCH_INDEX (uint32_t );
121132 }
122133
123134 return INFINI_STATUS_BAD_TENSOR_DTYPE ;
0 commit comments