@@ -207,6 +207,30 @@ class ET_EXPERIMENTAL CudaBackend final
207207 Info,
208208 " Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)" );
209209 }
210+
211+ // Load constant management symbols (optional — needed for cross-method
212+ // buffer sharing). These are available in torch >= 2.6.
213+ #define LOAD_OPTIONAL_SYMBOL (member, name ) \
214+ do { \
215+ auto res = get_function (so_handle, #name); \
216+ handle->member = \
217+ res.ok () ? reinterpret_cast <name##Func>(res.get ()) : nullptr ; \
218+ } while (0 )
219+
220+ LOAD_OPTIONAL_SYMBOL (
221+ get_num_constants, AOTInductorModelContainerGetNumConstants);
222+ LOAD_OPTIONAL_SYMBOL (
223+ get_constant_name, AOTInductorModelContainerGetConstantName);
224+ LOAD_OPTIONAL_SYMBOL (
225+ get_constant_original_fqn,
226+ AOTInductorModelContainerGetConstantOriginalFQN);
227+ LOAD_OPTIONAL_SYMBOL (
228+ extract_constants_map, AOTInductorModelContainerExtractConstantsMap);
229+ LOAD_OPTIONAL_SYMBOL (
230+ update_user_managed_constant_buffer_pairs,
231+ AOTInductorModelContainerUpdateUserManagedConstantBufferPairs);
232+ #undef LOAD_OPTIONAL_SYMBOL
233+
210234 return Error::Ok;
211235 }
212236
@@ -348,9 +372,20 @@ class ET_EXPERIMENTAL CudaBackend final
348372 const void * weights_blob = buffer_res->data ();
349373 // Feed the weights blob into the container. Under the hood it's copying
350374 // weights, so we should free the buffer immediately.
351- ET_CHECK_OK_OR_RETURN_ERROR (handle->update_constants_from_blob (
352- handle->container_handle , static_cast <const uint8_t *>(weights_blob)));
375+ auto update_err = handle->update_constants_from_blob (
376+ handle->container_handle , static_cast <const uint8_t *>(weights_blob));
377+ if (update_err != Error::Ok) {
378+ ET_LOG (Error, " update_constants_from_blob failed" );
379+ return update_err;
380+ }
381+ // Ensure all weight transfers are complete before execution
382+ cudaDeviceSynchronize ();
353383 buffer_res->Free ();
384+ } else {
385+ ET_LOG (
386+ Info,
387+ " weights_blob '%s' not found or update fn is null" ,
388+ weights_blob_key.c_str ());
354389 }
355390
356391 // Use shared CUDA stream if enabled via options, otherwise create one.
@@ -378,6 +413,105 @@ class ET_EXPERIMENTAL CudaBackend final
378413 method_name.c_str ());
379414 }
380415
416+ // ---------------------------------------------------------------
417+ // Cross-method constant sharing (e.g., KV cache between prefill/decode).
418+ //
419+ // The first container to initialize extracts its constants (keyed by
420+ // original FQN) and stores the AtenTensorHandle's. Subsequent containers
421+ // with matching FQNs are updated to point to the same GPU tensors via
422+ // UpdateUserManagedConstantBufferPairs (user_managed = true → no copy,
423+ // the source container retains ownership).
424+ // ---------------------------------------------------------------
425+ if (handle->get_num_constants && handle->get_constant_name &&
426+ handle->get_constant_original_fqn && handle->extract_constants_map &&
427+ handle->update_user_managed_constant_buffer_pairs ) {
428+ size_t num_constants = 0 ;
429+ handle->get_num_constants (handle->container_handle , &num_constants);
430+
431+ if (num_constants > 0 ) {
432+ // Build FQN → internal_name mapping for this container.
433+ std::unordered_map<std::string, std::string> fqn_to_name;
434+ for (size_t i = 0 ; i < num_constants; i++) {
435+ const char * name = nullptr ;
436+ const char * fqn = nullptr ;
437+ handle->get_constant_name (handle->container_handle , i, &name);
438+ handle->get_constant_original_fqn (handle->container_handle , i, &fqn);
439+ if (name && fqn && fqn[0 ] != ' \0 ' ) {
440+ fqn_to_name[fqn] = name;
441+ }
442+ }
443+
444+ std::lock_guard<std::mutex> guard (shared_constants_mutex_);
445+
446+ if (!constants_extracted_) {
447+ // First container: extract its constants and store by FQN.
448+ std::unordered_map<std::string, AtenTensorHandle> extracted_map;
449+ auto extract_err = handle->extract_constants_map (
450+ handle->container_handle ,
451+ reinterpret_cast <AOTInductorConstantMapHandle>(&extracted_map),
452+ /* use_inactive=*/ false );
453+
454+ if (extract_err == Error::Ok) {
455+ for (const auto & [fqn, internal_name] : fqn_to_name) {
456+ auto it = extracted_map.find (fqn);
457+ if (it != extracted_map.end ()) {
458+ shared_constant_tensors_[fqn] = it->second ;
459+ }
460+ }
461+ constants_extracted_ = true ;
462+ ET_LOG (
463+ Info,
464+ " Extracted %zu shared constants from method '%s'" ,
465+ shared_constant_tensors_.size (),
466+ method_name.c_str ());
467+ } else {
468+ ET_LOG (
469+ Error,
470+ " Failed to extract constants from '%s'" ,
471+ method_name.c_str ());
472+ }
473+ } else {
474+ // Subsequent container: share matching constants from the first.
475+ std::vector<AOTInductorConstantMapEntry> pairs;
476+ for (const auto & [fqn, internal_name] : fqn_to_name) {
477+ auto it = shared_constant_tensors_.find (fqn);
478+ if (it != shared_constant_tensors_.end ()) {
479+ // UpdateUserManagedConstantBufferPairs matches against the
480+ // codegen constant name (underscored), not the original FQN.
481+ pairs.push_back ({internal_name.c_str (), it->second });
482+ }
483+ }
484+
485+ if (!pairs.empty ()) {
486+ auto update_err = handle->update_user_managed_constant_buffer_pairs (
487+ handle->container_handle ,
488+ pairs.data (),
489+ pairs.size (),
490+ /* use_inactive=*/ false ,
491+ /* validate_full_update=*/ false );
492+
493+ if (update_err == Error::Ok) {
494+ ET_LOG (
495+ Info,
496+ " Shared %zu constants into method '%s'" ,
497+ pairs.size (),
498+ method_name.c_str ());
499+ } else {
500+ ET_LOG (
501+ Error,
502+ " Failed to share constants into '%s'" ,
503+ method_name.c_str ());
504+ }
505+ }
506+ }
507+ }
508+ } else {
509+ ET_LOG (
510+ Info,
511+ " Constant sharing APIs not available for method '%s'" ,
512+ method_name.c_str ());
513+ }
514+
381515 return (DelegateHandle*)handle; // Return the handle post-processing
382516 }
383517
@@ -623,6 +757,22 @@ class ET_EXPERIMENTAL CudaBackend final
623757 mutable std::
624758 unordered_map<cuda::CudaDelegateHandle*, std::vector<SlimTensor*>>
625759 cached_outputs_;
760+
761+ // Cross-method constant sharing state.
762+ // When multiple AOTI containers share mutable buffers (e.g., KV cache),
763+ // the first container's constants are extracted and stored here. Subsequent
764+ // containers with matching FQNs share the same GPU tensors via
765+ // UpdateUserManagedConstantBufferPairs.
766+ mutable std::mutex shared_constants_mutex_;
767+
768+ // FQN → AtenTensorHandle from the source (first) container.
769+ // The tensor handles are owned by the source container (which is never
770+ // explicitly deleted — see destroy() comment).
771+ mutable std::unordered_map<std::string, AtenTensorHandle>
772+ shared_constant_tensors_;
773+
774+ // Whether we've already extracted constants from a source container.
775+ mutable bool constants_extracted_ = false ;
626776};
627777
628778} // namespace executorch::backends::cuda
0 commit comments