@@ -242,13 +242,6 @@ class ET_EXPERIMENTAL CudaBackend final
242242 LOAD_OPTIONAL_SYMBOL (
243243 get_constant_original_fqn,
244244 AOTInductorModelContainerGetConstantOriginalFQN);
245- LOAD_OPTIONAL_SYMBOL (
246- get_constant_data_size, AOTInductorModelContainerGetConstantDataSize);
247- LOAD_OPTIONAL_SYMBOL (
248- get_constant_from_folded,
249- AOTInductorModelContainerGetConstantFromFolded);
250- LOAD_OPTIONAL_SYMBOL (
251- get_constants_blob_size, AOTInductorModelContainerGetConstantsBlobSize);
252245 LOAD_OPTIONAL_SYMBOL (
253246 extract_constants_map, AOTInductorModelContainerExtractConstantsMap);
254247 LOAD_OPTIONAL_SYMBOL (
@@ -323,11 +316,17 @@ class ET_EXPERIMENTAL CudaBackend final
323316 ArrayRef<CompileSpec> compile_specs // This will be my empty list
324317 ) const override {
325318 std::string method_name;
319+ bool share_kv_cache = false ;
326320 for (const CompileSpec& spec : compile_specs) {
327321 if (std::strcmp (spec.key , " method_name" ) == 0 ) {
328322 method_name.assign (
329323 static_cast <const char *>(spec.value .buffer ),
330324 spec.value .nbytes ); // no nullptr guarantee, so pass size
325+ } else if (std::strcmp (spec.key , kShareKvCacheAcrossMethods ) == 0 ) {
326+ if (spec.value .nbytes >= 1 ) {
327+ share_kv_cache =
328+ static_cast <const uint8_t *>(spec.value .buffer )[0 ] != 0 ;
329+ }
331330 }
332331 }
333332
@@ -398,11 +397,30 @@ class ET_EXPERIMENTAL CudaBackend final
398397
399398 handle->container_handle = container_handle;
400399
401- // Load constants with per-weight caching.
402- // This replaces the old update_constants_from_blob + cross-method sharing
403- // with a unified approach that avoids duplicate GPU allocations.
404- ET_CHECK_OK_OR_RETURN_ERROR (
405- load_constants_with_cache (handle, named_data_map, method_name));
400+ // Look into named data map for constant data
401+ std::string weights_blob_key =
402+ method_name.empty () ? " weights_blob" : method_name + " _weights_blob" ;
403+ auto buffer_res = named_data_map->get_data (weights_blob_key.c_str ());
404+ if (buffer_res.ok () && handle->update_constants_from_blob != nullptr ) {
405+ ET_LOG (Info, " Found %s in named data map" , weights_blob_key.c_str ());
406+ const void * weights_blob = buffer_res->data ();
407+ // Feed the weights blob into the container. Under the hood it's copying
408+ // weights, so we should free the buffer immediately.
409+ auto update_err = handle->update_constants_from_blob (
410+ handle->container_handle , static_cast <const uint8_t *>(weights_blob));
411+ if (update_err != Error::Ok) {
412+ ET_LOG (Error, " update_constants_from_blob failed" );
413+ return update_err;
414+ }
415+ // Ensure all weight transfers are complete before execution
416+ cudaDeviceSynchronize ();
417+ buffer_res->Free ();
418+ } else {
419+ ET_LOG (
420+ Info,
421+ " weights_blob '%s' not found or update fn is null" ,
422+ weights_blob_key.c_str ());
423+ }
406424
407425 // Use shared CUDA stream if enabled via options, otherwise create one.
408426 // A shared stream ensures proper ordering across multiple methods
@@ -981,265 +999,21 @@ class ET_EXPERIMENTAL CudaBackend final
981999 unordered_map<cuda::CudaDelegateHandle*, std::vector<SlimTensor*>>
9821000 cached_outputs_;
9831001
984- // ---------------------------------------------------------------
985- // Per-weight constant cache.
986- //
987- // Maintains a singleton FQN → AtenTensorHandle cache across methods.
988- // When loading constants for a method, constants already in the cache
989- // are reused (zero-copy via update_user_managed_constant_buffer_pairs).
990- // Only constants not in the cache are loaded from the blob and added
991- // to the cache. This avoids duplicate GPU allocations when multiple
992- // methods (e.g., prefill/decode) share the same weights.
993- //
994- // allocate_constant_on_gpu() is the allocation primitive — kept as a
995- // separate function so the strategy can be swapped later (e.g., pool
996- // allocator, unified memory, sub-allocation from a slab).
997- // ---------------------------------------------------------------
998-
999- // Allocate a single constant from the blob onto GPU and return its
1000- // raw GPU pointer. Caller is responsible for lifetime management.
1001- // Returns nullptr on failure or if data_size is 0.
1002- static void * allocate_constant_on_gpu (
1003- const uint8_t * blob_ptr,
1004- size_t blob_offset,
1005- size_t data_size) {
1006- if (data_size == 0 ) {
1007- return nullptr ;
1008- }
1009- void * gpu_ptr = nullptr ;
1010- cudaError_t err = cudaMalloc (&gpu_ptr, data_size);
1011- if (err != cudaSuccess) {
1012- ET_LOG (
1013- Error,
1014- " cudaMalloc failed for constant (%zu bytes): %s" ,
1015- data_size,
1016- cudaGetErrorString (err));
1017- return nullptr ;
1018- }
1019- err = cudaMemcpy (
1020- gpu_ptr, blob_ptr + blob_offset, data_size, cudaMemcpyHostToDevice);
1021- if (err != cudaSuccess) {
1022- ET_LOG (
1023- Error,
1024- " cudaMemcpy failed for constant (%zu bytes): %s" ,
1025- data_size,
1026- cudaGetErrorString (err));
1027- cudaFree (gpu_ptr);
1028- return nullptr ;
1029- }
1030- return gpu_ptr;
1031- }
1032-
1033- // Load constants for a method using per-weight caching.
1034- // Returns Error::Ok on success.
1035- //
1036- // Flow:
1037- // 1. Enumerate this method's constants and their FQNs.
1038- // 2. For each constant:
1039- // - If FQN is in shared_constant_tensors_ → reuse (cache hit).
1040- // - Otherwise → mark as needing loading (cache miss).
1041- // 3. If all constants are cached → skip blob loading entirely.
1042- // Otherwise → call update_constants_from_blob to load all, then
1043- // extract and cache the new constants.
1044- // 4. For cached constants, call update_user_managed_constant_buffer_pairs
1045- // to point the container to the shared GPU tensors.
1046- Error load_constants_with_cache (
1047- cuda::CudaDelegateHandle* handle,
1048- const NamedDataMap* named_data_map,
1049- const std::string& method_name) const {
1050- // Check if the required APIs are available
1051- if (!handle->get_num_constants || !handle->get_constant_name ||
1052- !handle->get_constant_original_fqn || !handle->extract_constants_map ||
1053- !handle->update_user_managed_constant_buffer_pairs ) {
1054- // Fall back to the legacy path
1055- return load_constants_legacy (handle, named_data_map, method_name);
1056- }
1057-
1058- // Step 1: Enumerate constants and partition into cached/uncached
1059- size_t num_constants = 0 ;
1060- handle->get_num_constants (handle->container_handle , &num_constants);
1061- if (num_constants == 0 ) {
1062- ET_LOG (Info, " No constants for method '%s'" , method_name.c_str ());
1063- return Error::Ok;
1064- }
1065-
1066- // Build FQN → internal_name mapping and determine cache hits/misses
1067- std::unordered_map<std::string, std::string> fqn_to_name;
1068- std::vector<std::string> uncached_fqns;
1069-
1070- {
1071- std::lock_guard<std::mutex> guard (shared_constants_mutex_);
1072- for (size_t i = 0 ; i < num_constants; i++) {
1073- const char * name = nullptr ;
1074- const char * fqn = nullptr ;
1075- handle->get_constant_name (handle->container_handle , i, &name);
1076- handle->get_constant_original_fqn (handle->container_handle , i, &fqn);
1077- if (name && fqn && fqn[0 ] != ' \0 ' ) {
1078- fqn_to_name[fqn] = name;
1079- if (shared_constant_tensors_.find (fqn) ==
1080- shared_constant_tensors_.end ()) {
1081- uncached_fqns.push_back (fqn);
1082- }
1083- }
1084- }
1085- }
1086-
1087- size_t num_cached = fqn_to_name.size () - uncached_fqns.size ();
1088- ET_LOG (
1089- Info,
1090- " Method '%s': %zu constants, %zu cached, %zu uncached" ,
1091- method_name.c_str (),
1092- fqn_to_name.size (),
1093- num_cached,
1094- uncached_fqns.size ());
1095-
1096- // Step 2: Load uncached constants from blob (if any)
1097- if (!uncached_fqns.empty ()) {
1098- // Need to load from blob — use update_constants_from_blob for all,
1099- // then extract the new constants into the cache.
1100- std::string weights_blob_key =
1101- method_name.empty () ? " weights_blob" : method_name + " _weights_blob" ;
1102- auto buffer_res = named_data_map->get_data (weights_blob_key.c_str ());
1103-
1104- if (buffer_res.ok () && handle->update_constants_from_blob != nullptr ) {
1105- ET_LOG (
1106- Info,
1107- " Loading constants from blob '%s' for method '%s'" ,
1108- weights_blob_key.c_str (),
1109- method_name.c_str ());
1110- const void * weights_blob = buffer_res->data ();
1111- auto update_err = handle->update_constants_from_blob (
1112- handle->container_handle ,
1113- static_cast <const uint8_t *>(weights_blob));
1114- if (update_err != Error::Ok) {
1115- ET_LOG (Error, " update_constants_from_blob failed" );
1116- return update_err;
1117- }
1118- cudaDeviceSynchronize ();
1119- buffer_res->Free ();
1120- } else {
1121- ET_LOG (
1122- Error,
1123- " weights_blob '%s' not found or update fn is null" ,
1124- weights_blob_key.c_str ());
1125- return Error::NotFound;
1126- }
1127-
1128- // Extract all constants and cache the newly loaded ones
1129- std::unordered_map<std::string, AtenTensorHandle> extracted_map;
1130- auto extract_err = handle->extract_constants_map (
1131- handle->container_handle ,
1132- reinterpret_cast <AOTInductorConstantMapHandle>(&extracted_map),
1133- /* use_inactive=*/ false );
1134-
1135- if (extract_err == Error::Ok) {
1136- std::lock_guard<std::mutex> guard (shared_constants_mutex_);
1137- for (const auto & fqn : uncached_fqns) {
1138- auto it_name = fqn_to_name.find (fqn);
1139- if (it_name == fqn_to_name.end ())
1140- continue ;
1141- // extract_constants_map returns entries keyed by FQN
1142- auto it = extracted_map.find (fqn);
1143- if (it != extracted_map.end ()) {
1144- shared_constant_tensors_[fqn] = it->second ;
1145- }
1146- }
1147- ET_LOG (
1148- Info,
1149- " Cached %zu new constants from method '%s' (total cache: %zu)" ,
1150- uncached_fqns.size (),
1151- method_name.c_str (),
1152- shared_constant_tensors_.size ());
1153- } else {
1154- ET_LOG (
1155- Error,
1156- " Failed to extract constants from '%s'" ,
1157- method_name.c_str ());
1158- return Error::Internal;
1159- }
1160- } else {
1161- // All constants are cached — skip blob loading entirely!
1162- ET_LOG (
1163- Info,
1164- " All %zu constants cached — skipping blob load for method '%s'" ,
1165- fqn_to_name.size (),
1166- method_name.c_str ());
1167- }
1168-
1169- // Step 3: Point the container to cached tensors via user_managed pairs
1170- if (num_cached > 0 || uncached_fqns.empty ()) {
1171- std::vector<AOTInductorConstantMapEntry> pairs;
1172- {
1173- std::lock_guard<std::mutex> guard (shared_constants_mutex_);
1174- for (const auto & [fqn, internal_name] : fqn_to_name) {
1175- auto it = shared_constant_tensors_.find (fqn);
1176- if (it != shared_constant_tensors_.end ()) {
1177- pairs.push_back ({internal_name.c_str (), it->second });
1178- }
1179- }
1180- }
1181-
1182- if (!pairs.empty ()) {
1183- auto update_err = handle->update_user_managed_constant_buffer_pairs (
1184- handle->container_handle ,
1185- pairs.data (),
1186- pairs.size (),
1187- /* use_inactive=*/ false ,
1188- /* validate_full_update=*/ false );
1189-
1190- if (update_err != Error::Ok) {
1191- ET_LOG (
1192- Error,
1193- " Failed to set cached constants for method '%s'" ,
1194- method_name.c_str ());
1195- return Error::Internal;
1196- }
1197- ET_LOG (
1198- Info,
1199- " Shared %zu cached constants into method '%s'" ,
1200- pairs.size (),
1201- method_name.c_str ());
1202- }
1203- }
1204-
1205- return Error::Ok;
1206- }
1207-
1208- // Legacy constant loading: load the entire blob without caching.
1209- // Used as fallback when constant management APIs are unavailable.
1210- Error load_constants_legacy (
1211- cuda::CudaDelegateHandle* handle,
1212- const NamedDataMap* named_data_map,
1213- const std::string& method_name) const {
1214- std::string weights_blob_key =
1215- method_name.empty () ? " weights_blob" : method_name + " _weights_blob" ;
1216- auto buffer_res = named_data_map->get_data (weights_blob_key.c_str ());
1217- if (buffer_res.ok () && handle->update_constants_from_blob != nullptr ) {
1218- ET_LOG (Info, " Found %s in named data map" , weights_blob_key.c_str ());
1219- const void * weights_blob = buffer_res->data ();
1220- auto update_err = handle->update_constants_from_blob (
1221- handle->container_handle , static_cast <const uint8_t *>(weights_blob));
1222- if (update_err != Error::Ok) {
1223- ET_LOG (Error, " update_constants_from_blob failed" );
1224- return update_err;
1225- }
1226- cudaDeviceSynchronize ();
1227- buffer_res->Free ();
1228- } else {
1229- ET_LOG (
1230- Info,
1231- " weights_blob '%s' not found or update fn is null" ,
1232- weights_blob_key.c_str ());
1233- }
1234- return Error::Ok;
1235- }
1002+ // Cross-method constant sharing state.
1003+ // When multiple AOTI containers share mutable buffers (e.g., KV cache),
1004+ // the first container's constants are extracted and stored here. Subsequent
1005+ // containers with matching FQNs share the same GPU tensors via
1006+ // UpdateUserManagedConstantBufferPairs.
12361007 mutable std::mutex shared_constants_mutex_;
12371008
12381009 // FQN → AtenTensorHandle from the source (first) container.
12391010 // The tensor handles are owned by the source container (which is never
12401011 // explicitly deleted — see destroy() comment).
12411012 mutable std::unordered_map<std::string, AtenTensorHandle>
12421013 shared_constant_tensors_;
1014+
1015+ // Whether we've already extracted constants from a source container.
1016+ mutable bool constants_extracted_ = false ;
12431017};
12441018
12451019} // namespace executorch::backends::cuda
0 commit comments