Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions backends/apple/coreml/runtime/delegate/backend_delegate.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

#include "model_logging_options.h"

#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/span.h>

#include <system_error>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -75,12 +78,15 @@ class BackendDelegate {
/// @param method_name The ExecuTorch method name for metadata lookup (optional, may be nullptr).
/// @param function_name The CoreML function name to invoke (optional, may be nullptr).
/// If nullptr, method_name is used as the function name.
/// @param runtime_specs Runtime options passed via BackendOptions (e.g., cache_dir).
/// @retval An opaque handle to the initialized blob or `nullptr` if the
/// initialization failed.
virtual Handle* init(Buffer processed,
const std::unordered_map<std::string, Buffer>& specs,
const char* method_name = nullptr,
const char* function_name = nullptr) const noexcept = 0;
virtual Handle*
init(Buffer processed,
const std::unordered_map<std::string, Buffer>& specs,
const char* method_name = nullptr,
const char* function_name = nullptr,
executorch::runtime::Span<const executorch::runtime::BackendOption> runtime_specs = {}) const noexcept = 0;

/// Must execute the CoreML model with the specified handle.
///
Expand Down
63 changes: 59 additions & 4 deletions backends/apple/coreml/runtime/delegate/backend_delegate.mm
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#import "ETCoreMLAssetManager.h"
#import "ETCoreMLLogging.h"
#import "ETCoreMLModel.h"
#import "ETCoreMLModelCache.h"
#import "ETCoreMLModelManager.h"
#import "ETCoreMLStrings.h"
#import "model_event_logger.h"
Expand Down Expand Up @@ -100,7 +101,14 @@ - (ModelHandle*)loadModelFromAOTData:(NSData*)data
configuration:(MLModelConfiguration*)configuration
methodName:(nullable NSString*)methodName
functionName:(nullable NSString*)functionName
error:(NSError* __autoreleasing*)error;
error:(NSError* __autoreleasing*)error;

- (ModelHandle*)loadModelFromAOTData:(NSData*)data
configuration:(MLModelConfiguration*)configuration
methodName:(nullable NSString*)methodName
functionName:(nullable NSString*)functionName
cachePath:(nullable NSString*)cachePath
error:(NSError* __autoreleasing*)error;

- (ModelHandle*)loadModelFromAOTData:(NSData*)data
configuration:(MLModelConfiguration*)configuration
Expand Down Expand Up @@ -199,14 +207,47 @@ - (ModelHandle*)loadModelFromAOTData:(NSData*)data
methodName:(nullable NSString*)methodName
functionName:(nullable NSString*)functionName
error:(NSError* __autoreleasing*)error {
return [self loadModelFromAOTData:data
configuration:configuration
methodName:methodName
functionName:functionName
cachePath:nil
error:error];
}

- (ModelHandle*)loadModelFromAOTData:(NSData*)data
configuration:(MLModelConfiguration*)configuration
methodName:(nullable NSString*)methodName
functionName:(nullable NSString*)functionName
cachePath:(nullable NSString*)cachePath
error:(NSError* __autoreleasing*)error {
if (![self loadAndReturnError:error]) {
return nil;
}

id<ETCoreMLCache> cache = nil;
if (cachePath != nil) {
// Use NEW filesystem cache at specified path
NSURL *cacheURL = [NSURL fileURLWithPath:cachePath isDirectory:YES];
ETCoreMLModelCache *modelCache = [[ETCoreMLModelCache alloc] initWithCacheRootDirectory:cacheURL];
if (!modelCache.isReady) {
// Fallback error if initializationError is unexpectedly nil
NSError *cacheError = modelCache.initializationError
?: [NSError errorWithDomain:ETCoreMLModelCacheErrorDomain
code:ETCoreMLModelCacheErrorCodeInitializationFailed
userInfo:@{NSLocalizedDescriptionKey: @"Cache initialization failed"}];
if (error) *error = cacheError;
return nil;
}
cache = modelCache;
}
// cache == nil means loadModelFromAOTData will use self.cache (default cache)

auto handle = [self.impl loadModelFromAOTData:data
configuration:configuration
methodName:methodName
functionName:functionName
cache:cache
error:error];
if ((handle != NULL) && self.config.should_prewarm_model) {
[self.impl prewarmModelWithHandle:handle error:nil];
Expand Down Expand Up @@ -291,9 +332,10 @@ explicit BackendDelegateImpl(const Config& config) noexcept
BackendDelegateImpl& operator=(BackendDelegateImpl const&) = delete;

Handle *init(Buffer processed,
const std::unordered_map<std::string, Buffer>& specs,
const char* method_name = nullptr,
const char* function_name = nullptr) const noexcept override {
const std::unordered_map<std::string, Buffer>& specs,
const char* method_name = nullptr,
const char* function_name = nullptr,
executorch::runtime::Span<const executorch::runtime::BackendOption> runtime_specs = {}) const noexcept override {
NSError *localError = nil;
MLModelConfiguration *configuration = get_model_configuration(specs, &localError);
if (configuration == nil) {
Expand All @@ -304,13 +346,26 @@ explicit BackendDelegateImpl(const Config& config) noexcept
NSString *methodNameStr = method_name ? @(method_name) : nil;
NSString *functionNameStr = function_name ? @(function_name) : nil;

// Parse cache_dir from runtime_specs
NSString *cachePath = nil;
for (size_t i = 0; i < runtime_specs.size(); ++i) {
const auto& opt = runtime_specs[i];
if (std::strcmp(opt.key, "cache_dir") == 0) {
if (auto* arr = std::get_if<std::array<char, executorch::runtime::kMaxOptionValueLength>>(&opt.value)) {
cachePath = @(arr->data());
}
break;
}
}

NSData *data = [NSData dataWithBytesNoCopy:const_cast<void *>(processed.data())
length:processed.size()
freeWhenDone:NO];
ModelHandle *modelHandle = [model_manager_ loadModelFromAOTData:data
configuration:configuration
methodName:methodNameStr
functionName:functionNameStr
cachePath:cachePath
error:&localError];
if (localError != nil) {
ETCoreMLLogError(localError, "Model init failed");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ ModelLoggingOptions get_logging_options(BackendExecutionContext& context) {

std::error_code error;
const char* function_name_cstr = functionName.empty() ? nullptr : functionName.c_str();
auto handle = impl_->init(std::move(buffer), specs_map, method_name, function_name_cstr);
auto handle = impl_->init(std::move(buffer), specs_map, method_name, function_name_cstr, runtime_specs);
ET_CHECK_OR_RETURN_ERROR(handle != nullptr,
InvalidProgram,
"%s: Failed to init the model.", ETCoreMLStrings.delegateIdentifier.UTF8String);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ class LoadOptionsBuilder {
return *this;
}

/**
* Sets a custom cache directory for compiled model storage.
*
* When specified, the backend uses the new filesystem-based cache
* (ETCoreMLModelCache) at the given path instead of the default
* SQLite-based cache (ETCoreMLAssetManager).
*
* This enables per-model cache selection for experimentation:
* - Specify cache_dir to opt-in to the new filesystem cache
* - Omit cache_dir to use the default (legacy) cache
*
* @param path The directory path for the cache. Must be a valid
* filesystem path with write permissions.
* @return Reference to this builder for chaining.
*/
LoadOptionsBuilder& setCacheDirectory(const char* path) {
options_.set_option("cache_dir", path);
return *this;
}

/**
* Returns the backend identifier for this options builder.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,83 @@ TEST_F(CoreMLBackendOptionsTest, SetComputeUnitMultipleTimes) {
FAIL() << "Expected string value for compute_unit";
}
}

// Test setCacheDirectory
TEST_F(CoreMLBackendOptionsTest, SetCacheDirectory) {
LoadOptionsBuilder builder;
builder.setCacheDirectory("/path/to/cache");

auto options = builder.view();
EXPECT_EQ(options.size(), 1);
EXPECT_STREQ(options[0].key, "cache_dir");

if (auto* arr = std::get_if<std::array<char, kMaxOptionValueLength>>(&options[0].value)) {
EXPECT_STREQ(arr->data(), "/path/to/cache");
} else {
FAIL() << "Expected string value for cache_dir";
}
}

// Test setCacheDirectory method chaining
TEST_F(CoreMLBackendOptionsTest, SetCacheDirectoryChaining) {
LoadOptionsBuilder builder;
auto& result = builder.setCacheDirectory("/tmp/cache");

// Should return reference to the same builder
EXPECT_EQ(&result, &builder);
}

// Test combining setComputeUnit and setCacheDirectory
TEST_F(CoreMLBackendOptionsTest, CombinedOptions) {
LoadOptionsBuilder builder;
builder.setComputeUnit(LoadOptionsBuilder::ComputeUnit::CPU_AND_NE).setCacheDirectory("/data/experiment_cache");

auto options = builder.view();
EXPECT_EQ(options.size(), 2);

// Verify compute_unit
EXPECT_STREQ(options[0].key, "compute_unit");
if (auto* arr = std::get_if<std::array<char, kMaxOptionValueLength>>(&options[0].value)) {
EXPECT_STREQ(arr->data(), "cpu_and_ne");
} else {
FAIL() << "Expected string value for compute_unit";
}

// Verify cache_dir
EXPECT_STREQ(options[1].key, "cache_dir");
if (auto* arr = std::get_if<std::array<char, kMaxOptionValueLength>>(&options[1].value)) {
EXPECT_STREQ(arr->data(), "/data/experiment_cache");
} else {
FAIL() << "Expected string value for cache_dir";
}
}

// Test integration with LoadBackendOptionsMap including cache_dir
TEST_F(CoreMLBackendOptionsTest, IntegrationWithOptionsMapCacheDir) {
LoadOptionsBuilder coreml_opts;
coreml_opts.setComputeUnit(LoadOptionsBuilder::ComputeUnit::ALL).setCacheDirectory("/custom/cache/path");

LoadBackendOptionsMap map;
EXPECT_EQ(map.set_options(coreml_opts), Error::Ok);

EXPECT_EQ(map.size(), 1);
EXPECT_TRUE(map.has_options("CoreMLBackend"));

auto retrieved = map.get_options("CoreMLBackend");
EXPECT_EQ(retrieved.size(), 2);

// Find cache_dir option
bool found_cache_dir = false;
for (size_t i = 0; i < retrieved.size(); ++i) {
if (std::strcmp(retrieved[i].key, "cache_dir") == 0) {
found_cache_dir = true;
if (auto* arr = std::get_if<std::array<char, kMaxOptionValueLength>>(&retrieved[i].value)) {
EXPECT_STREQ(arr->data(), "/custom/cache/path");
} else {
FAIL() << "Expected string value for cache_dir";
}
break;
}
}
EXPECT_TRUE(found_cache_dir) << "cache_dir option not found";
}
Loading