@@ -80,6 +80,7 @@ namespace {
8080constexpr char kSkipCopyOutputToCpuForMethod [] =
8181 " skip_copy_output_to_cpu_for_method" ;
8282constexpr char kUseSharedCudaStream [] = " use_shared_cuda_stream" ;
83+ constexpr char kShareKvCacheAcrossMethods [] = " share_kv_cache_across_methods" ;
8384} // anonymous namespace
8485
8586class ET_EXPERIMENTAL CudaBackend final
@@ -287,12 +288,17 @@ class ET_EXPERIMENTAL CudaBackend final
287288 ArrayRef<CompileSpec> compile_specs // This will be my empty list
288289 ) const override {
289290 std::string method_name;
291+ bool share_kv_cache = false ;
290292 for (const CompileSpec& spec : compile_specs) {
291293 if (std::strcmp (spec.key , " method_name" ) == 0 ) {
292294 method_name.assign (
293295 static_cast <const char *>(spec.value .buffer ),
294296 spec.value .nbytes ); // no nullptr guarantee, so pass size
295- break ;
297+ } else if (std::strcmp (spec.key , kShareKvCacheAcrossMethods ) == 0 ) {
298+ if (spec.value .nbytes >= 1 ) {
299+ share_kv_cache =
300+ static_cast <const uint8_t *>(spec.value .buffer )[0 ] != 0 ;
301+ }
296302 }
297303 }
298304
@@ -416,14 +422,16 @@ class ET_EXPERIMENTAL CudaBackend final
416422 // ---------------------------------------------------------------
417423 // Cross-method constant sharing (e.g., KV cache between prefill/decode).
418424 //
425+ // Only enabled when share_kv_cache_across_methods compile spec is set.
419426 // The first container to initialize extracts its constants (keyed by
420427 // original FQN) and stores the AtenTensorHandle's. Subsequent containers
421428 // with matching FQNs are updated to point to the same GPU tensors via
422429 // UpdateUserManagedConstantBufferPairs (user_managed = true → no copy,
423430 // the source container retains ownership).
424431 // ---------------------------------------------------------------
425- if (handle->get_num_constants && handle->get_constant_name &&
426- handle->get_constant_original_fqn && handle->extract_constants_map &&
432+ if (share_kv_cache && handle->get_num_constants &&
433+ handle->get_constant_name && handle->get_constant_original_fqn &&
434+ handle->extract_constants_map &&
427435 handle->update_user_managed_constant_buffer_pairs ) {
428436 size_t num_constants = 0 ;
429437 handle->get_num_constants (handle->container_handle , &num_constants);
@@ -469,6 +477,8 @@ class ET_EXPERIMENTAL CudaBackend final
469477 Error,
470478 " Failed to extract constants from '%s'" ,
471479 method_name.c_str ());
480+ delete handle;
481+ return Error::Internal;
472482 }
473483 } else {
474484 // Subsequent container: share matching constants from the first.
@@ -501,14 +511,24 @@ class ET_EXPERIMENTAL CudaBackend final
501511 Error,
502512 " Failed to share constants into '%s'" ,
503513 method_name.c_str ());
514+ delete handle;
515+ return Error::Internal;
504516 }
505517 }
506518 }
507519 }
520+ } else if (share_kv_cache) {
521+ ET_LOG (
522+ Error,
523+ " share_kv_cache_across_methods requested but constant sharing APIs "
524+ " not available for method '%s'" ,
525+ method_name.c_str ());
526+ delete handle;
527+ return Error::Internal;
508528 } else {
509529 ET_LOG (
510530 Info,
511- " Constant sharing APIs not available for method '%s'" ,
531+ " Constant sharing not requested for method '%s'" ,
512532 method_name.c_str ());
513533 }
514534
0 commit comments