Skip to content

Commit 2714eb0

Browse files
committed
issue/932 - feat: add paged caching operator referencing nvidia implementation
1 parent 92b81ac commit 2714eb0

3 files changed

Lines changed: 179 additions & 0 deletions

File tree

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __PAGED_CACHING_MOORE_H__
2+
#define __PAGED_CACHING_MOORE_H__
3+
4+
#include "../paged_caching.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif // __PAGED_CACHING_MOORE_H__
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#include "../../../devices/moore/moore_common.h"
2+
#include "../../../devices/moore/moore_kernel_common.h"
3+
#include "../cuda/kernel.cuh"
4+
#include "paged_caching_moore.h"
5+
6+
template <typename Tdata, int NUM_THREADS>
7+
INFINIOP_MOORE_KERNEL pagedCaching(
8+
Tdata *k_cache, Tdata *v_cache,
9+
const Tdata *k, const Tdata *v,
10+
const int64_t *slot_mapping,
11+
const size_t head_size, const size_t block_size,
12+
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
13+
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
14+
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
15+
k_cache, v_cache, k, v, slot_mapping, head_size,
16+
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
17+
}
18+
19+
namespace op::paged_caching::moore {
20+
// PIMPL struct definition
21+
struct Descriptor::Opaque {
22+
std::shared_ptr<device::moore::Handle::Internal> internal;
23+
};
24+
25+
// Destructor implementation
26+
Descriptor::~Descriptor() {
27+
delete _opaque;
28+
}
29+
30+
// Static factory method implementation
31+
infiniStatus_t Descriptor::create(
32+
infiniopHandle_t handle,
33+
Descriptor **desc_ptr,
34+
infiniopTensorDescriptor_t k_cache_desc,
35+
infiniopTensorDescriptor_t v_cache_desc,
36+
infiniopTensorDescriptor_t k_desc,
37+
infiniopTensorDescriptor_t v_desc,
38+
infiniopTensorDescriptor_t slot_mapping_desc) {
39+
40+
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
41+
CHECK_RESULT(info);
42+
43+
// Create and return the Descriptor instance.
44+
*desc_ptr = new Descriptor(
45+
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
46+
info.take(), 0, handle->device, handle->device_id);
47+
48+
return INFINI_STATUS_SUCCESS;
49+
}
50+
51+
// The launchKernel function is a templated helper to encapsulate the MUSA kernel launch.
52+
// It sets up grid/block dimensions and calls the device-side kernel.
53+
template <int NUM_THREADS>
54+
infiniStatus_t launchKernel(const PagedCachingInfo &info,
55+
void *k_cache, void *v_cache,
56+
infiniDtype_t dtype,
57+
const void *k, const void *v,
58+
const void *slot_mapping,
59+
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
60+
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
61+
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
62+
musaStream_t stream) {
63+
64+
// Grid dimension is 1D, with one block per token, as we decided.
65+
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
66+
// Block dimension is 1D, using the number of threads specified at compile time.
67+
dim3 block(NUM_THREADS);
68+
69+
// This kernel does not require dynamic shared memory.
70+
size_t shared_mem_size = 0;
71+
72+
// Launch the device-side MUSA kernel.
73+
if (dtype == INFINI_DTYPE_F16) {
74+
pagedCaching<half, NUM_THREADS>
75+
<<<grid, block, shared_mem_size, stream>>>(
76+
(half *)k_cache,
77+
(half *)v_cache,
78+
(const half *)k,
79+
(const half *)v,
80+
(const int64_t *)slot_mapping,
81+
head_size,
82+
block_size,
83+
k_src_stride,
84+
v_src_stride,
85+
k_cache_block_stride,
86+
v_cache_block_stride);
87+
} else if (dtype == INFINI_DTYPE_BF16) {
88+
pagedCaching<__mt_bfloat16, NUM_THREADS>
89+
<<<grid, block, shared_mem_size, stream>>>(
90+
(__mt_bfloat16 *)k_cache,
91+
(__mt_bfloat16 *)v_cache,
92+
(const __mt_bfloat16 *)k,
93+
(const __mt_bfloat16 *)v,
94+
(const int64_t *)slot_mapping,
95+
head_size,
96+
block_size,
97+
k_src_stride,
98+
v_src_stride,
99+
k_cache_block_stride,
100+
v_cache_block_stride);
101+
} else if (dtype == INFINI_DTYPE_F32) {
102+
pagedCaching<float, NUM_THREADS>
103+
<<<grid, block, shared_mem_size, stream>>>(
104+
(float *)k_cache,
105+
(float *)v_cache,
106+
(const float *)k,
107+
(const float *)v,
108+
(const int64_t *)slot_mapping,
109+
head_size,
110+
block_size,
111+
k_src_stride,
112+
v_src_stride,
113+
k_cache_block_stride,
114+
v_cache_block_stride);
115+
} else {
116+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
117+
}
118+
return INFINI_STATUS_SUCCESS;
119+
}
120+
121+
// Execution method implementation
122+
infiniStatus_t Descriptor::calculate(
123+
void *workspace, size_t workspace_size,
124+
void *k_cache, void *v_cache,
125+
const void *k, const void *v,
126+
const void *slot_mapping,
127+
void *stream_) const {
128+
129+
musaStream_t stream = (musaStream_t)stream_;
130+
131+
// Dispatch logic based on the GPU's maximum threads per block.
132+
// This allows selecting the largest, most efficient block size the hardware supports.
133+
if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_1024) {
134+
// Dispatch based on data type for a 1024-thread block.
135+
launchKernel<MOORE_BLOCK_SIZE_1024>(
136+
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
137+
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
138+
_info.k_src_stride, _info.v_src_stride,
139+
_info.k_cache_block_stride, _info.v_cache_block_stride,
140+
stream);
141+
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
142+
launchKernel<MOORE_BLOCK_SIZE_512>(
143+
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
144+
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
145+
_info.k_src_stride, _info.v_src_stride,
146+
_info.k_cache_block_stride, _info.v_cache_block_stride,
147+
stream);
148+
} else {
149+
// If the GPU is older and supports fewer threads, return an error.
150+
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
151+
}
152+
153+
return INFINI_STATUS_SUCCESS;
154+
}
155+
156+
} // namespace op::paged_caching::moore

src/infiniop/ops/paged_caching/operator.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef ENABLE_NVIDIA_API
66
#include "nvidia/paged_caching_nvidia.cuh"
77
#endif
8+
#ifdef ENABLE_MOORE_API
9+
#include "moore/paged_caching_moore.h"
10+
#endif
811
// #ifdef ENABLE_METAX_API
912
// #include "metax/paged_caching_metax.h"
1013
// #endif
@@ -28,6 +31,9 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
2831
switch (handle->device) {
2932
#ifdef ENABLE_NVIDIA_API
3033
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
34+
#endif
35+
#ifdef ENABLE_MOORE_API
36+
CREATE(INFINI_DEVICE_MOORE, moore)
3137
#endif
3238
// #ifdef ENABLE_METAX_API
3339
// CREATE(INFINI_DEVICE_METAX, metax)
@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
4955
switch (desc->device_type) {
5056
#ifdef ENABLE_NVIDIA_API
5157
GET(INFINI_DEVICE_NVIDIA, nvidia)
58+
#endif
59+
#ifdef ENABLE_MOORE_API
60+
GET(INFINI_DEVICE_MOORE, moore)
5261
#endif
5362
// #ifdef ENABLE_METAX_API
5463
// GET(INFINI_DEVICE_METAX, metax)
@@ -74,6 +83,9 @@ __C infiniStatus_t infiniopPagedCaching(
7483
switch (desc->device_type) {
7584
#ifdef ENABLE_NVIDIA_API
7685
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
86+
#endif
87+
#ifdef ENABLE_MOORE_API
88+
CALCULATE(INFINI_DEVICE_MOORE, moore)
7789
#endif
7890
// #ifdef ENABLE_METAX_API
7991
// CALCULATE(INFINI_DEVICE_METAX, metax)
@@ -94,6 +106,9 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
94106
switch (desc->device_type) {
95107
#ifdef ENABLE_NVIDIA_API
96108
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
109+
#endif
110+
#ifdef ENABLE_MOORE_API
111+
DESTROY(INFINI_DEVICE_MOORE, moore)
97112
#endif
98113
// #ifdef ENABLE_METAX_API
99114
// DESTROY(INFINI_DEVICE_METAX, metax)

0 commit comments

Comments
 (0)