|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <executorch/runtime/backend/backend_init_context.h> |
| 10 | +#include <executorch/runtime/backend/options.h> |
| 11 | +#include <executorch/runtime/platform/runtime.h> |
| 12 | + |
| 13 | +#include <gtest/gtest.h> |
| 14 | + |
| 15 | +using namespace ::testing; |
| 16 | +using executorch::runtime::BackendInitContext; |
| 17 | +using executorch::runtime::BackendOption; |
| 18 | +using executorch::runtime::BackendOptions; |
| 19 | +using executorch::runtime::Error; |
| 20 | +using executorch::runtime::Span; |
| 21 | + |
| 22 | +class BackendInitContextTest : public ::testing::Test { |
| 23 | + protected: |
| 24 | + void SetUp() override { |
| 25 | + executorch::runtime::runtime_init(); |
| 26 | + } |
| 27 | +}; |
| 28 | + |
| 29 | +// Test default constructor without runtime specs |
| 30 | +TEST_F(BackendInitContextTest, DefaultConstructorNoRuntimeSpecs) { |
| 31 | + BackendInitContext context(nullptr); |
| 32 | + |
| 33 | + auto specs = context.runtime_specs(); |
| 34 | + EXPECT_EQ(specs.size(), 0); |
| 35 | +} |
| 36 | + |
| 37 | +// Test constructor with runtime specs |
| 38 | +TEST_F(BackendInitContextTest, ConstructorWithRuntimeSpecs) { |
| 39 | + BackendOptions<4> opts; |
| 40 | + opts.set_option("compute_unit", "cpu_and_gpu"); |
| 41 | + opts.set_option("num_threads", 4); |
| 42 | + opts.set_option("enable_profiling", true); |
| 43 | + |
| 44 | + // Create a const span from the mutable view |
| 45 | + auto view = opts.view(); |
| 46 | + Span<const BackendOption> const_span(view.data(), view.size()); |
| 47 | + |
| 48 | + BackendInitContext context( |
| 49 | + nullptr, // runtime_allocator |
| 50 | + nullptr, // event_tracer |
| 51 | + "forward", // method_name |
| 52 | + nullptr, // named_data_map |
| 53 | + const_span // runtime_specs |
| 54 | + ); |
| 55 | + |
| 56 | + auto specs = context.runtime_specs(); |
| 57 | + EXPECT_EQ(specs.size(), 3); |
| 58 | +} |
| 59 | + |
| 60 | +// Test get_runtime_spec<bool> with valid key |
| 61 | +TEST_F(BackendInitContextTest, GetRuntimeSpecBoolValid) { |
| 62 | + BackendOptions<2> opts; |
| 63 | + opts.set_option("enable_profiling", true); |
| 64 | + opts.set_option("debug_mode", false); |
| 65 | + |
| 66 | + auto view = opts.view(); |
| 67 | + Span<const BackendOption> const_span(view.data(), view.size()); |
| 68 | + |
| 69 | + BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span); |
| 70 | + |
| 71 | + auto result1 = context.get_runtime_spec<bool>("enable_profiling"); |
| 72 | + EXPECT_TRUE(result1.ok()); |
| 73 | + EXPECT_TRUE(result1.get()); |
| 74 | + |
| 75 | + auto result2 = context.get_runtime_spec<bool>("debug_mode"); |
| 76 | + EXPECT_TRUE(result2.ok()); |
| 77 | + EXPECT_FALSE(result2.get()); |
| 78 | +} |
| 79 | + |
| 80 | +// Test get_runtime_spec<int> with valid key |
| 81 | +TEST_F(BackendInitContextTest, GetRuntimeSpecIntValid) { |
| 82 | + BackendOptions<2> opts; |
| 83 | + opts.set_option("num_threads", 8); |
| 84 | + opts.set_option("batch_size", 32); |
| 85 | + |
| 86 | + auto view = opts.view(); |
| 87 | + Span<const BackendOption> const_span(view.data(), view.size()); |
| 88 | + |
| 89 | + BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span); |
| 90 | + |
| 91 | + auto result1 = context.get_runtime_spec<int>("num_threads"); |
| 92 | + EXPECT_TRUE(result1.ok()); |
| 93 | + EXPECT_EQ(result1.get(), 8); |
| 94 | + |
| 95 | + auto result2 = context.get_runtime_spec<int>("batch_size"); |
| 96 | + EXPECT_TRUE(result2.ok()); |
| 97 | + EXPECT_EQ(result2.get(), 32); |
| 98 | +} |
| 99 | + |
| 100 | +// Test get_runtime_spec<const char*> with valid key |
| 101 | +TEST_F(BackendInitContextTest, GetRuntimeSpecStringValid) { |
| 102 | + BackendOptions<2> opts; |
| 103 | + opts.set_option("compute_unit", "cpu_and_gpu"); |
| 104 | + opts.set_option("cache_dir", "/tmp/cache"); |
| 105 | + |
| 106 | + auto view = opts.view(); |
| 107 | + Span<const BackendOption> const_span(view.data(), view.size()); |
| 108 | + |
| 109 | + BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span); |
| 110 | + |
| 111 | + auto result1 = context.get_runtime_spec<const char*>("compute_unit"); |
| 112 | + EXPECT_TRUE(result1.ok()); |
| 113 | + EXPECT_STREQ(result1.get(), "cpu_and_gpu"); |
| 114 | + |
| 115 | + auto result2 = context.get_runtime_spec<const char*>("cache_dir"); |
| 116 | + EXPECT_TRUE(result2.ok()); |
| 117 | + EXPECT_STREQ(result2.get(), "/tmp/cache"); |
| 118 | +} |
| 119 | + |
| 120 | +// Test get_runtime_spec<T> with non-existent key returns NotFound |
| 121 | +TEST_F(BackendInitContextTest, GetRuntimeSpecNotFound) { |
| 122 | + BackendOptions<1> opts; |
| 123 | + opts.set_option("key", "value"); |
| 124 | + |
| 125 | + auto view = opts.view(); |
| 126 | + Span<const BackendOption> const_span(view.data(), view.size()); |
| 127 | + |
| 128 | + BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span); |
| 129 | + |
| 130 | + auto bool_result = context.get_runtime_spec<bool>("nonexistent"); |
| 131 | + EXPECT_FALSE(bool_result.ok()); |
| 132 | + EXPECT_EQ(bool_result.error(), Error::NotFound); |
| 133 | + |
| 134 | + auto int_result = context.get_runtime_spec<int>("nonexistent"); |
| 135 | + EXPECT_FALSE(int_result.ok()); |
| 136 | + EXPECT_EQ(int_result.error(), Error::NotFound); |
| 137 | + |
| 138 | + auto string_result = context.get_runtime_spec<const char*>("nonexistent"); |
| 139 | + EXPECT_FALSE(string_result.ok()); |
| 140 | + EXPECT_EQ(string_result.error(), Error::NotFound); |
| 141 | +} |
| 142 | + |
| 143 | +// Test get_runtime_spec<T> with wrong type returns InvalidArgument |
| 144 | +TEST_F(BackendInitContextTest, GetRuntimeSpecTypeMismatch) { |
| 145 | + BackendOptions<3> opts; |
| 146 | + opts.set_option("bool_opt", true); |
| 147 | + opts.set_option("int_opt", 42); |
| 148 | + opts.set_option("string_opt", "hello"); |
| 149 | + |
| 150 | + auto view = opts.view(); |
| 151 | + Span<const BackendOption> const_span(view.data(), view.size()); |
| 152 | + |
| 153 | + BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span); |
| 154 | + |
| 155 | + // Try to get bool as int |
| 156 | + auto result1 = context.get_runtime_spec<int>("bool_opt"); |
| 157 | + EXPECT_FALSE(result1.ok()); |
| 158 | + EXPECT_EQ(result1.error(), Error::InvalidArgument); |
| 159 | + |
| 160 | + // Try to get int as string |
| 161 | + auto result2 = context.get_runtime_spec<const char*>("int_opt"); |
| 162 | + EXPECT_FALSE(result2.ok()); |
| 163 | + EXPECT_EQ(result2.error(), Error::InvalidArgument); |
| 164 | + |
| 165 | + // Try to get string as bool |
| 166 | + auto result3 = context.get_runtime_spec<bool>("string_opt"); |
| 167 | + EXPECT_FALSE(result3.ok()); |
| 168 | + EXPECT_EQ(result3.error(), Error::InvalidArgument); |
| 169 | +} |
| 170 | + |
| 171 | +// Test empty runtime specs |
| 172 | +TEST_F(BackendInitContextTest, EmptyRuntimeSpecs) { |
| 173 | + Span<const BackendOption> empty_span; |
| 174 | + BackendInitContext context(nullptr, nullptr, nullptr, nullptr, empty_span); |
| 175 | + |
| 176 | + EXPECT_EQ(context.runtime_specs().size(), 0); |
| 177 | + |
| 178 | + // All lookups should return NotFound |
| 179 | + auto bool_result = context.get_runtime_spec<bool>("any_key"); |
| 180 | + EXPECT_FALSE(bool_result.ok()); |
| 181 | + EXPECT_EQ(bool_result.error(), Error::NotFound); |
| 182 | +} |
| 183 | + |
| 184 | +// Test that other context fields still work |
| 185 | +TEST_F(BackendInitContextTest, OtherFieldsStillWork) { |
| 186 | + BackendOptions<1> opts; |
| 187 | + opts.set_option("key", "value"); |
| 188 | + |
| 189 | + auto view = opts.view(); |
| 190 | + Span<const BackendOption> const_span(view.data(), view.size()); |
| 191 | + |
| 192 | + BackendInitContext context( |
| 193 | + nullptr, // runtime_allocator |
| 194 | + nullptr, // event_tracer |
| 195 | + "forward", // method_name |
| 196 | + nullptr, // named_data_map |
| 197 | + const_span // runtime_specs |
| 198 | + ); |
| 199 | + |
| 200 | + EXPECT_EQ(context.get_runtime_allocator(), nullptr); |
| 201 | + EXPECT_EQ(context.event_tracer(), nullptr); |
| 202 | + EXPECT_STREQ(context.get_method_name(), "forward"); |
| 203 | + EXPECT_EQ(context.get_named_data_map(), nullptr); |
| 204 | +} |
0 commit comments