Skip to content

Commit 3a9abee

Browse files
authored
Per-instance named-data keys for AOTI delegate blobs (pytorch#20424)
1 parent 6f6225c commit 3a9abee

6 files changed

Lines changed: 123 additions & 22 deletions

File tree

backends/aoti/aoti_backend.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import contextlib
8+
import hashlib
89
import os
910
import typing
1011
from abc import ABC, abstractmethod
@@ -276,18 +277,21 @@ def preprocess(
276277

277278
# Create named data store
278279
named_data_store = NamedDataStore()
279-
method_name = cls.method_name_from_compile_specs(compile_specs)
280280

281-
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)
281+
# Key each blob by a content hash so partitions in one method get distinct
282+
# keys (a method-name-only key collides). Runtime recovers them from
283+
# processed_bytes below.
284+
so_blob_key = hashlib.sha256(so_data).hexdigest() + "_so_blob"
285+
weights_blob_key = hashlib.sha256(blob_data).hexdigest() + "_weights_blob"
286+
287+
named_data_store.add_named_data(so_blob_key, so_data, 1, None)
282288
# Determine whether to save named data externally based on backend setting
283289
# External: save to separate .ptd file, otherwise merge with .pte file
284290
external_tag = (
285291
f"aoti_{device_name}_blob" if cls.save_data_externally() else None
286292
)
287293

288-
named_data_store.add_named_data(
289-
method_name + "_weights_blob", blob_data, 1, external_tag
290-
)
294+
named_data_store.add_named_data(weights_blob_key, blob_data, 1, external_tag)
291295

292296
# Clean up the generated files
293297
os.remove(so_path)
@@ -299,8 +303,11 @@ def preprocess(
299303
# the next preprocess call (e.g. for the next method).
300304
cls.release_moved_tensors(device_edge_program, compile_specs)
301305

306+
# The runtime cannot recompute these hash keys, so carry them (one per line).
307+
processed_bytes = (so_blob_key + "\n" + weights_blob_key).encode("utf-8")
308+
302309
return PreprocessResult(
303-
processed_bytes=b"",
310+
processed_bytes=processed_bytes,
304311
debug_handle_map={},
305312
data_store_output=named_data_store.get_named_data_store_output(),
306313
)

backends/aoti/aoti_delegate_handle.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
#include <executorch/runtime/core/error.h>
1212
#include <executorch/runtime/core/evalue.h>
13+
#include <executorch/runtime/core/freeable_buffer.h>
1314
#include <string>
1415

1516
namespace executorch {
1617
namespace backends {
1718
namespace aoti {
1819

1920
using executorch::runtime::Error;
21+
using executorch::runtime::FreeableBuffer;
2022
using executorch::runtime::etensor::Tensor;
2123

2224
extern "C" {
@@ -148,6 +150,30 @@ struct AOTIDelegateHandle {
148150
update_user_managed_constant_buffer_pairs;
149151
};
150152

153+
// New-format payload is "<so_key>\n<weights_key>"; an empty payload is a
154+
// pre-this-change artifact, so fall back to the legacy method-name keys.
155+
inline Error resolve_blob_keys(
156+
const FreeableBuffer* processed,
157+
const std::string& method_name,
158+
std::string& so_blob_key,
159+
std::string& weights_blob_key) {
160+
if (processed != nullptr && processed->size() > 0) {
161+
const std::string keys(
162+
static_cast<const char*>(processed->data()), processed->size());
163+
const size_t newline = keys.find('\n');
164+
if (newline == std::string::npos) {
165+
return Error::Internal;
166+
}
167+
so_blob_key = keys.substr(0, newline);
168+
weights_blob_key = keys.substr(newline + 1);
169+
} else {
170+
so_blob_key = method_name.empty() ? "so_blob" : method_name + "_so_blob";
171+
weights_blob_key =
172+
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
173+
}
174+
return Error::Ok;
175+
}
176+
151177
} // namespace aoti
152178
} // namespace backends
153179
} // namespace executorch

backends/aoti/tests/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@ load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")
33

44
oncall("executorch")
55

6+
cpp_unittest(
7+
name = "test_resolve_blob_keys",
8+
srcs = [
9+
"test_resolve_blob_keys.cpp",
10+
],
11+
deps = [
12+
"//executorch/backends/aoti:delegate_handle",
13+
"//executorch/runtime/core:core",
14+
"//executorch/runtime/core:evalue",
15+
],
16+
)
17+
618
cpp_unittest(
719
name = "test_common_shims",
820
srcs = [
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/aoti_delegate_handle.h>
10+
11+
#include <gtest/gtest.h>
12+
#include <string>
13+
14+
#include <executorch/runtime/core/error.h>
15+
#include <executorch/runtime/core/freeable_buffer.h>
16+
17+
using executorch::backends::aoti::resolve_blob_keys;
18+
using executorch::runtime::Error;
19+
using executorch::runtime::FreeableBuffer;
20+
21+
TEST(ResolveBlobKeysTest, ParsesKeysFromPayload) {
22+
const std::string payload = "aaa_so_blob\nbbb_weights_blob";
23+
FreeableBuffer processed(payload.data(), payload.size(), nullptr);
24+
std::string so_key;
25+
std::string weights_key;
26+
27+
ASSERT_EQ(
28+
resolve_blob_keys(&processed, "forward", so_key, weights_key), Error::Ok);
29+
EXPECT_EQ(so_key, "aaa_so_blob");
30+
EXPECT_EQ(weights_key, "bbb_weights_blob");
31+
}
32+
33+
TEST(ResolveBlobKeysTest, FallsBackToMethodNameKeysWhenEmpty) {
34+
FreeableBuffer processed; // size 0: a pre-this-change artifact
35+
std::string so_key;
36+
std::string weights_key;
37+
38+
ASSERT_EQ(
39+
resolve_blob_keys(&processed, "forward", so_key, weights_key), Error::Ok);
40+
EXPECT_EQ(so_key, "forward_so_blob");
41+
EXPECT_EQ(weights_key, "forward_weights_blob");
42+
}
43+
44+
TEST(ResolveBlobKeysTest, FailsOnMalformedPayload) {
45+
const std::string payload = "missing_the_newline_separator";
46+
FreeableBuffer processed(payload.data(), payload.size(), nullptr);
47+
std::string so_key;
48+
std::string weights_key;
49+
50+
EXPECT_EQ(
51+
resolve_blob_keys(&processed, "forward", so_key, weights_key),
52+
Error::Internal);
53+
}

backends/apple/metal/runtime/metal_backend.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,12 @@ class ET_EXPERIMENTAL MetalBackend final
245245
}
246246
}
247247

248-
std::string so_blob_key =
249-
method_name.empty() ? "so_blob" : method_name + "_so_blob";
248+
std::string so_blob_key;
249+
std::string weights_blob_key;
250+
ET_CHECK_OK_OR_RETURN_ERROR(
251+
executorch::backends::aoti::resolve_blob_keys(
252+
processed, method_name, so_blob_key, weights_blob_key),
253+
"Malformed named-data key payload");
250254
ET_LOG(Info, "MetalBackend::init - so_blob_key: %s", so_blob_key.c_str());
251255

252256
const NamedDataMap* named_data_map = context.get_named_data_map();
@@ -258,8 +262,6 @@ class ET_EXPERIMENTAL MetalBackend final
258262
// Prefetch the weights blob — trigger async readahead so pages are
259263
// resident by the time update_constants_from_blob memcpy's them.
260264
// This overlaps disk I/O with the .so write + dlopen (~200ms).
261-
std::string weights_blob_key =
262-
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
263265
{
264266
auto prefetch_buf = named_data_map->get_data(weights_blob_key.c_str());
265267
if (prefetch_buf.ok() && prefetch_buf->data() != nullptr) {

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,12 @@ class ET_EXPERIMENTAL CudaBackend final
319319
}
320320
}
321321

322-
std::string so_blob_key =
323-
method_name.empty() ? "so_blob" : method_name + "_so_blob";
322+
std::string so_blob_key;
323+
std::string weights_blob_key;
324+
ET_CHECK_OK_OR_RETURN_ERROR(
325+
executorch::backends::aoti::resolve_blob_keys(
326+
processed, method_name, so_blob_key, weights_blob_key),
327+
"Malformed named-data key payload");
324328

325329
const NamedDataMap* named_data_map = context.get_named_data_map();
326330
auto aoti_dso_buffer = named_data_map->get_data(so_blob_key.c_str());
@@ -394,11 +398,11 @@ class ET_EXPERIMENTAL CudaBackend final
394398
// methods are independent sub-graphs that may have FQN collisions
395399
// (e.g. parakeet).
396400
if (is_weight_sharing_across_methods_enabled()) {
397-
ET_CHECK_OK_OR_RETURN_ERROR(
398-
load_constants_with_cache(handle, named_data_map, method_name));
401+
ET_CHECK_OK_OR_RETURN_ERROR(load_constants_with_cache(
402+
handle, named_data_map, method_name, weights_blob_key));
399403
} else {
400404
ET_CHECK_OK_OR_RETURN_ERROR(
401-
load_constants_legacy(handle, named_data_map, method_name));
405+
load_constants_legacy(handle, named_data_map, weights_blob_key));
402406
}
403407

404408
// Use shared CUDA stream if enabled via options, otherwise create one.
@@ -1011,13 +1015,14 @@ class ET_EXPERIMENTAL CudaBackend final
10111015
Error load_constants_with_cache(
10121016
cuda::CudaDelegateHandle* handle,
10131017
const NamedDataMap* named_data_map,
1014-
const std::string& method_name) const {
1018+
const std::string& method_name,
1019+
const std::string& weights_blob_key) const {
10151020
// Check if the required APIs are available
10161021
if (!handle->get_num_constants || !handle->get_constant_name ||
10171022
!handle->get_constant_original_fqn || !handle->extract_constants_map ||
10181023
!handle->update_user_managed_constant_buffer_pairs) {
10191024
// Fall back to the legacy path
1020-
return load_constants_legacy(handle, named_data_map, method_name);
1025+
return load_constants_legacy(handle, named_data_map, weights_blob_key);
10211026
}
10221027

10231028
// Step 1: Enumerate constants and partition into cached/uncached
@@ -1069,8 +1074,6 @@ class ET_EXPERIMENTAL CudaBackend final
10691074
if (!uncached_fqns.empty()) {
10701075
// Need to load from blob — use update_constants_from_blob for all,
10711076
// then extract the new constants into the cache.
1072-
std::string weights_blob_key =
1073-
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
10741077
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
10751078

10761079
ET_CHECK_OR_RETURN_ERROR(
@@ -1190,9 +1193,7 @@ class ET_EXPERIMENTAL CudaBackend final
11901193
Error load_constants_legacy(
11911194
cuda::CudaDelegateHandle* handle,
11921195
const NamedDataMap* named_data_map,
1193-
const std::string& method_name) const {
1194-
std::string weights_blob_key =
1195-
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
1196+
const std::string& weights_blob_key) const {
11961197
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
11971198
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
11981199
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());

0 commit comments

Comments
 (0)