Skip to content

Commit c14e29d

Browse files
committed
issue/923 - ninetoothed kv_caching
1 parent 53a1969 commit c14e29d

12 files changed

Lines changed: 360 additions & 44 deletions

File tree

include/infinicore/ops/kv_caching.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#pragma
1+
#pragma once
22

33
#include "../device.hpp"
44
#include "common/op.hpp"
@@ -15,11 +15,6 @@ class KVCaching {
1515
static common::OpDispatcher<schema> &dispatcher();
1616
};
1717

18-
Tensor kv_caching(Tensor k_cache,
19-
Tensor v_cache,
20-
Tensor k,
21-
Tensor v,
22-
Tensor past_kv_lengths);
2318
void kv_caching_(Tensor k_cache,
2419
Tensor v_cache,
2520
Tensor k,

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from infinicore.ops.add import add
4646
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
4747
from infinicore.ops.attention import attention
48+
from infinicore.ops.kv_caching import kv_caching
4849
from infinicore.ops.matmul import matmul
4950
from infinicore.ops.mul import mul
5051
from infinicore.ops.narrow import narrow
@@ -115,6 +116,7 @@
115116
"add_rms_norm",
116117
"add_rms_norm_",
117118
"attention",
119+
"kv_caching",
118120
"matmul",
119121
"mul",
120122
"narrow",
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from infinicore.lib import _infinicore
2+
3+
4+
def kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
5+
_infinicore.kv_caching_(
6+
k_cache._underlying,
7+
v_cache._underlying,
8+
k._underlying,
9+
v._underlying,
10+
past_kv_lengths._underlying,
11+
)
12+
13+
return k_cache, v_cache

src/infinicore/ops/kv_caching/kv_caching.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,6 @@ void KVCaching::execute(Tensor k_cache,
2828
func(k_cache, v_cache, k, v, past_kv_lengths);
2929
}
3030

31-
Tensor kv_caching(Tensor k_cache,
32-
Tensor v_cache,
33-
Tensor k,
34-
Tensor v,
35-
Tensor past_kv_lengths) {
36-
KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths);
37-
return k_cache; // or v_cache, depending on the intended use
38-
}
39-
4031
void kv_caching_(Tensor k_cache,
4132
Tensor v_cache,
4233
Tensor k,

src/infinicore/pybind11/ops.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "ops/causal_softmax.hpp"
99
#include "ops/embedding.hpp"
1010
#include "ops/flash_attention.hpp"
11+
#include "ops/kv_caching.hpp"
1112
#include "ops/linear.hpp"
1213
#include "ops/matmul.hpp"
1314
#include "ops/mul.hpp"
@@ -30,20 +31,21 @@ inline void bind(py::module &m) {
3031
bind_add_rms_norm(m);
3132
bind_attention(m);
3233
bind_causal_softmax(m);
34+
bind_embedding(m);
3335
bind_flash_attention(m);
34-
bind_random_sample(m);
36+
bind_kv_caching(m);
3537
bind_linear(m);
3638
bind_matmul(m);
3739
bind_mul(m);
3840
bind_paged_attention(m);
3941
bind_paged_attention_prefill(m);
4042
bind_paged_caching(m);
43+
bind_random_sample(m);
4144
bind_rearrange(m);
4245
bind_rms_norm(m);
46+
bind_rope(m);
4347
bind_silu(m);
4448
bind_swiglu(m);
45-
bind_rope(m);
46-
bind_embedding(m);
4749
}
4850

4951
} // namespace infinicore::ops
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/kv_caching.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_kv_caching(py::module &m) {
12+
m.def("kv_caching_",
13+
&op::kv_caching_,
14+
py::arg("k_cache"),
15+
py::arg("v_cache"),
16+
py::arg("k"),
17+
py::arg("v"),
18+
py::arg("past_kv_lengths"),
19+
R"doc(In-place Key-Value Caching.
20+
21+
Updates the KV cache in-place with new key and value tensors.
22+
23+
Args:
24+
k_cache: Key cache tensor to update in-place
25+
v_cache: Value cache tensor to update in-place
26+
k: New key tensor to append
27+
v: New value tensor to append
28+
past_kv_lengths: Tensor containing current sequence lengths for each batch
29+
)doc");
30+
}
31+
32+
} // namespace infinicore::ops

src/infiniop/ops/flash_attention/ninetoothed/descriptor.h

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,26 @@ class Descriptor final : public InfiniopDescriptor {
6767
constexpr auto block_size_m_{64};
6868
constexpr auto block_size_n_{64};
6969

70-
launch_flash_attention(stream,
71-
query,
72-
key,
73-
value,
74-
attn_mask,
75-
is_causal,
76-
scale,
77-
output,
78-
with_attn_mask,
79-
causal_variant,
80-
with_kv_cache_,
81-
emb_dim_,
82-
is_causal_,
83-
with_attn_mask_,
84-
causal_variant_,
85-
dtype_,
86-
block_size_m_,
87-
block_size_n_);
70+
if (launch_flash_attention(stream,
71+
query,
72+
key,
73+
value,
74+
attn_mask,
75+
is_causal,
76+
scale,
77+
output,
78+
with_attn_mask,
79+
causal_variant,
80+
with_kv_cache_,
81+
emb_dim_,
82+
is_causal_,
83+
with_attn_mask_,
84+
causal_variant_,
85+
dtype_,
86+
block_size_m_,
87+
block_size_n_)) {
88+
return INFINI_STATUS_NOT_IMPLEMENTED;
89+
}
8890

8991
return INFINI_STATUS_SUCCESS;
9092
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import ninetoothed
2+
from ntops.kernels import kv_caching
3+
4+
import infiniop.ninetoothed.build
5+
6+
7+
def build():
8+
dtype_values = (
9+
ninetoothed.float16,
10+
ninetoothed.bfloat16,
11+
ninetoothed.float32,
12+
)
13+
14+
constexpr_param_grid = {
15+
"emb_dim": (1, 16, 32, 64, 128, 256),
16+
"dtype": dtype_values,
17+
"block_size_m": (64,),
18+
"block_size_n": (64,),
19+
}
20+
21+
infiniop.ninetoothed.build.build(
22+
kv_caching.premake,
23+
constexpr_param_grid,
24+
caller="cuda",
25+
op_name="kv_caching",
26+
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
27+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#ifndef KV_CACHING_H
2+
#define KV_CACHING_H
3+
4+
#include "../../../handle.h"
5+
#include "../../../operator.h"
6+
#include "../../../tensor.h"
7+
8+
#include "../../../../../build/ninetoothed/kv_caching.h"
9+
#include "../../../ninetoothed/utils.h"
10+
11+
namespace op::kv_caching::ninetoothed {
12+
class Descriptor final : public InfiniopDescriptor {
13+
14+
public:
15+
Descriptor(
16+
infiniopHandle_t handle,
17+
infiniopTensorDescriptor_t k_cache_desc,
18+
infiniopTensorDescriptor_t v_cache_desc,
19+
infiniopTensorDescriptor_t k_desc,
20+
infiniopTensorDescriptor_t v_desc,
21+
infiniopTensorDescriptor_t past_kv_lengths_desc) : InfiniopDescriptor{handle->device, handle->device_id},
22+
k_cache_shape_{k_cache_desc->shape()},
23+
k_cache_strides_{k_cache_desc->strides()},
24+
v_cache_shape_{v_cache_desc->shape()},
25+
v_cache_strides_{v_cache_desc->strides()},
26+
k_shape_{k_desc->shape()},
27+
k_strides_{k_desc->strides()},
28+
v_shape_{v_desc->shape()},
29+
v_strides_{v_desc->strides()},
30+
past_kv_lengths_shape_{past_kv_lengths_desc->shape()},
31+
past_kv_lengths_strides_{past_kv_lengths_desc->strides()},
32+
dtype_{k_desc->dtype()} {}
33+
34+
~Descriptor() = default;
35+
36+
size_t get_workspace_size() const { return 0; };
37+
38+
static infiniStatus_t create(
39+
infiniopHandle_t handle,
40+
Descriptor **desc_ptr,
41+
infiniopTensorDescriptor_t k_cache,
42+
infiniopTensorDescriptor_t v_cache,
43+
infiniopTensorDescriptor_t k,
44+
infiniopTensorDescriptor_t v,
45+
infiniopTensorDescriptor_t past_kv_lengths) {
46+
*desc_ptr = new Descriptor{handle, k_cache, v_cache, k, v, past_kv_lengths};
47+
return INFINI_STATUS_SUCCESS;
48+
}
49+
50+
infiniStatus_t calculate(
51+
void *workspace, size_t workspace_size,
52+
void *k_cache,
53+
void *v_cache,
54+
const void *k,
55+
const void *v,
56+
const void *past_kv_lengths,
57+
void *stream) const {
58+
auto k_cache_nt{::ninetoothed::Tensor{k_cache, k_cache_shape_, k_cache_strides_}};
59+
auto v_cache_nt{::ninetoothed::Tensor{v_cache, v_cache_shape_, v_cache_strides_}};
60+
auto k_nt{::ninetoothed::Tensor{k, k_shape_, k_strides_}};
61+
auto v_nt{::ninetoothed::Tensor{v, v_shape_, v_strides_}};
62+
auto past_kv_lengths_nt{::ninetoothed::Tensor{past_kv_lengths, past_kv_lengths_shape_, past_kv_lengths_strides_}};
63+
64+
if (launch_kv_caching(stream,
65+
k_cache_nt,
66+
v_cache_nt,
67+
k_nt,
68+
v_nt,
69+
past_kv_lengths_nt,
70+
k_shape_[3],
71+
dtype_,
72+
64, 64)) {
73+
return INFINI_STATUS_NOT_IMPLEMENTED;
74+
}
75+
76+
return INFINI_STATUS_SUCCESS;
77+
}
78+
79+
private:
80+
using Size = ::ninetoothed::Tensor<>::Size;
81+
using Stride = ::ninetoothed::Tensor<>::Stride;
82+
83+
std::vector<Size> k_cache_shape_;
84+
std::vector<Stride> k_cache_strides_;
85+
86+
std::vector<Size> v_cache_shape_;
87+
std::vector<Stride> v_cache_strides_;
88+
89+
std::vector<Size> k_shape_;
90+
std::vector<Stride> k_strides_;
91+
std::vector<Size> v_shape_;
92+
std::vector<Stride> v_strides_;
93+
94+
std::vector<Size> past_kv_lengths_shape_;
95+
std::vector<Stride> past_kv_lengths_strides_;
96+
97+
infiniDtype_t dtype_;
98+
};
99+
} // namespace op::kv_caching::ninetoothed
100+
101+
#endif // KV_CACHING_H

src/infiniop/ops/kv_caching/operator.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
// #include "cpu/kv_caching_cpu.h"
77
#endif
88
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
9-
// #include "nvidia/kv_caching_nvidia.cuh"
9+
#if defined(ENABLE_NINETOOTHED)
10+
#include "ninetoothed/kv_caching.h"
11+
#endif
1012
#endif
1113

1214
__C infiniStatus_t infiniopCreateKVCachingDescriptor(
@@ -35,7 +37,9 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor(
3537
// CREATE(INFINI_DEVICE_CPU, cpu);
3638
#endif
3739
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
38-
// CREATE(INFINI_DEVICE_NVIDIA, nvidia);
40+
#if defined(ENABLE_NINETOOTHED)
41+
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
42+
#endif
3943
#endif
4044
default:
4145
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -59,7 +63,9 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
5963
// GET_SIZE(INFINI_DEVICE_CPU, cpu);
6064
#endif
6165
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
62-
// GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia);
66+
#if defined(ENABLE_NINETOOTHED)
67+
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
68+
#endif
6369
#endif
6470
default:
6571
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -90,7 +96,9 @@ __C infiniStatus_t infiniopKVCaching(
9096
// CALCULATE(INFINI_DEVICE_CPU, cpu);
9197
#endif
9298
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
93-
// CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
99+
#if defined(ENABLE_NINETOOTHED)
100+
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
101+
#endif
94102
#endif
95103
default:
96104
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -112,7 +120,9 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor(
112120
// DELETE(INFINI_DEVICE_CPU, cpu);
113121
#endif
114122
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
115-
// DELETE(INFINI_DEVICE_NVIDIA, nvidia);
123+
#if defined(ENABLE_NINETOOTHED)
124+
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
125+
#endif
116126
#endif
117127
default:
118128
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;

0 commit comments

Comments
 (0)