Skip to content

Commit c0e8395

Browse files
committed
Make kv cache update op support transposed cache
Differential Revision: [D93870392](https://our.internmc.facebook.com/intern/diff/D93870392/) [ghstack-poisoned]
1 parent 3ce257f commit c0e8395

5 files changed

Lines changed: 260 additions & 83 deletions

File tree

extension/llm/custom_ops/custom_ops.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,15 @@ def _validate_update_cache_params(
190190
value,
191191
cache,
192192
start_pos,
193+
is_seq_dim_2=False,
193194
indices=None,
194195
):
195-
seq_len = value.size(1)
196+
# Determine sequence dimension based on is_seq_dim_2
197+
# If is_seq_dim_2 is False: [batch, seq, heads, head_dim]
198+
# If is_seq_dim_2 is True: [batch, heads, seq, head_dim]
199+
seq_dim = 2 if is_seq_dim_2 else 1
200+
seq_len = value.size(seq_dim)
201+
196202
assert (
197203
value.dim() == 4
198204
), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."
@@ -201,22 +207,31 @@ def _validate_update_cache_params(
201207
value.dtype == cache.dtype
202208
), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}"
203209

204-
for i in [0, 2, 3]:
205-
assert value.size(i) == cache.size(
206-
i
207-
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
210+
# Validate batch and head_dim dimensions match
211+
assert value.size(0) == cache.size(
212+
0
213+
), f"Expected value and cache to have same size in dimension 0 (batch) but got {value.size(0)} and {cache.size(0)}"
214+
assert value.size(3) == cache.size(
215+
3
216+
), f"Expected value and cache to have same size in dimension 3 (head_dim) but got {value.size(3)} and {cache.size(3)}"
217+
218+
# Validate heads dimension matches based on layout
219+
heads_dim = 1 if is_seq_dim_2 else 2
220+
assert value.size(heads_dim) == cache.size(
221+
heads_dim
222+
), f"Expected value and cache to have same size in dimension {heads_dim} (heads) but got {value.size(heads_dim)} and {cache.size(heads_dim)}"
208223

209224
torch._check_is_size(start_pos)
210225
if indices is None:
211-
torch._check(start_pos < cache.size(1))
226+
torch._check(start_pos < cache.size(seq_dim))
212227
assert start_pos < cache.size(
213-
1
214-
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
228+
seq_dim
229+
), f"Start position {start_pos} must be less than sequence length {cache.size(seq_dim)}"
215230

216-
torch._check((start_pos + seq_len) <= cache.size(1))
231+
torch._check((start_pos + seq_len) <= cache.size(seq_dim))
217232
assert (start_pos + seq_len) <= cache.size(
218-
1
219-
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
233+
seq_dim
234+
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(seq_dim)}"
220235

221236
if indices is not None:
222237
assert (
@@ -229,20 +244,22 @@ def _validate_update_cache_params(
229244
0
230245
), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}"
231246
assert indices.size(1) == value.size(
232-
1
233-
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}"
247+
seq_dim
248+
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(seq_dim)}"
234249

235250

236251
@impl(custom_ops_lib, "update_cache", "Meta")
237252
def update_cache_meta(
238253
value,
239254
cache,
240255
start_pos,
256+
is_seq_dim_2=False,
241257
):
242258
_validate_update_cache_params(
243259
value,
244260
cache,
245261
start_pos,
262+
is_seq_dim_2,
246263
)
247264

248265
# Update cache doesnt really return anything but I dont know a better
@@ -257,11 +274,13 @@ def update_cache_with_indices_meta(
257274
cache,
258275
start_pos,
259276
indices,
277+
is_seq_dim_2=False,
260278
):
261279
_validate_update_cache_params(
262280
value,
263281
cache,
264282
start_pos,
283+
is_seq_dim_2,
265284
indices,
266285
)
267286

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,26 +122,30 @@ Tensor& update_cache_out_no_context(
122122
const Tensor& value,
123123
Tensor& cache,
124124
const int64_t start_pos,
125+
const bool is_seq_dim_2,
125126
Tensor& output);
126127

127128
at::Tensor update_cache_aten(
128129
const at::Tensor& value,
129130
at::Tensor& cache,
130-
const int64_t start_pos);
131+
const int64_t start_pos,
132+
const bool is_seq_dim_2);
131133

132134
// New functions for update_cache_with_indices
133135
Tensor& update_cache_with_indices_out_no_context(
134136
const Tensor& value,
135137
Tensor& cache,
136138
const int64_t start_pos,
137139
const Tensor& indices,
140+
const bool is_seq_dim_2,
138141
Tensor& output);
139142

140143
at::Tensor update_cache_with_indices_aten(
141144
const at::Tensor& value,
142145
at::Tensor& cache,
143146
const int64_t start_pos,
144-
const at::Tensor& indices);
147+
const at::Tensor& indices,
148+
const bool is_seq_dim_2);
145149

146150
Tensor& sdpa_with_kv_cache_out_no_context(
147151
const Tensor& q_projected,
@@ -338,19 +342,21 @@ Tensor& update_cache_out_no_context(
338342
const Tensor& value,
339343
Tensor& cache,
340344
const int64_t start_pos,
345+
const bool is_seq_dim_2,
341346
Tensor& output) {
342347
executorch::aten::RuntimeContext context{};
343348
return torch::executor::native::update_cache_out(
344-
context, value, cache, start_pos, output);
349+
context, value, cache, start_pos, is_seq_dim_2, output);
345350
}
346351

347352
at::Tensor update_cache_aten(
348353
const at::Tensor& value,
349354
at::Tensor& cache,
350-
const int64_t start_pos) {
355+
const int64_t start_pos,
356+
const bool is_seq_dim_2) {
351357
auto output = at::empty({1});
352-
WRAP_TO_ATEN(update_cache_out_no_context, 3)
353-
(value, cache, start_pos, output);
358+
WRAP_TO_ATEN(update_cache_out_no_context, 4)
359+
(value, cache, start_pos, is_seq_dim_2, output);
354360
return output;
355361
}
356362

@@ -360,20 +366,22 @@ Tensor& update_cache_with_indices_out_no_context(
360366
Tensor& cache,
361367
const int64_t start_pos,
362368
const Tensor& indices,
369+
const bool is_seq_dim_2,
363370
Tensor& output) {
364371
executorch::aten::RuntimeContext context{};
365372
return torch::executor::native::update_cache_with_indices_out(
366-
context, value, cache, start_pos, indices, output);
373+
context, value, cache, start_pos, indices, is_seq_dim_2, output);
367374
}
368375

369376
at::Tensor update_cache_with_indices_aten(
370377
const at::Tensor& value,
371378
at::Tensor& cache,
372379
const int64_t start_pos,
373-
const at::Tensor& indices) {
380+
const at::Tensor& indices,
381+
const bool is_seq_dim_2) {
374382
auto output = at::empty({1});
375-
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4)
376-
(value, cache, start_pos, indices, output);
383+
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 5)
384+
(value, cache, start_pos, indices, is_seq_dim_2, output);
377385
return output;
378386
}
379387

@@ -400,16 +408,16 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
400408
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
401409
m.def(
402410
"update_cache(Tensor value, Tensor(a!) cache, "
403-
"SymInt start_pos) -> Tensor");
411+
"SymInt start_pos, bool is_seq_dim_2=False) -> Tensor");
404412
m.def(
405413
"update_cache.out(Tensor value, Tensor(a!) cache, "
406-
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
414+
"SymInt start_pos, bool is_seq_dim_2=False, *, Tensor(b!) out) -> Tensor(b!)");
407415
m.def(
408416
"update_cache_with_indices(Tensor value, Tensor(a!) cache, "
409-
"SymInt start_pos, Tensor indices) -> Tensor");
417+
"SymInt start_pos, Tensor indices, bool is_seq_dim_2=False) -> Tensor");
410418
m.def(
411419
"update_cache_with_indices.out(Tensor value, Tensor(a!) cache, "
412-
"SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)");
420+
"SymInt start_pos, Tensor indices, bool is_seq_dim_2=False, *, Tensor(b!) out) -> Tensor(b!)");
413421
m.def(
414422
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
415423
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
@@ -439,15 +447,15 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
439447
m.impl("update_cache", torch::executor::native::update_cache_aten);
440448
m.impl(
441449
"update_cache.out",
442-
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
450+
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4));
443451
m.impl(
444452
"update_cache_with_indices",
445453
torch::executor::native::update_cache_with_indices_aten);
446454
m.impl(
447455
"update_cache_with_indices.out",
448456
WRAP_TO_ATEN(
449457
torch::executor::native::update_cache_with_indices_out_no_context,
450-
4));
458+
5));
451459
m.impl(
452460
"custom_quantized_sdpa",
453461
torch::executor::native::custom_quantized_sdpa_aten);

0 commit comments

Comments
 (0)