Skip to content

Commit 875f7c8

Browse files
authored
Share kv cache compile spec (#18864)
Currently we blindly share kv cache cross all prefill + decode methods, making parakeet model generate garbage output. This PR creates a cuda backend spec to control the KV cache sharing across different methods. Default is not sharing.
1 parent 1f7f466 commit 875f7c8

3 files changed

Lines changed: 43 additions & 6 deletions

File tree

backends/aoti/aoti_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
class COMPILE_SPEC_KEYS(Enum):
2727
METHOD_NAME = "method_name"
28+
SHARE_KV_CACHE_ACROSS_METHODS = "share_kv_cache_across_methods"
2829

2930

3031
@experimental(
@@ -286,3 +287,13 @@ def method_name_from_compile_specs(
286287
raise RuntimeError(
287288
f"Could not find method name in compile specs: {compile_specs}"
288289
)
290+
291+
@classmethod
292+
def generate_share_kv_cache_compile_spec(cls) -> CompileSpec:
293+
"""
294+
Generate a CompileSpec to enable cross-method KV cache sharing.
295+
"""
296+
return CompileSpec(
297+
COMPILE_SPEC_KEYS.SHARE_KV_CACHE_ACROSS_METHODS.value,
298+
bytes([1]),
299+
)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ namespace {
8080
constexpr char kSkipCopyOutputToCpuForMethod[] =
8181
"skip_copy_output_to_cpu_for_method";
8282
constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream";
83+
constexpr char kShareKvCacheAcrossMethods[] = "share_kv_cache_across_methods";
8384
} // anonymous namespace
8485

8586
class ET_EXPERIMENTAL CudaBackend final
@@ -287,12 +288,17 @@ class ET_EXPERIMENTAL CudaBackend final
287288
ArrayRef<CompileSpec> compile_specs // This will be my empty list
288289
) const override {
289290
std::string method_name;
291+
bool share_kv_cache = false;
290292
for (const CompileSpec& spec : compile_specs) {
291293
if (std::strcmp(spec.key, "method_name") == 0) {
292294
method_name.assign(
293295
static_cast<const char*>(spec.value.buffer),
294296
spec.value.nbytes); // no nullptr guarantee, so pass size
295-
break;
297+
} else if (std::strcmp(spec.key, kShareKvCacheAcrossMethods) == 0) {
298+
if (spec.value.nbytes >= 1) {
299+
share_kv_cache =
300+
static_cast<const uint8_t*>(spec.value.buffer)[0] != 0;
301+
}
296302
}
297303
}
298304

@@ -416,14 +422,16 @@ class ET_EXPERIMENTAL CudaBackend final
416422
// ---------------------------------------------------------------
417423
// Cross-method constant sharing (e.g., KV cache between prefill/decode).
418424
//
425+
// Only enabled when share_kv_cache_across_methods compile spec is set.
419426
// The first container to initialize extracts its constants (keyed by
420427
// original FQN) and stores the AtenTensorHandle's. Subsequent containers
421428
// with matching FQNs are updated to point to the same GPU tensors via
422429
// UpdateUserManagedConstantBufferPairs (user_managed = true → no copy,
423430
// the source container retains ownership).
424431
// ---------------------------------------------------------------
425-
if (handle->get_num_constants && handle->get_constant_name &&
426-
handle->get_constant_original_fqn && handle->extract_constants_map &&
432+
if (share_kv_cache && handle->get_num_constants &&
433+
handle->get_constant_name && handle->get_constant_original_fqn &&
434+
handle->extract_constants_map &&
427435
handle->update_user_managed_constant_buffer_pairs) {
428436
size_t num_constants = 0;
429437
handle->get_num_constants(handle->container_handle, &num_constants);
@@ -469,6 +477,8 @@ class ET_EXPERIMENTAL CudaBackend final
469477
Error,
470478
"Failed to extract constants from '%s'",
471479
method_name.c_str());
480+
delete handle;
481+
return Error::Internal;
472482
}
473483
} else {
474484
// Subsequent container: share matching constants from the first.
@@ -501,14 +511,24 @@ class ET_EXPERIMENTAL CudaBackend final
501511
Error,
502512
"Failed to share constants into '%s'",
503513
method_name.c_str());
514+
delete handle;
515+
return Error::Internal;
504516
}
505517
}
506518
}
507519
}
520+
} else if (share_kv_cache) {
521+
ET_LOG(
522+
Error,
523+
"share_kv_cache_across_methods requested but constant sharing APIs "
524+
"not available for method '%s'",
525+
method_name.c_str());
526+
delete handle;
527+
return Error::Internal;
508528
} else {
509529
ET_LOG(
510530
Info,
511-
"Constant sharing APIs not available for method '%s'",
531+
"Constant sharing not requested for method '%s'",
512532
method_name.c_str());
513533
}
514534

examples/models/qwen3_5_moe/export.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,12 +659,18 @@ def _export_cuda(model, config, args):
659659
partitioner={
660660
"decode": [
661661
CudaPartitioner(
662-
[CudaBackend.generate_method_name_compile_spec("decode")]
662+
[
663+
CudaBackend.generate_method_name_compile_spec("decode"),
664+
CudaBackend.generate_share_kv_cache_compile_spec(),
665+
]
663666
)
664667
],
665668
"prefill": [
666669
CudaPartitioner(
667-
[CudaBackend.generate_method_name_compile_spec("prefill")]
670+
[
671+
CudaBackend.generate_method_name_compile_spec("prefill"),
672+
CudaBackend.generate_share_kv_cache_compile_spec(),
673+
]
668674
)
669675
],
670676
},

0 commit comments

Comments
 (0)