Skip to content

Commit 995df95

Browse files
committed
feat: add cuda kv_caching infinilm
1 parent cc0bc83 commit 995df95

8 files changed

Lines changed: 414 additions & 0 deletions

File tree

src/base/kv_caching_infinilm.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#ifndef INFINI_OPS_BASE_KV_CACHING_INFINILM_H_
2+
#define INFINI_OPS_BASE_KV_CACHING_INFINILM_H_
3+
4+
#include <cassert>
5+
6+
#include "operator.h"
7+
8+
namespace infini::ops {
9+
10+
class KvCachingInfinilm : public Operator<KvCachingInfinilm> {
11+
public:
12+
KvCachingInfinilm(const Tensor k, const Tensor v,
13+
const Tensor past_kv_lengths, Tensor k_cache,
14+
Tensor v_cache)
15+
: k_cache_shape_{k_cache.shape()},
16+
k_cache_strides_{k_cache.strides()},
17+
v_cache_shape_{v_cache.shape()},
18+
v_cache_strides_{v_cache.strides()},
19+
k_shape_{k.shape()},
20+
k_strides_{k.strides()},
21+
v_shape_{v.shape()},
22+
v_strides_{v.strides()},
23+
past_kv_lengths_shape_{past_kv_lengths.shape()},
24+
data_type_{k_cache.dtype()},
25+
past_kv_lengths_type_{past_kv_lengths.dtype()},
26+
batch_size_{k_cache.size(0)},
27+
num_kv_heads_{k_cache.size(1)},
28+
max_seq_len_{k_cache.size(2)},
29+
seq_len_{k.size(2)},
30+
hidden_size_{k_cache.size(3)},
31+
output_size_{k.numel()},
32+
device_index_{k_cache.device().index()} {
33+
assert(k_cache.ndim() == 4 && v_cache.ndim() == 4 && k.ndim() == 4 &&
34+
v.ndim() == 4 && "`KvCachingInfinilm` tensors must be 4D");
35+
assert(k_cache_shape_ == v_cache_shape_ &&
36+
"`KvCachingInfinilm` cache shapes must match");
37+
assert(k_shape_ == v_shape_ &&
38+
"`KvCachingInfinilm` source shapes must match");
39+
assert(k.size(0) == batch_size_ && k.size(1) == num_kv_heads_ &&
40+
k.size(3) == hidden_size_ &&
41+
"`KvCachingInfinilm` source shape must match cache "
42+
"batch/head/hidden dims");
43+
assert(seq_len_ <= max_seq_len_ &&
44+
"`KvCachingInfinilm` source sequence length exceeds cache length");
45+
assert(k_cache.dtype() == v_cache.dtype() && k_cache.dtype() == k.dtype() &&
46+
k_cache.dtype() == v.dtype() &&
47+
"`KvCachingInfinilm` K/V tensors must have the same dtype");
48+
assert(
49+
(data_type_ == DataType::kFloat16 ||
50+
data_type_ == DataType::kBFloat16 ||
51+
data_type_ == DataType::kFloat32) &&
52+
"`KvCachingInfinilm` K/V dtype must be float16, bfloat16, or float32");
53+
assert((past_kv_lengths_type_ == DataType::kInt32 ||
54+
past_kv_lengths_type_ == DataType::kInt64) &&
55+
"`KvCachingInfinilm` past_kv_lengths dtype must be int32 or int64");
56+
assert(past_kv_lengths.ndim() == 1 &&
57+
past_kv_lengths.size(0) == batch_size_ &&
58+
"`KvCachingInfinilm` past_kv_lengths shape must be (batch_size,)");
59+
assert(!k_cache.HasBroadcastDim() && !v_cache.HasBroadcastDim() &&
60+
"`KvCachingInfinilm` caches must not have broadcasted dimensions");
61+
}
62+
63+
virtual void operator()(const Tensor k, const Tensor v,
64+
const Tensor past_kv_lengths, Tensor k_cache,
65+
Tensor v_cache) const = 0;
66+
67+
protected:
68+
Tensor::Shape k_cache_shape_;
69+
70+
Tensor::Strides k_cache_strides_;
71+
72+
Tensor::Shape v_cache_shape_;
73+
74+
Tensor::Strides v_cache_strides_;
75+
76+
Tensor::Shape k_shape_;
77+
78+
Tensor::Strides k_strides_;
79+
80+
Tensor::Shape v_shape_;
81+
82+
Tensor::Strides v_strides_;
83+
84+
Tensor::Shape past_kv_lengths_shape_;
85+
86+
DataType data_type_;
87+
88+
DataType past_kv_lengths_type_;
89+
90+
Tensor::Size batch_size_{0};
91+
92+
Tensor::Size num_kv_heads_{0};
93+
94+
Tensor::Size max_seq_len_{0};
95+
96+
Tensor::Size seq_len_{0};
97+
98+
Tensor::Size hidden_size_{0};
99+
100+
Tensor::Size output_size_{0};
101+
102+
int device_index_{0};
103+
};
104+
105+
} // namespace infini::ops
106+
107+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef INFINI_OPS_ILUVATAR_KV_CACHING_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_ILUVATAR_KV_CACHING_INFINILM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
#include "native/cuda/iluvatar/caster.cuh"
7+
#include "native/cuda/iluvatar/runtime_.h"
8+
#include "native/cuda/ops/kv_caching_infinilm/kernel.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<KvCachingInfinilm, Device::Type::kIluvatar>
14+
: public CudaKvCachingInfinilm<Runtime<Device::Type::kIluvatar>> {
15+
public:
16+
using CudaKvCachingInfinilm<
17+
Runtime<Device::Type::kIluvatar>>::CudaKvCachingInfinilm;
18+
};
19+
20+
} // namespace infini::ops
21+
22+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef INFINI_OPS_METAX_KV_CACHING_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_METAX_KV_CACHING_INFINILM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
#include "native/cuda/metax/caster.cuh"
7+
#include "native/cuda/metax/runtime_.h"
8+
#include "native/cuda/ops/kv_caching_infinilm/kernel.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<KvCachingInfinilm, Device::Type::kMetax>
14+
: public CudaKvCachingInfinilm<Runtime<Device::Type::kMetax>> {
15+
public:
16+
using CudaKvCachingInfinilm<
17+
Runtime<Device::Type::kMetax>>::CudaKvCachingInfinilm;
18+
};
19+
20+
} // namespace infini::ops
21+
22+
#endif
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef INFINI_OPS_MOORE_KV_CACHING_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_MOORE_KV_CACHING_INFINILM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
#include "native/cuda/moore/caster.cuh"
7+
#include "native/cuda/moore/polyfills.cuh"
8+
#include "native/cuda/moore/runtime_.h"
9+
#include "native/cuda/ops/kv_caching_infinilm/kernel.h"
10+
11+
namespace infini::ops {
12+
13+
template <>
14+
class Operator<KvCachingInfinilm, Device::Type::kMoore>
15+
: public CudaKvCachingInfinilm<Runtime<Device::Type::kMoore>> {
16+
public:
17+
using CudaKvCachingInfinilm<
18+
Runtime<Device::Type::kMoore>>::CudaKvCachingInfinilm;
19+
};
20+
21+
} // namespace infini::ops
22+
23+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef INFINI_OPS_NVIDIA_KV_CACHING_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_NVIDIA_KV_CACHING_INFINILM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
#include "native/cuda/nvidia/caster.cuh"
7+
#include "native/cuda/nvidia/runtime_.h"
8+
#include "native/cuda/ops/kv_caching_infinilm/kernel.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<KvCachingInfinilm, Device::Type::kNvidia>
14+
: public CudaKvCachingInfinilm<Runtime<Device::Type::kNvidia>> {
15+
public:
16+
using CudaKvCachingInfinilm<
17+
Runtime<Device::Type::kNvidia>>::CudaKvCachingInfinilm;
18+
};
19+
20+
} // namespace infini::ops
21+
22+
#endif
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#ifndef INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_CUH_
2+
#define INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_CUH_
3+
4+
#include <cstddef>
5+
#include <cstdint>
6+
7+
namespace infini::ops {
8+
9+
template <typename T, typename TIndex, unsigned int block_size>
10+
__global__ void KvCachingInfinilmKernel(
11+
T* __restrict__ k_cache, T* __restrict__ v_cache, const T* __restrict__ k,
12+
const T* __restrict__ v, const TIndex* __restrict__ past_kv_lengths,
13+
const ptrdiff_t* __restrict__ k_cache_strides,
14+
const ptrdiff_t* __restrict__ v_cache_strides,
15+
const ptrdiff_t* __restrict__ k_strides,
16+
const ptrdiff_t* __restrict__ v_strides, size_t output_size,
17+
size_t num_kv_heads, size_t seq_len, size_t hidden_size) {
18+
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
19+
20+
for (size_t idx = tid; idx < output_size; idx += blockDim.x * gridDim.x) {
21+
size_t offset = idx;
22+
size_t d = offset % hidden_size;
23+
offset /= hidden_size;
24+
size_t s = offset % seq_len;
25+
offset /= seq_len;
26+
size_t h = offset % num_kv_heads;
27+
size_t b = offset / num_kv_heads;
28+
29+
size_t cache_s = static_cast<size_t>(past_kv_lengths[b]) + s;
30+
ptrdiff_t k_cache_offset = b * k_cache_strides[0] + h * k_cache_strides[1] +
31+
cache_s * k_cache_strides[2] +
32+
d * k_cache_strides[3];
33+
ptrdiff_t v_cache_offset = b * v_cache_strides[0] + h * v_cache_strides[1] +
34+
cache_s * v_cache_strides[2] +
35+
d * v_cache_strides[3];
36+
ptrdiff_t k_offset = b * k_strides[0] + h * k_strides[1] +
37+
s * k_strides[2] + d * k_strides[3];
38+
ptrdiff_t v_offset = b * v_strides[0] + h * v_strides[1] +
39+
s * v_strides[2] + d * v_strides[3];
40+
41+
k_cache[k_cache_offset] = k[k_offset];
42+
v_cache[v_cache_offset] = v[v_offset];
43+
}
44+
}
45+
46+
} // namespace infini::ops
47+
48+
#endif
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#ifndef INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_H_
3+
4+
#include <algorithm>
5+
#include <cstddef>
6+
#include <cstring>
7+
#include <vector>
8+
9+
#include "base/kv_caching_infinilm.h"
10+
#include "common/generic_utils.h"
11+
#include "data_type.h"
12+
#include "dispatcher.h"
13+
#include "native/cuda/ops/kv_caching_infinilm/kernel.cuh"
14+
#include "native/cuda/runtime_utils.h"
15+
16+
namespace infini::ops {
17+
18+
template <typename Backend>
19+
class CudaKvCachingInfinilm : public KvCachingInfinilm {
20+
public:
21+
CudaKvCachingInfinilm(const Tensor k, const Tensor v,
22+
const Tensor past_kv_lengths, Tensor k_cache,
23+
Tensor v_cache)
24+
: KvCachingInfinilm{k, v, past_kv_lengths, k_cache, v_cache} {
25+
constexpr size_t ndim = 4;
26+
size_t strides_size = ndim * sizeof(*d_k_cache_strides_);
27+
const size_t metadata_size = 4 * strides_size;
28+
std::vector<std::byte> metadata(metadata_size);
29+
30+
Backend::Malloc((void**)&d_metadata_, metadata_size);
31+
32+
size_t offset = 0;
33+
d_k_cache_strides_ =
34+
reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
35+
std::memcpy(metadata.data() + offset, k_cache_strides_.data(),
36+
strides_size);
37+
offset += strides_size;
38+
39+
d_v_cache_strides_ =
40+
reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
41+
std::memcpy(metadata.data() + offset, v_cache_strides_.data(),
42+
strides_size);
43+
offset += strides_size;
44+
45+
d_k_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
46+
std::memcpy(metadata.data() + offset, k_strides_.data(), strides_size);
47+
offset += strides_size;
48+
49+
d_v_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
50+
std::memcpy(metadata.data() + offset, v_strides_.data(), strides_size);
51+
52+
Backend::Memcpy(d_metadata_, metadata.data(), metadata_size,
53+
Backend::MemcpyHostToDevice);
54+
}
55+
56+
~CudaKvCachingInfinilm() { Backend::Free(d_metadata_); }
57+
58+
void operator()(const Tensor k, const Tensor v, const Tensor past_kv_lengths,
59+
Tensor k_cache, Tensor v_cache) const override {
60+
auto cuda_stream =
61+
static_cast<typename Backend::Stream>(stream_ ? stream_ : 0);
62+
int block_size = std::min(
63+
RuntimeUtils<Backend::kDeviceType>::GetOptimalBlockSize(), 1024);
64+
dim3 block(std::min(static_cast<Tensor::Size>(block_size), output_size_));
65+
dim3 grid(utils::CeilDiv(output_size_, block.x));
66+
67+
using IndexTypes = List<DataType::kInt32, DataType::kInt64>;
68+
DispatchFunc<AllFloatTypes, IndexTypes, List<128, 256, 512, 1024>>(
69+
{static_cast<int64_t>(data_type_),
70+
static_cast<int64_t>(past_kv_lengths_type_), block_size},
71+
[&](auto list_tag) {
72+
using T = TypeMapType<Backend::kDeviceType, ListGet<0>(list_tag)>;
73+
using TIndex =
74+
TypeMapType<Backend::kDeviceType, ListGet<1>(list_tag)>;
75+
constexpr int kBlockSize = ListGet<2>(list_tag);
76+
77+
KvCachingInfinilmKernel<T, TIndex, kBlockSize>
78+
<<<grid, block, 0, cuda_stream>>>(
79+
reinterpret_cast<T*>(k_cache.data()),
80+
reinterpret_cast<T*>(v_cache.data()),
81+
reinterpret_cast<const T*>(k.data()),
82+
reinterpret_cast<const T*>(v.data()),
83+
reinterpret_cast<const TIndex*>(past_kv_lengths.data()),
84+
d_k_cache_strides_, d_v_cache_strides_, d_k_strides_,
85+
d_v_strides_, output_size_, num_kv_heads_, seq_len_,
86+
hidden_size_);
87+
},
88+
"CudaKvCachingInfinilm::operator()");
89+
}
90+
91+
private:
92+
std::byte* d_metadata_{nullptr};
93+
94+
Tensor::Stride* d_k_cache_strides_{nullptr};
95+
96+
Tensor::Stride* d_v_cache_strides_{nullptr};
97+
98+
Tensor::Stride* d_k_strides_{nullptr};
99+
100+
Tensor::Stride* d_v_strides_{nullptr};
101+
};
102+
103+
} // namespace infini::ops
104+
105+
#endif

0 commit comments

Comments
 (0)