Skip to content

Commit c93f8ae

Browse files
gasoonjiaGasoonjia
authored andcommitted
Revert non-SDPA changes to match main
Keep only sdpa.py changes on this branch; revert all other files (aoti_delegate_handle.h, benchmark_sdpa.py, cuda_backend.cpp, main.cpp, model.py) to their main branch state.
1 parent c6a4b38 commit c93f8ae

3 files changed

Lines changed: 39 additions & 284 deletions

File tree

backends/aoti/aoti_delegate_handle.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,6 @@ using AOTInductorModelContainerGetConstantOriginalFQNFunc =
104104
size_t idx,
105105
const char** original_fqn);
106106

107-
// Retrieves a constant's data size in bytes by index.
108-
using AOTInductorModelContainerGetConstantDataSizeFunc = AOTIRuntimeError (*)(
109-
AOTInductorModelContainerHandle container_handle,
110-
size_t idx,
111-
size_t* data_size);
112-
113-
// Retrieves whether a constant was produced by constant folding.
114-
using AOTInductorModelContainerGetConstantFromFoldedFunc = AOTIRuntimeError (*)(
115-
AOTInductorModelContainerHandle container_handle,
116-
size_t idx,
117-
bool* from_folded);
118-
119-
// Retrieves the total size of the constants blob.
120-
using AOTInductorModelContainerGetConstantsBlobSizeFunc = AOTIRuntimeError (*)(
121-
AOTInductorModelContainerHandle container_handle,
122-
uint64_t* ret_size);
123-
124107
// Extracts the constants map from the container (active or inactive buffer).
125108
// constant_map_handle should point to a
126109
// std::unordered_map<std::string, AtenTensorHandle>.
@@ -160,9 +143,6 @@ struct AOTIDelegateHandle {
160143
AOTInductorModelContainerGetNumConstantsFunc get_num_constants;
161144
AOTInductorModelContainerGetConstantNameFunc get_constant_name;
162145
AOTInductorModelContainerGetConstantOriginalFQNFunc get_constant_original_fqn;
163-
AOTInductorModelContainerGetConstantDataSizeFunc get_constant_data_size;
164-
AOTInductorModelContainerGetConstantFromFoldedFunc get_constant_from_folded;
165-
AOTInductorModelContainerGetConstantsBlobSizeFunc get_constants_blob_size;
166146
AOTInductorModelContainerExtractConstantsMapFunc extract_constants_map;
167147
AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc
168148
update_user_managed_constant_buffer_pairs;

backends/cuda/benchmarks/benchmark_sdpa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import torch
2323
import torch.nn.functional as F
24+
2425
from executorch.backends.cuda.triton.kernels.sdpa import (
2526
sdpa as triton_sdpa,
2627
sdpa_decode_splitk as triton_splitk,

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 38 additions & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,6 @@ class ET_EXPERIMENTAL CudaBackend final
242242
LOAD_OPTIONAL_SYMBOL(
243243
get_constant_original_fqn,
244244
AOTInductorModelContainerGetConstantOriginalFQN);
245-
LOAD_OPTIONAL_SYMBOL(
246-
get_constant_data_size, AOTInductorModelContainerGetConstantDataSize);
247-
LOAD_OPTIONAL_SYMBOL(
248-
get_constant_from_folded,
249-
AOTInductorModelContainerGetConstantFromFolded);
250-
LOAD_OPTIONAL_SYMBOL(
251-
get_constants_blob_size, AOTInductorModelContainerGetConstantsBlobSize);
252245
LOAD_OPTIONAL_SYMBOL(
253246
extract_constants_map, AOTInductorModelContainerExtractConstantsMap);
254247
LOAD_OPTIONAL_SYMBOL(
@@ -323,11 +316,17 @@ class ET_EXPERIMENTAL CudaBackend final
323316
ArrayRef<CompileSpec> compile_specs // This will be my empty list
324317
) const override {
325318
std::string method_name;
319+
bool share_kv_cache = false;
326320
for (const CompileSpec& spec : compile_specs) {
327321
if (std::strcmp(spec.key, "method_name") == 0) {
328322
method_name.assign(
329323
static_cast<const char*>(spec.value.buffer),
330324
spec.value.nbytes); // no nullptr guarantee, so pass size
325+
} else if (std::strcmp(spec.key, kShareKvCacheAcrossMethods) == 0) {
326+
if (spec.value.nbytes >= 1) {
327+
share_kv_cache =
328+
static_cast<const uint8_t*>(spec.value.buffer)[0] != 0;
329+
}
331330
}
332331
}
333332

@@ -398,11 +397,30 @@ class ET_EXPERIMENTAL CudaBackend final
398397

399398
handle->container_handle = container_handle;
400399

401-
// Load constants with per-weight caching.
402-
// This replaces the old update_constants_from_blob + cross-method sharing
403-
// with a unified approach that avoids duplicate GPU allocations.
404-
ET_CHECK_OK_OR_RETURN_ERROR(
405-
load_constants_with_cache(handle, named_data_map, method_name));
400+
// Look into named data map for constant data
401+
std::string weights_blob_key =
402+
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
403+
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
404+
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
405+
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
406+
const void* weights_blob = buffer_res->data();
407+
// Feed the weights blob into the container. Under the hood it's copying
408+
// weights, so we should free the buffer immediately.
409+
auto update_err = handle->update_constants_from_blob(
410+
handle->container_handle, static_cast<const uint8_t*>(weights_blob));
411+
if (update_err != Error::Ok) {
412+
ET_LOG(Error, "update_constants_from_blob failed");
413+
return update_err;
414+
}
415+
// Ensure all weight transfers are complete before execution
416+
cudaDeviceSynchronize();
417+
buffer_res->Free();
418+
} else {
419+
ET_LOG(
420+
Info,
421+
"weights_blob '%s' not found or update fn is null",
422+
weights_blob_key.c_str());
423+
}
406424

407425
// Use shared CUDA stream if enabled via options, otherwise create one.
408426
// A shared stream ensures proper ordering across multiple methods
@@ -981,265 +999,21 @@ class ET_EXPERIMENTAL CudaBackend final
981999
unordered_map<cuda::CudaDelegateHandle*, std::vector<SlimTensor*>>
9821000
cached_outputs_;
9831001

984-
// ---------------------------------------------------------------
985-
// Per-weight constant cache.
986-
//
987-
// Maintains a singleton FQN → AtenTensorHandle cache across methods.
988-
// When loading constants for a method, constants already in the cache
989-
// are reused (zero-copy via update_user_managed_constant_buffer_pairs).
990-
// Only constants not in the cache are loaded from the blob and added
991-
// to the cache. This avoids duplicate GPU allocations when multiple
992-
// methods (e.g., prefill/decode) share the same weights.
993-
//
994-
// allocate_constant_on_gpu() is the allocation primitive — kept as a
995-
// separate function so the strategy can be swapped later (e.g., pool
996-
// allocator, unified memory, sub-allocation from a slab).
997-
// ---------------------------------------------------------------
998-
999-
// Allocate a single constant from the blob onto GPU and return its
1000-
// raw GPU pointer. Caller is responsible for lifetime management.
1001-
// Returns nullptr on failure or if data_size is 0.
1002-
static void* allocate_constant_on_gpu(
1003-
const uint8_t* blob_ptr,
1004-
size_t blob_offset,
1005-
size_t data_size) {
1006-
if (data_size == 0) {
1007-
return nullptr;
1008-
}
1009-
void* gpu_ptr = nullptr;
1010-
cudaError_t err = cudaMalloc(&gpu_ptr, data_size);
1011-
if (err != cudaSuccess) {
1012-
ET_LOG(
1013-
Error,
1014-
"cudaMalloc failed for constant (%zu bytes): %s",
1015-
data_size,
1016-
cudaGetErrorString(err));
1017-
return nullptr;
1018-
}
1019-
err = cudaMemcpy(
1020-
gpu_ptr, blob_ptr + blob_offset, data_size, cudaMemcpyHostToDevice);
1021-
if (err != cudaSuccess) {
1022-
ET_LOG(
1023-
Error,
1024-
"cudaMemcpy failed for constant (%zu bytes): %s",
1025-
data_size,
1026-
cudaGetErrorString(err));
1027-
cudaFree(gpu_ptr);
1028-
return nullptr;
1029-
}
1030-
return gpu_ptr;
1031-
}
1032-
1033-
// Load constants for a method using per-weight caching.
1034-
// Returns Error::Ok on success.
1035-
//
1036-
// Flow:
1037-
// 1. Enumerate this method's constants and their FQNs.
1038-
// 2. For each constant:
1039-
// - If FQN is in shared_constant_tensors_ → reuse (cache hit).
1040-
// - Otherwise → mark as needing loading (cache miss).
1041-
// 3. If all constants are cached → skip blob loading entirely.
1042-
// Otherwise → call update_constants_from_blob to load all, then
1043-
// extract and cache the new constants.
1044-
// 4. For cached constants, call update_user_managed_constant_buffer_pairs
1045-
// to point the container to the shared GPU tensors.
1046-
Error load_constants_with_cache(
1047-
cuda::CudaDelegateHandle* handle,
1048-
const NamedDataMap* named_data_map,
1049-
const std::string& method_name) const {
1050-
// Check if the required APIs are available
1051-
if (!handle->get_num_constants || !handle->get_constant_name ||
1052-
!handle->get_constant_original_fqn || !handle->extract_constants_map ||
1053-
!handle->update_user_managed_constant_buffer_pairs) {
1054-
// Fall back to the legacy path
1055-
return load_constants_legacy(handle, named_data_map, method_name);
1056-
}
1057-
1058-
// Step 1: Enumerate constants and partition into cached/uncached
1059-
size_t num_constants = 0;
1060-
handle->get_num_constants(handle->container_handle, &num_constants);
1061-
if (num_constants == 0) {
1062-
ET_LOG(Info, "No constants for method '%s'", method_name.c_str());
1063-
return Error::Ok;
1064-
}
1065-
1066-
// Build FQN → internal_name mapping and determine cache hits/misses
1067-
std::unordered_map<std::string, std::string> fqn_to_name;
1068-
std::vector<std::string> uncached_fqns;
1069-
1070-
{
1071-
std::lock_guard<std::mutex> guard(shared_constants_mutex_);
1072-
for (size_t i = 0; i < num_constants; i++) {
1073-
const char* name = nullptr;
1074-
const char* fqn = nullptr;
1075-
handle->get_constant_name(handle->container_handle, i, &name);
1076-
handle->get_constant_original_fqn(handle->container_handle, i, &fqn);
1077-
if (name && fqn && fqn[0] != '\0') {
1078-
fqn_to_name[fqn] = name;
1079-
if (shared_constant_tensors_.find(fqn) ==
1080-
shared_constant_tensors_.end()) {
1081-
uncached_fqns.push_back(fqn);
1082-
}
1083-
}
1084-
}
1085-
}
1086-
1087-
size_t num_cached = fqn_to_name.size() - uncached_fqns.size();
1088-
ET_LOG(
1089-
Info,
1090-
"Method '%s': %zu constants, %zu cached, %zu uncached",
1091-
method_name.c_str(),
1092-
fqn_to_name.size(),
1093-
num_cached,
1094-
uncached_fqns.size());
1095-
1096-
// Step 2: Load uncached constants from blob (if any)
1097-
if (!uncached_fqns.empty()) {
1098-
// Need to load from blob — use update_constants_from_blob for all,
1099-
// then extract the new constants into the cache.
1100-
std::string weights_blob_key =
1101-
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
1102-
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
1103-
1104-
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
1105-
ET_LOG(
1106-
Info,
1107-
"Loading constants from blob '%s' for method '%s'",
1108-
weights_blob_key.c_str(),
1109-
method_name.c_str());
1110-
const void* weights_blob = buffer_res->data();
1111-
auto update_err = handle->update_constants_from_blob(
1112-
handle->container_handle,
1113-
static_cast<const uint8_t*>(weights_blob));
1114-
if (update_err != Error::Ok) {
1115-
ET_LOG(Error, "update_constants_from_blob failed");
1116-
return update_err;
1117-
}
1118-
cudaDeviceSynchronize();
1119-
buffer_res->Free();
1120-
} else {
1121-
ET_LOG(
1122-
Error,
1123-
"weights_blob '%s' not found or update fn is null",
1124-
weights_blob_key.c_str());
1125-
return Error::NotFound;
1126-
}
1127-
1128-
// Extract all constants and cache the newly loaded ones
1129-
std::unordered_map<std::string, AtenTensorHandle> extracted_map;
1130-
auto extract_err = handle->extract_constants_map(
1131-
handle->container_handle,
1132-
reinterpret_cast<AOTInductorConstantMapHandle>(&extracted_map),
1133-
/*use_inactive=*/false);
1134-
1135-
if (extract_err == Error::Ok) {
1136-
std::lock_guard<std::mutex> guard(shared_constants_mutex_);
1137-
for (const auto& fqn : uncached_fqns) {
1138-
auto it_name = fqn_to_name.find(fqn);
1139-
if (it_name == fqn_to_name.end())
1140-
continue;
1141-
// extract_constants_map returns entries keyed by FQN
1142-
auto it = extracted_map.find(fqn);
1143-
if (it != extracted_map.end()) {
1144-
shared_constant_tensors_[fqn] = it->second;
1145-
}
1146-
}
1147-
ET_LOG(
1148-
Info,
1149-
"Cached %zu new constants from method '%s' (total cache: %zu)",
1150-
uncached_fqns.size(),
1151-
method_name.c_str(),
1152-
shared_constant_tensors_.size());
1153-
} else {
1154-
ET_LOG(
1155-
Error,
1156-
"Failed to extract constants from '%s'",
1157-
method_name.c_str());
1158-
return Error::Internal;
1159-
}
1160-
} else {
1161-
// All constants are cached — skip blob loading entirely!
1162-
ET_LOG(
1163-
Info,
1164-
"All %zu constants cached — skipping blob load for method '%s'",
1165-
fqn_to_name.size(),
1166-
method_name.c_str());
1167-
}
1168-
1169-
// Step 3: Point the container to cached tensors via user_managed pairs
1170-
if (num_cached > 0 || uncached_fqns.empty()) {
1171-
std::vector<AOTInductorConstantMapEntry> pairs;
1172-
{
1173-
std::lock_guard<std::mutex> guard(shared_constants_mutex_);
1174-
for (const auto& [fqn, internal_name] : fqn_to_name) {
1175-
auto it = shared_constant_tensors_.find(fqn);
1176-
if (it != shared_constant_tensors_.end()) {
1177-
pairs.push_back({internal_name.c_str(), it->second});
1178-
}
1179-
}
1180-
}
1181-
1182-
if (!pairs.empty()) {
1183-
auto update_err = handle->update_user_managed_constant_buffer_pairs(
1184-
handle->container_handle,
1185-
pairs.data(),
1186-
pairs.size(),
1187-
/*use_inactive=*/false,
1188-
/*validate_full_update=*/false);
1189-
1190-
if (update_err != Error::Ok) {
1191-
ET_LOG(
1192-
Error,
1193-
"Failed to set cached constants for method '%s'",
1194-
method_name.c_str());
1195-
return Error::Internal;
1196-
}
1197-
ET_LOG(
1198-
Info,
1199-
"Shared %zu cached constants into method '%s'",
1200-
pairs.size(),
1201-
method_name.c_str());
1202-
}
1203-
}
1204-
1205-
return Error::Ok;
1206-
}
1207-
1208-
// Legacy constant loading: load the entire blob without caching.
1209-
// Used as fallback when constant management APIs are unavailable.
1210-
Error load_constants_legacy(
1211-
cuda::CudaDelegateHandle* handle,
1212-
const NamedDataMap* named_data_map,
1213-
const std::string& method_name) const {
1214-
std::string weights_blob_key =
1215-
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
1216-
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
1217-
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
1218-
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
1219-
const void* weights_blob = buffer_res->data();
1220-
auto update_err = handle->update_constants_from_blob(
1221-
handle->container_handle, static_cast<const uint8_t*>(weights_blob));
1222-
if (update_err != Error::Ok) {
1223-
ET_LOG(Error, "update_constants_from_blob failed");
1224-
return update_err;
1225-
}
1226-
cudaDeviceSynchronize();
1227-
buffer_res->Free();
1228-
} else {
1229-
ET_LOG(
1230-
Info,
1231-
"weights_blob '%s' not found or update fn is null",
1232-
weights_blob_key.c_str());
1233-
}
1234-
return Error::Ok;
1235-
}
1002+
// Cross-method constant sharing state.
1003+
// When multiple AOTI containers share mutable buffers (e.g., KV cache),
1004+
// the first container's constants are extracted and stored here. Subsequent
1005+
// containers with matching FQNs share the same GPU tensors via
1006+
// UpdateUserManagedConstantBufferPairs.
12361007
mutable std::mutex shared_constants_mutex_;
12371008

12381009
// FQN → AtenTensorHandle from the source (first) container.
12391010
// The tensor handles are owned by the source container (which is never
12401011
// explicitly deleted — see destroy() comment).
12411012
mutable std::unordered_map<std::string, AtenTensorHandle>
12421013
shared_constant_tensors_;
1014+
1015+
// Whether we've already extracted constants from a source container.
1016+
mutable bool constants_extracted_ = false;
12431017
};
12441018

12451019
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)