Skip to content

Commit d2806fb

Browse files
theraysmithcopybara-github
authored andcommitted
Fixed msan error by fixing padding of k_cache and v_cache
PiperOrigin-RevId: 879644219
1 parent d6c7576 commit d2806fb

8 files changed

Lines changed: 153 additions & 62 deletions

File tree

gemma/attention.cc

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
2222
#include "util/zones.h"
23+
#include "hwy/base.h"
2324
#ifndef HWY_DISABLED_TARGETS
2425
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
2526
#endif // HWY_DISABLED_TARGETS
@@ -58,8 +59,8 @@ size_t FloatsPerVector() {
5859

5960
// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
6061
// done already, reshape it to take NF into account.
61-
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
62-
if (kv.Cols() > cache.Cols()) {
62+
void MaybeReshapeCache(const size_t default_cols, MatPtrT<KV_t>& cache) {
63+
if (default_cols == cache.Cols()) {
6364
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
6465
}
6566
}
@@ -71,13 +72,50 @@ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
7172
// is a tiny fraction of the overall computation, and it is linear in the
7273
// token length.
7374
const size_t kFloatsPerTile = 2 * FloatsPerVector();
75+
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
7476
for (size_t i = 0; i < qkv_dim; i += 2) {
7577
k[i * kFloatsPerTile] = kv[i];
7678
k[i * kFloatsPerTile + 1] = kv[i + 1];
7779
}
80+
for (size_t i = qkv_dim; i < kRoundedQkvDim; i += 2) {
81+
k[i * kFloatsPerTile] = hwy::ConvertScalarTo<KV_t>(0.0f);
82+
k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo<KV_t>(0.0f);
83+
}
7884
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
85+
if (i + kFloatsPerTile <= qkv_dim) {
86+
for (size_t j = 0; j < kFloatsPerTile; j++) {
87+
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
88+
}
89+
} else {
90+
for (size_t j = 0; j < qkv_dim - i; j++) {
91+
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
92+
}
93+
for (size_t j = qkv_dim - i; j < kFloatsPerTile; j++) {
94+
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
95+
}
96+
}
97+
}
98+
for (size_t i = hwy::RoundUpTo(qkv_dim, kFloatsPerTile); i < kRoundedQkvDim;
99+
i += kFloatsPerTile) {
79100
for (size_t j = 0; j < kFloatsPerTile; j++) {
80-
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
101+
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
102+
}
103+
}
104+
}
105+
106+
// Zeros out a part of k and v that corresponds to out-of-bounds cache
107+
// positions.
108+
void TransposeOOBKVCacheRow(KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v,
109+
size_t qkv_dim) {
110+
const size_t kFloatsPerTile = 2 * FloatsPerVector();
111+
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
112+
for (size_t i = 0; i < kRoundedQkvDim; i += 2) {
113+
k[i * kFloatsPerTile] = hwy::ConvertScalarTo<KV_t>(0.0f);
114+
k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo<KV_t>(0.0f);
115+
}
116+
for (size_t i = 0; i < kRoundedQkvDim; i += kFloatsPerTile) {
117+
for (size_t j = 0; j < kFloatsPerTile; j++) {
118+
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
81119
}
82120
}
83121
}
@@ -314,23 +352,51 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
314352
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
315353
/*add=*/nullptr, env, kv_rows);
316354
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
317-
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
318-
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
355+
MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(),
356+
qbatch.KV(qi).k_cache);
357+
MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(),
358+
qbatch.KV(qi).v_cache);
319359
}
320360
const size_t kFloatsPerVector = FloatsPerVector();
361+
const size_t kRoundedTokens =
362+
hwy::RoundUpTo(num_tokens, 2 * kFloatsPerVector);
363+
const size_t kRoundedNumInterleaved =
364+
kRoundedTokens * div_qbatch.GetDivisor();
321365

322366
// Apply positional encodings for K.
323367
// Note that 2D parallelism is not worth the fork/join overhead because the
324368
// tasks are very lightweight.
325369
ParallelFor(
326-
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
370+
Parallelism::kFlat, kv_heads * kRoundedNumInterleaved, env.ctx,
327371
/*cluster_idx=*/0, Callers::kAttComputeQKV,
328372
[&](size_t task, size_t worker) HWY_ATTR {
329373
const size_t head = task % kv_heads;
330374
const size_t interleaved_idx = task / kv_heads;
331375
const size_t qi = div_qbatch.Remainder(interleaved_idx);
332376
const size_t token_idx = div_qbatch.Divide(interleaved_idx);
333377
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
378+
if (token_idx >= kRoundedTokens) {
379+
return;
380+
}
381+
// The innermost dimension of v is 2NF values from qkv_dim because they
382+
// will be loaded into a BF16 vector to be scaled and added to the
383+
// cached attention output in 2 NF-sized registers.
384+
auto& k_cache = qbatch.KV(qi).k_cache;
385+
KV_t* HWY_RESTRICT k =
386+
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
387+
qbatch.KV(qi).cache->KOffset(layer_idx, head, kFloatsPerVector,
388+
cache_pos);
389+
auto& v_cache = qbatch.KV(qi).v_cache;
390+
KV_t* HWY_RESTRICT v =
391+
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
392+
qbatch.KV(qi).cache->VOffset(layer_idx, head, kFloatsPerVector,
393+
cache_pos);
394+
if (token_idx >= num_tokens) {
395+
// Create a zero-filled K/V pair for padding for out-of-sequence
396+
// tokens.
397+
TransposeOOBKVCacheRow(k, v, qkv_dim);
398+
return;
399+
}
334400
// --seq_len must be large enough to avoid wraparound.
335401
HWY_DASSERT(cache_pos < activations.SeqLen());
336402
auto& kv_cache = qbatch.KV(qi).kv_cache;
@@ -341,22 +407,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
341407
// The innermost dimension of k is 2 values from qkv_dim because they
342408
// are going to be used in a BF16 dot product involving pairs of
343409
// values over NF k positions.
344-
// The innermost dimension of v is 2NF values from qkv_dim because they
345-
// will be loaded into a BF16 vector to be scaled and added to the
346-
// cached attention output in 2 NF-sized registers.
347-
// TODO(rays): factor out these calculations into functions.
348-
auto& k_cache = qbatch.KV(qi).k_cache;
349-
KV_t* HWY_RESTRICT k =
350-
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
351-
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
352-
kFloatsPerVector +
353-
(cache_pos % (2 * kFloatsPerVector)) * 2;
354-
auto& v_cache = qbatch.KV(qi).v_cache;
355-
KV_t* HWY_RESTRICT v =
356-
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
357-
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
358-
kFloatsPerVector +
359-
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;
360410

361411
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
362412
const hn::ScalableTag<float> df;

gemma/attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace gcpp {
3333
namespace NAMESPACE { \
3434
size_t FloatsPerVector(); \
3535
\
36-
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
36+
void MaybeReshapeCache(size_t default_cols, MatPtrT<KV_t>& cache); \
3737
\
3838
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
3939
KV_t* HWY_RESTRICT v, size_t qkv_dim); \

gemma/configs.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@
2929
#include "io/fields.h" // IFieldsVisitor
3030
#include "io/io.h" // Path
3131
#include "util/basics.h"
32+
#include "hwy/detect_compiler_arch.h"
3233

3334
namespace gcpp {
3435

36+
constexpr size_t kMaxBF16PerVector = HWY_ARCH_MAX_BYTES / sizeof(BF16);
37+
3538
HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
3639

3740
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;

gemma/flash_attention.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,7 +1700,6 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
17001700
// A "head group" in the context of GQA refers to a collection of query
17011701
// heads that share the same key and value heads.
17021702
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
1703-
const size_t cache_layer_size = layer_config.CacheLayerSize();
17041703
const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
17051704
const size_t total_tasks = token_batch * layer_config.heads;
17061705
size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, total_tasks,
@@ -1716,11 +1715,9 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
17161715
params.clear();
17171716
for (uint32_t qi = 0; qi < div_qbatch.GetDivisor(); ++qi) {
17181717
for (uint32_t kv_head = 0; kv_head < layer_config.kv_heads; ++kv_head) {
1719-
const size_t head_offset = kv_head * qkv_dim * 2;
1720-
const uint32_t kv_offset = layer_idx * cache_layer_size + head_offset;
17211718
params.push_back(Tile148Params{
17221719
.qi_index = qi,
1723-
.kv_offset = kv_offset,
1720+
.kv_head = kv_head,
17241721
});
17251722
for (uint32_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
17261723
const size_t pos = qbatch.Pos(qi) + batch_idx;
@@ -1746,7 +1743,7 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
17461743
// current tile is full so start new tile.
17471744
params.push_back(Tile148Params{
17481745
.qi_index = qi,
1749-
.kv_offset = kv_offset,
1746+
.kv_head = kv_head,
17501747
});
17511748
}
17521749
const size_t head = head_group + kHeadGroups * kv_head;
@@ -2157,13 +2154,20 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
21572154
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
21582155
auto& param = params[task];
21592156
auto& kT_cache = qbatch.KV(param.qi_index).k_cache;
2157+
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
21602158
MatPtrT<KV_t> kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
2161-
qkv_dim * 2 * kNF));
2162-
kT.SetPtr(kT_cache.Row(0) + param.kv_offset * kNF, kT_cache.Stride());
2159+
kRoundedQkvDim * 2 * kNF));
2160+
kT.SetPtr(
2161+
kT_cache.Row(0) + qbatch.KV(param.qi_index)
2162+
.cache->KOrVOffset(layer_idx, param.kv_head, kNF),
2163+
kT_cache.Stride());
21632164
auto& vT_cache = qbatch.KV(param.qi_index).v_cache;
21642165
MatPtrT<KV_t> vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
2165-
qkv_dim * 2 * kNF));
2166-
vT.SetPtr(vT_cache.Row(0) + param.kv_offset * kNF, vT_cache.Stride());
2166+
kRoundedQkvDim * 2 * kNF));
2167+
vT.SetPtr(
2168+
vT_cache.Row(0) + qbatch.KV(param.qi_index)
2169+
.cache->KOrVOffset(layer_idx, param.kv_head, kNF),
2170+
vT_cache.Stride());
21672171
MatPtrT<float>& att_out =
21682172
param.i_of_n == 0 ? activations.att_out : activations.att_out_reps;
21692173
DispatchTileFlashAttention148(param, activations.q_bf, kT, vT, layer_idx,

gemma/flash_attention_test.cc

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,15 @@ void TestFlashAttention(size_t target_parallelism,
144144
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
145145
const size_t seq_len =
146146
static_cast<size_t>(attention.div_seq_len.GetDivisor());
147-
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
148-
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
147+
MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(),
148+
qbatch.KV(0).k_cache);
149+
MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(),
150+
qbatch.KV(0).v_cache);
149151
auto& kvc = qbatch.KV(0).kv_cache;
150-
const size_t kFloatsPerTile = 2 * FloatsPerVector();
152+
using DF = hn::ScalableTag<float>;
153+
const DF df;
154+
const size_t kNF = hn::Lanes(df);
155+
const size_t kFloatsPerTile = 2 * kNF;
151156
for (size_t h = 0; h < layer_config.heads; ++h) {
152157
// Make strided views into the kv cache for
153158
// this query and head.
@@ -160,12 +165,12 @@ void TestFlashAttention(size_t target_parallelism,
160165
SetMat(h + layer_config.heads * 2, v);
161166
for (size_t p = 0; p < tokens.size(); ++p) {
162167
KV_t* HWY_RESTRICT k_src = k.Row(p);
163-
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
164-
head_offset * kFloatsPerTile / 2 +
165-
p % kFloatsPerTile * 2;
166-
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
167-
head_offset * kFloatsPerTile / 2 +
168-
p % kFloatsPerTile * kFloatsPerTile;
168+
KV_t* HWY_RESTRICT k_dest =
169+
qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
170+
qbatch.KV(0).cache->KOffset(0, h / kHeadGroups, kNF, p);
171+
KV_t* HWY_RESTRICT v_dest =
172+
qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
173+
qbatch.KV(0).cache->VOffset(0, h / kHeadGroups, kNF, p);
169174

170175
TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
171176
}
@@ -176,9 +181,6 @@ void TestFlashAttention(size_t target_parallelism,
176181
// Copy the output to saved_att to allow for comparison.
177182
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
178183
SetMat(1, attention.q);
179-
using DF = hn::ScalableTag<float>;
180-
const DF df;
181-
const size_t kNF = hn::Lanes(df);
182184
const size_t total_tasks =
183185
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
184186
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),

gemma/flash_structs.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ struct Tile148Params {
4848
uint32_t max_last_pos = 0;
4949
// Index into the qbatch.KV is the same for each row in the tile.
5050
uint32_t qi_index;
51-
// Index into the kv_cache is the same for each row in the tile.
52-
uint32_t kv_offset;
51+
// kv_head is the same for each row in the tile.
52+
uint32_t kv_head;
5353
// In the original task, the index to the split tasks of the first split task.
5454
uint32_t split_index = 0;
5555
// The index of the split for running split attention.

gemma/kv_cache.cc

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@
2929

3030
namespace gcpp {
3131

32-
// TODO: rays - Remove this once hwy is updated.
33-
#ifndef HWY_ARCH_MAX_BYTES
34-
#define HWY_ARCH_MAX_BYTES 256
35-
#endif
36-
3732
// Number of rows for KV cache. Note that both rows and cols are u32, and
3833
// the total number of elements can exceed 2^32.
3934
static size_t CappedSeqLen(const ModelConfig& config,
@@ -46,8 +41,13 @@ static size_t CappedSeqLen(const ModelConfig& config,
4641
return inference_args.seq_len;
4742
}
4843

49-
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
50-
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
44+
KVCache::KVCache(const Extents2D& kv_extents, size_t num_layers,
45+
size_t kv_heads, size_t qkv_dim, const Allocator& allocator)
46+
: num_layers(num_layers),
47+
kv_heads(kv_heads),
48+
qkv_dim(qkv_dim),
49+
rounded_qkv_dim(hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector)),
50+
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
5151
// WARNING: the rows and cols of k_cache and v_cache will be modified
5252
// before use!
5353
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
@@ -56,22 +56,21 @@ KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
5656
// machine architecture, since kFloatsPerVector is architecture dependent.
5757
// The change is shape is safe only if the padding is kPacked.
5858
k_cache("k",
59-
Extents2D(HWY_MAX(kv_extents.rows,
60-
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
61-
kv_extents.cols / 2),
59+
Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector),
60+
KOrVDefaultCols()),
6261
allocator, MatPadding::kPacked),
6362
v_cache("v",
64-
Extents2D(HWY_MAX(kv_extents.rows,
65-
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
66-
kv_extents.cols / 2),
63+
Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector),
64+
KOrVDefaultCols()),
6765
allocator, MatPadding::kPacked),
6866
allocator_(allocator) {}
6967

7068
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
7169
const Allocator& allocator)
7270
: KVCache(
7371
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
74-
allocator) {}
72+
config.layer_configs.size(), config.layer_configs[0].kv_heads,
73+
config.layer_configs[0].qkv_dim, allocator) {}
7574

7675
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
7776
const RuntimeConfig& runtime_config,
@@ -135,7 +134,7 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
135134
}
136135

137136
KVCache KVCache::Copy() {
138-
KVCache copy(kv_cache.Extents(), allocator_);
137+
KVCache copy(kv_cache.Extents(), num_layers, kv_heads, qkv_dim, allocator_);
139138

140139
CopyMat(kv_cache, copy.kv_cache);
141140
return copy;

gemma/kv_cache.h

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,38 @@ struct KVCache {
9191
return {start_ptr, source_ptr};
9292
}
9393

94+
// Returns the default size of a row in k_cache or v_cache, before scaling by
95+
// 2 * kNF.
96+
size_t KOrVDefaultCols() const {
97+
return num_layers * kv_heads * rounded_qkv_dim;
98+
}
99+
100+
// Returns an offset into a row of k_cache or v_cache at a position that is
101+
// aligned to the tile size (a multiple of 2kNF).
102+
size_t KOrVOffset(const size_t layer_idx, const size_t kv_head_idx,
103+
const size_t kNF) const {
104+
return (layer_idx * kv_heads + kv_head_idx) * rounded_qkv_dim * 2 * kNF;
105+
}
106+
107+
// Returns an offset into k_cache at any given position.
108+
size_t KOffset(const size_t layer_idx, const size_t kv_head_idx,
109+
const size_t kNF, const size_t pos) const {
110+
return KOrVOffset(layer_idx, kv_head_idx, kNF) + (pos % (2 * kNF)) * 2;
111+
}
112+
113+
// Returns an offset into v_cache at any given position.
114+
size_t VOffset(const size_t layer_idx, const size_t kv_head_idx,
115+
const size_t kNF, const size_t pos) const {
116+
return KOrVOffset(layer_idx, kv_head_idx, kNF) +
117+
(pos % (2 * kNF)) * 2 * kNF;
118+
}
119+
120+
// Saved sizes for computing offsets into the KV cache.
121+
size_t num_layers = 0;
122+
size_t kv_heads = 0;
123+
size_t qkv_dim = 0;
124+
size_t rounded_qkv_dim = 0;
125+
94126
static constexpr size_t kTileSize = 32;
95127
std::optional<uint32_t> tiled_seq_len = std::nullopt;
96128
// Default Format
@@ -159,7 +191,8 @@ struct KVCache {
159191
const Allocator& allocator_;
160192

161193
// For use by other ctor and Copy()
162-
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
194+
KVCache(const Extents2D& kv_extents, size_t num_layers, size_t kv_heads,
195+
size_t qkv_dim, const Allocator& allocator);
163196
};
164197

165198
inline size_t KVCachePtr::SeqLen() const {

0 commit comments

Comments
 (0)