Skip to content

Commit 4528ae2

Browse files
authored
4/x: Add LoadBackendOptionsMap support to Module::load()
Differential Revision: D92358607 Pull Request resolved: #17687
1 parent 19de115 commit 4528ae2

4 files changed

Lines changed: 214 additions & 7 deletions

File tree

extension/module/module.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,17 @@ Module::Module(
176176
}
177177

178178
runtime::Error Module::load(const Program::Verification verification) {
179+
return load_internal(verification);
180+
}
181+
182+
runtime::Error Module::load(
183+
const LoadBackendOptionsMap& backend_options,
184+
const Program::Verification verification) {
185+
backend_options_ = &backend_options;
186+
return load_internal(verification);
187+
}
188+
189+
runtime::Error Module::load_internal(const Program::Verification verification) {
179190
if (!is_loaded()) {
180191
if (!data_loader_) {
181192
auto data_loader_result = make_data_loader(file_path_, load_mode_);
@@ -256,10 +267,15 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
256267
runtime::Error Module::load_method(
257268
const std::string& method_name,
258269
runtime::HierarchicalAllocator* planned_memory,
259-
torch::executor::EventTracer* event_tracer) {
270+
torch::executor::EventTracer* event_tracer,
271+
const LoadBackendOptionsMap* backend_options) {
260272
if (!is_method_loaded(method_name)) {
261273
ET_CHECK_OK_OR_RETURN_ERROR(load());
262274

275+
// Use passed backend_options, or fall back to stored one from load()
276+
const LoadBackendOptionsMap* effective_backend_options =
277+
backend_options ? backend_options : backend_options_;
278+
263279
MethodHolder method_holder;
264280

265281
if (!planned_memory) {
@@ -292,7 +308,8 @@ runtime::Error Module::load_method(
292308
method_name.c_str(),
293309
method_holder.memory_manager.get(),
294310
event_tracer ? event_tracer : this->event_tracer(),
295-
merged_data_map_.get());
311+
merged_data_map_.get(),
312+
effective_backend_options);
296313
if (!res_method.ok()) {
297314
return res_method.error();
298315
}

extension/module/module.h

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ using ET_RUNTIME_NAMESPACE::Method;
2929
using ET_RUNTIME_NAMESPACE::MethodMeta;
3030
using ET_RUNTIME_NAMESPACE::NamedDataMap;
3131
using ET_RUNTIME_NAMESPACE::Program;
32+
using runtime::LoadBackendOptionsMap;
3233

3334
class ExecuTorchJni;
3435

@@ -153,6 +154,22 @@ class Module {
153154
const Program::Verification verification =
154155
Program::Verification::Minimal);
155156

157+
/**
158+
* Loads the program with per-delegate runtime options.
159+
*
160+
* @param[in] backend_options A LoadBackendOptionsMap containing per-delegate
161+
* load-time configuration options. The caller must ensure this object
162+
* outlives any methods loaded with these options.
163+
* @param[in] verification The type of verification to do before returning
164+
* success.
165+
*
166+
* @returns An Error to indicate success or failure of the loading process.
167+
*/
168+
ET_NODISCARD virtual runtime::Error load(
169+
const LoadBackendOptionsMap& backend_options,
170+
const Program::Verification verification =
171+
Program::Verification::Minimal);
172+
156173
/**
157174
* Checks if the program is loaded.
158175
*
@@ -207,12 +224,13 @@ class Module {
207224
runtime::Error load_method(
208225
const std::string& method_name,
209226
runtime::HierarchicalAllocator* planned_memory = nullptr,
210-
torch::executor::EventTracer* event_tracer = nullptr);
227+
torch::executor::EventTracer* event_tracer = nullptr,
228+
const LoadBackendOptionsMap* backend_options = nullptr);
211229

212230
ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method(
213231
const std::string& method_name,
214232
torch::executor::EventTracer* event_tracer) {
215-
return load_method(method_name, nullptr, event_tracer);
233+
return load_method(method_name, nullptr, event_tracer, nullptr);
216234
}
217235

218236
/**
@@ -254,13 +272,15 @@ class Module {
254272
*/
255273
ET_NODISCARD inline runtime::Error load_forward(
256274
runtime::HierarchicalAllocator* planned_memory = nullptr,
257-
torch::executor::EventTracer* event_tracer = nullptr) {
258-
return load_method("forward", planned_memory, event_tracer);
275+
torch::executor::EventTracer* event_tracer = nullptr,
276+
const LoadBackendOptionsMap* backend_options = nullptr) {
277+
return load_method(
278+
"forward", planned_memory, event_tracer, backend_options);
259279
}
260280

261281
ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward(
262282
torch::executor::EventTracer* event_tracer) {
263-
return load_forward(nullptr, event_tracer);
283+
return load_forward(nullptr, event_tracer, nullptr);
264284
}
265285

266286
/**
@@ -650,6 +670,10 @@ class Module {
650670
std::vector<std::unique_ptr<NamedDataMap>> named_data_maps_;
651671
std::unique_ptr<NamedDataMap> merged_data_map_;
652672
ET_DEPRECATED std::vector<uint8_t> debug_buffer_;
673+
const LoadBackendOptionsMap* backend_options_ = nullptr;
674+
675+
ET_NODISCARD runtime::Error load_internal(
676+
const Program::Verification verification);
653677

654678
protected:
655679
std::unordered_map<std::string, MethodHolder> methods_;

extension/module/test/module_test.cpp

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#include <executorch/extension/data_loader/file_data_loader.h>
1717
#include <executorch/extension/tensor/tensor.h>
18+
#include <executorch/runtime/backend/backend_options_map.h>
19+
#include <executorch/runtime/backend/options.h>
1820
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1921

2022
using namespace ::executorch::extension;
@@ -554,3 +556,165 @@ TEST_F(ModuleTest, TestPTD_Multiple) {
554556
auto tensor2 = make_tensor_ptr({3}, {2.f, 3.f, 4.f});
555557
ASSERT_EQ(module_linear.forward(tensor2).error(), Error::Ok);
556558
}
559+
560+
// =============================================================================
561+
// LoadBackendOptionsMap / RuntimeSpec Tests
562+
// =============================================================================
563+
564+
TEST_F(ModuleTest, TestLoadWithLoadBackendOptionsMap) {
565+
Module module(model_path_);
566+
567+
// Create a LoadBackendOptionsMap with some options for a hypothetical backend
568+
LoadBackendOptionsMap backend_options;
569+
BackendOptions<4> options;
570+
options.set_option("compute_unit", "cpu_only");
571+
options.set_option("debug_mode", true);
572+
ASSERT_EQ(
573+
backend_options.set_options("TestBackend", options.view()), Error::Ok);
574+
575+
// Load with backend options - should succeed even though the model
576+
// doesn't use this backend (options are simply passed through)
577+
EXPECT_FALSE(module.is_loaded());
578+
const auto error = module.load(backend_options);
579+
EXPECT_EQ(error, Error::Ok);
580+
EXPECT_TRUE(module.is_loaded());
581+
}
582+
583+
TEST_F(ModuleTest, TestLoadWithLoadBackendOptionsMapThenExecute) {
584+
Module module(model_path_);
585+
586+
LoadBackendOptionsMap backend_options;
587+
BackendOptions<2> options;
588+
options.set_option("key1", "value1");
589+
ASSERT_EQ(
590+
backend_options.set_options("SomeBackend", options.view()), Error::Ok);
591+
592+
const auto load_error = module.load(backend_options);
593+
EXPECT_EQ(load_error, Error::Ok);
594+
595+
// Execute should work normally
596+
auto tensor = make_tensor_ptr({2, 2}, {1.f, 2.f, 3.f, 4.f});
597+
const auto result = module.execute("forward", {tensor, tensor, 1.0});
598+
EXPECT_EQ(result.error(), Error::Ok);
599+
600+
const auto expected = make_tensor_ptr({2, 2}, {2.f, 4.f, 6.f, 8.f});
601+
EXPECT_TENSOR_CLOSE(result->at(0).toTensor(), *expected.get());
602+
}
603+
604+
TEST_F(ModuleTest, TestLoadMethodWithLoadBackendOptionsMap) {
605+
Module module(model_path_);
606+
607+
LoadBackendOptionsMap backend_options;
608+
BackendOptions<2> options;
609+
options.set_option("option1", 42);
610+
ASSERT_EQ(
611+
backend_options.set_options("AnotherBackend", options.view()), Error::Ok);
612+
613+
EXPECT_FALSE(module.is_method_loaded("forward"));
614+
const auto error =
615+
module.load_method("forward", nullptr, nullptr, &backend_options);
616+
EXPECT_EQ(error, Error::Ok);
617+
EXPECT_TRUE(module.is_method_loaded("forward"));
618+
EXPECT_TRUE(module.is_loaded());
619+
}
620+
621+
TEST_F(ModuleTest, TestLoadForwardWithLoadBackendOptionsMap) {
622+
Module module(model_path_);
623+
624+
LoadBackendOptionsMap backend_options;
625+
BackendOptions<2> options;
626+
options.set_option("setting", "enabled");
627+
ASSERT_EQ(
628+
backend_options.set_options("ForwardBackend", options.view()), Error::Ok);
629+
630+
EXPECT_FALSE(module.is_method_loaded("forward"));
631+
const auto error = module.load_forward(nullptr, nullptr, &backend_options);
632+
EXPECT_EQ(error, Error::Ok);
633+
EXPECT_TRUE(module.is_method_loaded("forward"));
634+
}
635+
636+
TEST_F(ModuleTest, TestLoadWithEmptyLoadBackendOptionsMap) {
637+
Module module(model_path_);
638+
639+
// Empty LoadBackendOptionsMap should work fine
640+
LoadBackendOptionsMap backend_options;
641+
642+
const auto error = module.load(backend_options);
643+
EXPECT_EQ(error, Error::Ok);
644+
EXPECT_TRUE(module.is_loaded());
645+
}
646+
647+
TEST_F(ModuleTest, TestLoadBackendOptionsMapPersistedAcrossLoadMethod) {
648+
Module module(model_path_);
649+
650+
// Set backend options via load()
651+
LoadBackendOptionsMap backend_options;
652+
BackendOptions<2> options;
653+
options.set_option("persist_test", true);
654+
ASSERT_EQ(
655+
backend_options.set_options("PersistBackend", options.view()), Error::Ok);
656+
657+
const auto load_error = module.load(backend_options);
658+
EXPECT_EQ(load_error, Error::Ok);
659+
660+
// load_method without explicit backend_options should use the stored ones
661+
const auto method_error = module.load_method("forward");
662+
EXPECT_EQ(method_error, Error::Ok);
663+
EXPECT_TRUE(module.is_method_loaded("forward"));
664+
}
665+
666+
TEST_F(ModuleTest, TestLoadMethodOverridesStoredBackendOptions) {
667+
Module module(model_path_);
668+
669+
// Set initial backend options via load()
670+
LoadBackendOptionsMap initial_options;
671+
BackendOptions<2> opts1;
672+
opts1.set_option("source", "load");
673+
ASSERT_EQ(
674+
initial_options.set_options("TestBackend", opts1.view()), Error::Ok);
675+
676+
const auto load_error = module.load(initial_options);
677+
EXPECT_EQ(load_error, Error::Ok);
678+
679+
// Unload and reload with different options passed to load_method
680+
module.unload_method("forward");
681+
682+
LoadBackendOptionsMap override_options;
683+
BackendOptions<2> opts2;
684+
opts2.set_option("source", "load_method");
685+
ASSERT_EQ(
686+
override_options.set_options("TestBackend", opts2.view()), Error::Ok);
687+
688+
// The override_options should take precedence
689+
const auto method_error =
690+
module.load_method("forward", nullptr, nullptr, &override_options);
691+
EXPECT_EQ(method_error, Error::Ok);
692+
EXPECT_TRUE(module.is_method_loaded("forward"));
693+
}
694+
695+
TEST_F(ModuleTest, TestMultipleBackendsInOptionsMap) {
696+
Module module(model_path_);
697+
698+
LoadBackendOptionsMap backend_options;
699+
700+
// Add options for multiple backends
701+
BackendOptions<2> backend1_opts;
702+
backend1_opts.set_option("compute_unit", "gpu");
703+
ASSERT_EQ(
704+
backend_options.set_options("Backend1", backend1_opts.view()), Error::Ok);
705+
706+
BackendOptions<3> backend2_opts;
707+
backend2_opts.set_option("optimization_level", 3);
708+
backend2_opts.set_option("debug", false);
709+
ASSERT_EQ(
710+
backend_options.set_options("Backend2", backend2_opts.view()), Error::Ok);
711+
712+
const auto error = module.load(backend_options);
713+
EXPECT_EQ(error, Error::Ok);
714+
EXPECT_TRUE(module.is_loaded());
715+
716+
// Should still execute normally
717+
auto tensor = make_tensor_ptr({2, 2}, {1.f, 2.f, 3.f, 4.f});
718+
const auto result = module.forward({tensor, tensor, 1.0});
719+
EXPECT_EQ(result.error(), Error::Ok);
720+
}

extension/module/test/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def define_common_targets(is_fbcode=False):
3737
"//executorch/extension/data_loader:file_data_loader",
3838
"//executorch/extension/module:module" + aten_suffix,
3939
"//executorch/extension/tensor:tensor" + aten_suffix,
40+
"//executorch/runtime/backend:backend_options",
41+
"//executorch/runtime/backend:backend_options_map",
4042
"//executorch/runtime/core/exec_aten/testing_util:tensor_util" + aten_suffix,
4143
],
4244
env = modules_env,

0 commit comments

Comments
 (0)