diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d90189496af42..2bb73ea23854f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -593,6 +593,62 @@ typedef OrtStatus*(ORT_API_CALL* OrtWriteBufferFunc)(_In_ void* state, _In_ const void* buffer, _In_ size_t buffer_num_bytes); +/** \brief Function called to write EPContext binary data during compilation. + * + * This function is called synchronously by OrtEpApi::WriteEpContextData on the calling thread. ORT does not retain + * buffer after the callback returns, does not reorder callback invocations, and does not serialize invocations made by + * different EP instances or EP worker threads. + * + * Each callback invocation represents one complete write operation for file_name. The callback signature does not + * provide an offset, sequence number, or final-chunk marker, so EPs that need chunked streaming must define their own + * ordering and completion contract with the application. EPs should prefer a single callback invocation per EPContext + * binary unless chunking semantics are documented by that EP. + * + * The application's implementation can process the data in any way (e.g., encrypt and store, upload to cloud storage, + * or compress) before persisting it. + * + * \param[in] state Opaque pointer holding the user's state. ORT does not own or manage this pointer. The application + * must keep it valid for the duration of any compile operation that may invoke this callback and must + * provide any synchronization required if it can be used concurrently. + * \param[in] file_name The intended EPContext binary file name as a null-terminated UTF-8 string. + * \param[in] buffer The buffer containing EPContext binary data to write. + * \param[in] buffer_num_bytes The size of the buffer in bytes. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info with ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtWriteEpContextDataFunc)(_In_ void* state, + _In_ const char* file_name, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes); + +/** \brief Function called by ORT to read EPContext binary data during session load. + * + * The application reads, processes (e.g., decrypts, decompresses, downloads), and returns the EPContext binary data. + * ORT provides an allocator so the application can allocate the output buffer directly. The callback is called + * synchronously by OrtEpApi::ReadEpContextData on the calling thread. ORT does not serialize invocations made by + * different EP instances or EP worker threads. + * + * \param[in] state Opaque pointer holding the user's state. ORT does not own or manage this pointer. The application + * must keep it valid while any session or EP created from the associated OrtSessionOptions may invoke + * this callback and must provide any synchronization required if it can be used concurrently. + * \param[in] file_name The EPContext binary file name as a null-terminated UTF-8 string. + * \param[in] allocator ORT-provided allocator. The application must use this to allocate the output buffer. + * \param[out] buffer Set by the implementation to the allocated buffer containing the output data. + * \param[out] data_size Set by the implementation to the size of the output data in bytes. + * + * \return OrtStatus* Read status. Return nullptr on success. + * On failure, ORT ignores callback outputs and treats buffer/data_size as unset. + * Use CreateStatus to provide error info with ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtReadEpContextDataFunc)(_In_ void* state, + _In_ const char* file_name, + _In_ OrtAllocator* allocator, + _Outptr_ void** buffer, + _Out_ size_t* data_size); + /** \brief Function called by ORT to allow user to specify how an initializer should be saved, that is, either * written to an external file or stored within the model. ORT calls this function for every initializer when * generating a model. @@ -7486,6 +7542,26 @@ struct OrtApi { * \since Version 1.27. */ ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id); + + /** \brief Registers a callback to provide EPContext binary data during session load. + * + * When loading a compiled model with external (non-embedded) EPContext binary data, an execution provider can use + * OrtEpApi::ReadEpContextData to call this callback instead of reading the binary data from disk. + * + * The state pointer is stored as-is and is not owned by ORT. It must remain valid while any session or EP created + * from these options may call the callback. If the same state may be used by multiple EPs or threads, the application + * is responsible for synchronization. + * + * \param[in] options The OrtSessionOptions instance. + * \param[in] read_func The OrtReadEpContextDataFunc callback. + * \param[in] state Opaque state passed to read_func. Can be NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + */ + ORT_API2_STATUS(SessionOptions_SetEpContextDataReadFunc, _Inout_ OrtSessionOptions* options, + _In_ OrtReadEpContextDataFunc read_func, _In_opt_ void* state); }; /* @@ -8307,6 +8383,27 @@ struct OrtCompileApi { ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const OrtModel* model); + + /** \brief Sets a callback for writing EPContext binary data during compilation. + * + * When EPContext embed mode is disabled, execution providers can use OrtEpApi::WriteEpContextData to call this + * callback instead of writing EPContext binary data directly to disk. + * + * The state pointer is stored as-is and is not owned by ORT. It must remain valid for the duration of the compile + * operation that may call the callback. If the same state may be used by multiple EPs or threads, the application is + * responsible for synchronization. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] write_func The OrtWriteEpContextDataFunc called to write EPContext bytes. + * \param[in] state Opaque state passed to write_func. Can be NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetEpContextDataWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtWriteEpContextDataFunc write_func, _In_opt_ void* state); }; /** diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 42eeac19da377..f2af639664401 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -658,6 +658,7 @@ ORT_DEFINE_RELEASE(Value); ORT_DEFINE_RELEASE(ValueInfo); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpContextConfig, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi); @@ -786,6 +787,7 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; +struct EpContextConfig; struct EpDevice; struct ExternalInitializerInfo; struct Graph; @@ -1185,6 +1187,22 @@ struct EpDevice : detail::EpDeviceImpl { ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {}); }; +/** \brief Wrapper around ::OrtEpContextConfig + * + * Owns an OrtEpContextConfig handle and releases it via OrtEpApi::ReleaseEpContextConfig on destruction. + * The underlying pointer implicitly converts to OrtEpContextConfig* so it can be passed directly to + * OrtEpApi::ReadEpContextData / OrtEpApi::WriteEpContextData. + */ +struct EpContextConfig : detail::Base { + using B = detail::Base; + using B::B; // inherit default and take-ownership-from-pointer constructors + + explicit EpContextConfig(std::nullptr_t) {} ///< No instance is created + + /// \brief Wraps OrtEpApi::SessionOptions_GetEpContextConfig + explicit EpContextConfig(const OrtSessionOptions* session_options); +}; + /** \brief Validate a compiled model's compatibility for one or more EP devices. * * Throws on error. Returns the resulting compatibility status. @@ -1668,6 +1686,8 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { const std::vector& external_initializer_file_buffer_array, const std::vector& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory + SessionOptionsImpl& SetEpContextDataReadFunc(OrtReadEpContextDataFunc read_func, void* state); ///< Wraps OrtApi::SessionOptions_SetEpContextDataReadFunc + SessionOptionsImpl& AppendExecutionProvider_CPU(int use_arena); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CPU SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 @@ -1769,6 +1789,9 @@ struct ModelCompilationOptions : detail::Base { ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelWriteFunc ModelCompilationOptions& SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); + ///< Wraps OrtCompileApi::ModelCompilationOptions_SetEpContextDataWriteFunc + ModelCompilationOptions& SetEpContextDataWriteFunc(OrtWriteEpContextDataFunc write_func, void* state); + ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 61bc31736f5b5..cd6ff97a2c8b3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -769,6 +769,10 @@ inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardwar ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_)); } +inline EpContextConfig::EpContextConfig(const OrtSessionOptions* session_options) { + ThrowOnError(GetEpApi().SessionOptions_GetEpContextConfig(session_options, &this->p_)); +} + namespace detail { template inline std::string EpAssignedSubgraphImpl::GetEpName() const { @@ -1335,6 +1339,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelWriteFunc return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextDataWriteFunc( + OrtWriteEpContextDataFunc write_func, void* state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextDataWriteFunc(this->p_, write_func, state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( bool embed_ep_context_in_model) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode( @@ -1574,6 +1584,13 @@ inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializersFrom return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::SetEpContextDataReadFunc(OrtReadEpContextDataFunc read_func, + void* state) { + ThrowOnError(GetApi().SessionOptions_SetEpContextDataReadFunc(this->p_, read_func, state)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CPU(int use_arena) { ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(this->p_, use_arena)); diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index b816528f1f2ba..f26650809239c 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -18,6 +18,7 @@ extern "C" { * @{ */ ORT_RUNTIME_CLASS(Ep); +ORT_RUNTIME_CLASS(EpContextConfig); ORT_RUNTIME_CLASS(EpFactory); ORT_RUNTIME_CLASS(EpGraphSupportInfo); ORT_RUNTIME_CLASS(MemoryDevice); // opaque class to wrap onnxruntime::OrtDevice @@ -2077,6 +2078,92 @@ struct OrtEpApi { ORT_API2_STATUS(ProfilingEventsContainer_AddEvents, _In_ OrtProfilingEventsContainer* events_container, _In_reads_(num_events) const OrtProfilingEvent* const* events, _In_ size_t num_events); + + /** \brief Get the EPContext configuration from session options. + * + * Extracts EPContext-related data I/O callbacks from the session options into an opaque OrtEpContextConfig handle. + * The EP should call this during CreateEp() while session_options is still valid, and store the returned handle for + * use during Compile(). The returned config is always non-NULL and must be released with ReleaseEpContextConfig. + * + * The returned handle owns only ORT's copy of callback function pointers and opaque state pointer values. It does not + * own the application-provided state. The application is responsible for keeping callback state valid and + * synchronized while an EP may call ReadEpContextData or WriteEpContextData with this config. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[out] config The extracted OrtEpContextConfig. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + */ + ORT_API2_STATUS(SessionOptions_GetEpContextConfig, + _In_ const OrtSessionOptions* session_options, + _Outptr_ OrtEpContextConfig** config); + + /** \brief Release an OrtEpContextConfig instance. + * + * \param[in] input The OrtEpContextConfig instance to release. May be NULL. + * + * \since Version 1.27. + */ + ORT_CLASS_RELEASE(EpContextConfig); + + /** \brief Read EPContext binary data, using an application read callback or falling back to disk. + * + * If config contains a read callback, the callback is invoked with the provided allocator. Otherwise, ORT reads the + * file from disk. The disk fallback derives the base directory from the graph's model path. + * + * This function is synchronous. If a callback is present, it is invoked on the calling thread and its OrtStatus is + * returned to the caller. ORT does not serialize concurrent calls across EP instances or EP worker threads. + * + * \param[in] config The OrtEpContextConfig from SessionOptions_GetEpContextConfig. May be NULL for disk fallback. + * \param[in] file_name EPContext file name as a null-terminated UTF-8 string. + * \param[in] graph The OrtGraph from which ORT derives the model path for disk fallback. May be NULL. + * \param[in] allocator Allocator for the output buffer. + * \param[out] buffer Output buffer containing the EPContext binary data. + * \param[out] buffer_size Size of the output buffer in bytes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + */ + ORT_API2_STATUS(ReadEpContextData, + _In_opt_ const OrtEpContextConfig* config, + _In_ const char* file_name, + _In_opt_ const OrtGraph* graph, + _Inout_ OrtAllocator* allocator, + _Outptr_ void** buffer, + _Out_ size_t* buffer_size); + + /** \brief Write EPContext binary data, using an application write callback or falling back to disk. + * + * If config contains a write callback, the data is forwarded to the application's callback. Otherwise, ORT writes the + * data to disk. The disk fallback derives the base directory from the graph's model path. + * + * This function is synchronous. If a callback is present, it is invoked on the calling thread and its OrtStatus is + * returned to the caller. ORT does not retain buffer after the callback returns, reorder callback invocations, or + * serialize concurrent calls across EP instances or EP worker threads. + * + * Each call is one complete write operation for file_name. The API does not provide an offset, sequence number, or + * final-chunk marker. EPs should prefer one call per EPContext binary, or document EP-specific chunk ordering and + * completion semantics if multiple calls are made for the same file_name. + * + * \param[in] config The OrtEpContextConfig from SessionOptions_GetEpContextConfig. May be NULL for disk fallback. + * \param[in] file_name EPContext file name as a null-terminated UTF-8 string. + * \param[in] graph The OrtGraph from which ORT derives the model path for disk fallback. May be NULL. + * \param[in] buffer The buffer containing EPContext binary data to write. + * \param[in] buffer_size Size of the buffer in bytes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + */ + ORT_API2_STATUS(WriteEpContextData, + _In_opt_ const OrtEpContextConfig* config, + _In_ const char* file_name, + _In_opt_ const OrtGraph* graph, + _In_ const void* buffer, + _In_ size_t buffer_size); }; /** diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc index 99fa21b1e4be8..b53a99084152f 100644 --- a/onnxruntime/core/framework/ep_context_options.cc +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -56,6 +56,10 @@ const BufferWriteFuncHolder* ModelGenOptions::TryGetOutputModelWriteFunc() const return std::get_if(&output_model_location); } +const EpContextDataWriteFuncHolder* ModelGenOptions::TryGetEpContextDataWriteFunc() const { + return ep_context_data_write_func.write_func != nullptr ? &ep_context_data_write_func : nullptr; +} + bool ModelGenOptions::AreInitializersEmbeddedInOutputModel() const { return std::holds_alternative(initializers_location); } diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h index 6643516bfb4c3..f05d0d95df73a 100644 --- a/onnxruntime/core/framework/ep_context_options.h +++ b/onnxruntime/core/framework/ep_context_options.h @@ -27,6 +27,14 @@ struct BufferWriteFuncHolder { void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. }; +/// +/// Holds the opaque state and write function that EPs use to write EPContext binary data. +/// +struct EpContextDataWriteFuncHolder { + OrtWriteEpContextDataFunc write_func = nullptr; + void* state = nullptr; +}; + /// /// Holds path and size threshold used to write out initializers to an external file. /// @@ -84,10 +92,13 @@ struct ModelGenOptions { InitializerHandler> // Custom function called for every initializer to determine location. initializers_location = std::monostate{}; + EpContextDataWriteFuncHolder ep_context_data_write_func = {}; + bool HasOutputModelLocation() const; const std::filesystem::path* TryGetOutputModelPath() const; const BufferHolder* TryGetOutputModelBuffer() const; const BufferWriteFuncHolder* TryGetOutputModelWriteFunc() const; + const EpContextDataWriteFuncHolder* TryGetEpContextDataWriteFunc() const; bool AreInitializersEmbeddedInOutputModel() const; const ExternalInitializerFileInfo* TryGetExternalInitializerFileInfo() const; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index b328fc916f885..ce406305cc17d 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -226,6 +226,9 @@ struct SessionOptions { bool has_explicit_ep_context_gen_options = false; epctx::ModelGenOptions ep_context_gen_options = {}; epctx::ModelGenOptions GetEpContextGenerationOptions() const; + + OrtReadEpContextDataFunc ep_context_data_read_func = nullptr; + void* ep_context_data_read_state = nullptr; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 06bd5c4d84089..3e5670caea676 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -127,6 +127,18 @@ ORT_API_STATUS_IMPL(OrtApis::GetSessionExecutionMode, _In_ const OrtSessionOptio API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptions_SetEpContextDataReadFunc, _Inout_ OrtSessionOptions* options, + _In_ OrtReadEpContextDataFunc read_func, _In_opt_ void* state) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(options == nullptr, ORT_INVALID_ARGUMENT, "'options' parameter must not be NULL"); + ORT_API_RETURN_IF(read_func == nullptr, ORT_INVALID_ARGUMENT, "'read_func' parameter must not be NULL"); + + options->value.ep_context_data_read_func = read_func; + options->value.ep_context_data_read_state = state; + return nullptr; + API_IMPL_END +} + // set filepath to save optimized onnx model. ORT_API_STATUS_IMPL(OrtApis::SetOptimizedModelFilePath, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath) { options->value.optimized_model_filepath = optimized_model_filepath; diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 54d26021d8c99..5a5567a1d4b92 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -259,6 +259,32 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInit API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextDataWriteFunc, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ OrtWriteEpContextDataFunc write_func, _In_opt_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (model_compile_options == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtModelCompilationOptions is null"); + } + + if (write_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtWriteEpContextDataFunc function is null"); + } + + model_compile_options->SetEpContextDataWriteFunc(write_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(write_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* ort_model_compile_options, bool embed_ep_context_in_model) { @@ -367,6 +393,8 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetInputModel, // End of Version 24 - DO NOT MODIFY ABOVE + + &OrtCompileAPI::ModelCompilationOptions_SetEpContextDataWriteFunc, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index e8f171ee24295..60c75bd5386a9 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -44,5 +44,8 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelGetInitializerLocation ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const OrtModel* model); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextDataWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtWriteEpContextDataFunc write_func, _In_opt_ void* state); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index efaf28fbeefc0..54ad31cb98887 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -129,6 +129,13 @@ void ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( }; } +void ModelCompilationOptions::SetEpContextDataWriteFunc(OrtWriteEpContextDataFunc write_func, void* state) { + session_options_.value.ep_context_gen_options.ep_context_data_write_func = epctx::EpContextDataWriteFuncHolder{ + write_func, + state, + }; +} + Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::filesystem::path& output_directory, const std::filesystem::path& model_name) { if (output_directory.empty() || model_name.empty()) { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 47529e794677e..e24286df2b512 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -97,6 +97,13 @@ class ModelCompilationOptions { void SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, void* state); + /// + /// Sets a user-provided function to handle EPContext binary data writes. + /// + /// The user-provided function called to write EPContext data + /// The user's state. + void SetEpContextDataWriteFunc(OrtWriteEpContextDataFunc write_func, void* state); + /// /// Sets information relate to EP context binary file. /// EP use this information to decide the location and context binary file name. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 549334564a1cf..95dca5fac8a06 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4926,6 +4926,7 @@ static constexpr OrtApi ort_api_1_to_27 = { &OrtApis::GetMemPatternEnabled, &OrtApis::GetSessionExecutionMode, &OrtApis::SessionReleaseCapturedGraph, + &OrtApis::SessionOptions_SetEpContextDataReadFunc, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index adccfe09bc3f7..141279c824eb0 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -66,6 +66,8 @@ ORT_API_STATUS_IMPL(EnableMemPattern, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(DisableMemPattern, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(GetMemPatternEnabled, _In_ const OrtSessionOptions* options, _Out_ int* out); ORT_API_STATUS_IMPL(GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out); +ORT_API_STATUS_IMPL(SessionOptions_SetEpContextDataReadFunc, _Inout_ OrtSessionOptions* options, + _In_ OrtReadEpContextDataFunc read_func, _In_opt_ void* state); ORT_API_STATUS_IMPL(EnableCpuMemArena, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(DisableCpuMemArena, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(SetSessionLogId, _In_ OrtSessionOptions* options, const char* logid); diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index d56f4299402b5..c1d01224e7c5c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -5,11 +5,15 @@ #include #include +#include +#include +#include #include #include #include #include +#include "core/common/path_string.h" #include "core/common/semver.h" #include "core/framework/error_code_helper.h" #include "core/framework/func_api.h" @@ -24,6 +28,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/abi_devices.h" #include "core/session/abi_ep_types.h" +#include "core/session/abi_session_options_impl.h" #include "core/session/abi_opschema.h" #include "core/session/environment.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" @@ -36,7 +41,79 @@ #include "core/session/plugin_ep/ep_event_profiling.h" using namespace onnxruntime; + +struct OrtEpContextConfig { + OrtWriteEpContextDataFunc write_func = nullptr; + void* write_state = nullptr; + OrtReadEpContextDataFunc read_func = nullptr; + void* read_state = nullptr; +}; + namespace OrtExecutionProviderApi { + +namespace { + +// Deleter that frees a buffer allocated via an OrtAllocator. Used to make the read path exception-safe. +struct OrtAllocatorBufferDeleter { + OrtAllocator* allocator; + void operator()(void* p) const { + if (p != nullptr) { + allocator->Free(allocator, p); + } + } +}; + +// Resolves the on-disk path for EPContext binary data. A relative file_name is joined to the directory +// of the model referenced by graph. To prevent path traversal (e.g., a malicious model influencing an +// EP-derived file name like "../../secret"), the joined path is canonicalized and verified to stay within +// the model directory. +Status ResolveEpContextDataPath(const char* file_name, const OrtGraph* graph, + std::filesystem::path& resolved_path) { + std::filesystem::path data_path{ToPathString(file_name)}; + if (data_path.is_absolute() || graph == nullptr) { + resolved_path = std::move(data_path); + return Status::OK(); + } + + const ORTCHAR_T* model_path = graph->GetModelPath(); + if (model_path == nullptr || model_path[0] == 0) { + resolved_path = std::move(data_path); + return Status::OK(); + } + + const std::filesystem::path base_dir = std::filesystem::path{model_path}.parent_path(); + const std::filesystem::path joined = base_dir / data_path; + + // Canonicalize lexically/with the filesystem so ".." segments are collapsed before the containment check. + // weakly_canonical does not require the target to exist (the file may not yet be written). + std::error_code ec; + std::filesystem::path canonical_base = std::filesystem::weakly_canonical(base_dir, ec); + if (ec) { + canonical_base = base_dir.lexically_normal(); + ec.clear(); + } + std::filesystem::path canonical_joined = std::filesystem::weakly_canonical(joined, ec); + if (ec) { + canonical_joined = joined.lexically_normal(); + } + + // The resolved path must be the model directory itself or a descendant of it. + auto base_it = canonical_base.begin(); + auto joined_it = canonical_joined.begin(); + for (; base_it != canonical_base.end(); ++base_it, ++joined_it) { + if (joined_it == canonical_joined.end() || *joined_it != *base_it) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "EPContext data file name '", file_name, + "' resolves to a location outside the model directory"); + } + } + + resolved_path = std::move(canonical_joined); + return Status::OK(); +} + +} // namespace + ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, _In_ const OrtHardwareDevice* hardware_device, _In_opt_ const OrtKeyValuePairs* ep_metadata, @@ -1198,6 +1275,121 @@ ORT_API_STATUS_IMPL(ProfilingEventsContainer_AddEvents, API_IMPL_END } +ORT_API_STATUS_IMPL(SessionOptions_GetEpContextConfig, + _In_ const OrtSessionOptions* session_options, + _Outptr_ OrtEpContextConfig** config) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(session_options == nullptr, ORT_INVALID_ARGUMENT, "OrtSessionOptions is NULL"); + ORT_API_RETURN_IF(config == nullptr, ORT_INVALID_ARGUMENT, "Output OrtEpContextConfig is NULL"); + + auto ep_context_config = std::make_unique(); + if (const auto* write_config = session_options->value.ep_context_gen_options.TryGetEpContextDataWriteFunc()) { + ep_context_config->write_func = write_config->write_func; + ep_context_config->write_state = write_config->state; + } + ep_context_config->read_func = session_options->value.ep_context_data_read_func; + ep_context_config->read_state = session_options->value.ep_context_data_read_state; + + *config = ep_context_config.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseEpContextConfig, _Frees_ptr_opt_ OrtEpContextConfig* config) { + delete config; +} + +ORT_API_STATUS_IMPL(ReadEpContextData, + _In_opt_ const OrtEpContextConfig* config, + _In_ const char* file_name, + _In_opt_ const OrtGraph* graph, + _Inout_ OrtAllocator* allocator, + _Outptr_ void** buffer, + _Out_ size_t* buffer_size) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(file_name == nullptr, ORT_INVALID_ARGUMENT, "file_name is NULL"); + ORT_API_RETURN_IF(allocator == nullptr, ORT_INVALID_ARGUMENT, "OrtAllocator is NULL"); + ORT_API_RETURN_IF(buffer == nullptr, ORT_INVALID_ARGUMENT, "Output buffer is NULL"); + ORT_API_RETURN_IF(buffer_size == nullptr, ORT_INVALID_ARGUMENT, "Output buffer_size is NULL"); + + *buffer = nullptr; + *buffer_size = 0; + + if (config != nullptr && config->read_func != nullptr) { + OrtStatus* status = config->read_func(config->read_state, file_name, allocator, buffer, buffer_size); + if (status != nullptr) { + if (*buffer != nullptr) { + allocator->Free(allocator, *buffer); + } + *buffer = nullptr; + *buffer_size = 0; + return status; + } + + ORT_API_RETURN_IF(*buffer_size != 0 && *buffer == nullptr, ORT_FAIL, + "OrtReadEpContextDataFunc returned a null buffer for non-empty EPContext data"); + return nullptr; + } + + std::filesystem::path data_path; + ORT_API_RETURN_IF_STATUS_NOT_OK(ResolveEpContextDataPath(file_name, graph, data_path)); + size_t file_size = 0; + ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().GetFileLength(data_path.native().c_str(), file_size)); + + if (file_size == 0) { + return nullptr; + } + + std::unique_ptr allocated_buffer( + allocator->Alloc(allocator, file_size), OrtAllocatorBufferDeleter{allocator}); + ORT_API_RETURN_IF(allocated_buffer == nullptr, ORT_FAIL, "Failed to allocate buffer for EPContext data"); + + ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().ReadFileIntoBuffer( + data_path.native().c_str(), 0, file_size, + gsl::make_span(static_cast(allocated_buffer.get()), file_size))); + + *buffer = allocated_buffer.release(); + *buffer_size = file_size; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(WriteEpContextData, + _In_opt_ const OrtEpContextConfig* config, + _In_ const char* file_name, + _In_opt_ const OrtGraph* graph, + _In_ const void* buffer, + _In_ size_t buffer_size) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(file_name == nullptr, ORT_INVALID_ARGUMENT, "file_name is NULL"); + ORT_API_RETURN_IF(buffer_size != 0 && buffer == nullptr, ORT_INVALID_ARGUMENT, + "EPContext data buffer is NULL for non-empty data"); + + if (config != nullptr && config->write_func != nullptr) { + return config->write_func(config->write_state, file_name, buffer, buffer_size); + } + + std::filesystem::path data_path; + ORT_API_RETURN_IF_STATUS_NOT_OK(ResolveEpContextDataPath(file_name, graph, data_path)); + std::ofstream output_stream(data_path, std::ios::binary); + ORT_API_RETURN_IF(!output_stream, ORT_FAIL, "Failed to open EPContext data file for write: ", + PathToUTF8String(data_path.native())); + + if (buffer_size != 0) { + // Compare in a common unsigned wide type to avoid signed/unsigned and width mismatches between + // size_t and std::streamsize across platforms. + constexpr auto max_stream_size = static_cast(std::numeric_limits::max()); + ORT_API_RETURN_IF(static_cast(buffer_size) > max_stream_size, + ORT_INVALID_ARGUMENT, "EPContext data buffer is too large to write"); + output_stream.write(static_cast(buffer), static_cast(buffer_size)); + ORT_API_RETURN_IF(!output_stream, ORT_FAIL, "Failed to write EPContext data file: ", + PathToUTF8String(data_path.native())); + } + + return nullptr; + API_IMPL_END +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -1287,6 +1479,11 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::ProfilingEvent_GetArgValue, &OrtExecutionProviderApi::ProfilingEventsContainer_AddEvents, // End of Version 25 - DO NOT MODIFY ABOVE + + &OrtExecutionProviderApi::SessionOptions_GetEpContextConfig, + &OrtExecutionProviderApi::ReleaseEpContextConfig, + &OrtExecutionProviderApi::ReadEpContextData, + &OrtExecutionProviderApi::WriteEpContextData, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index e32e267a75ba5..8ca542b59a425 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -179,4 +179,23 @@ ORT_API_STATUS_IMPL(ProfilingEvent_GetDurationUs, _In_ const OrtProfilingEvent* ORT_API_STATUS_IMPL(ProfilingEvent_GetArgValue, _In_ const OrtProfilingEvent* event, _In_ const char* key, _Outptr_result_maybenull_ const char** out); +// EPContext data I/O +ORT_API_STATUS_IMPL(SessionOptions_GetEpContextConfig, + _In_ const OrtSessionOptions* session_options, + _Outptr_ OrtEpContextConfig** config); +ORT_API(void, ReleaseEpContextConfig, _Frees_ptr_opt_ OrtEpContextConfig* config); +ORT_API_STATUS_IMPL(ReadEpContextData, + _In_opt_ const OrtEpContextConfig* config, + _In_ const char* file_name, + _In_opt_ const OrtGraph* graph, + _Inout_ OrtAllocator* allocator, + _Outptr_ void** buffer, + _Out_ size_t* buffer_size); +ORT_API_STATUS_IMPL(WriteEpContextData, + _In_opt_ const OrtEpContextConfig* config, + _In_ const char* file_name, + _In_opt_ const OrtGraph* graph, + _In_ const void* buffer, + _In_ size_t buffer_size); + } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index ca9a296501b04..4ff42cfea19ac 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -167,13 +167,15 @@ struct EpContextNodeComputeInfo : NodeComputeInfoBase { ExampleEp& ep; }; -ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger) +ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger, + Ort::EpContextConfig ep_context_config) : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized ApiPtrs{static_cast(factory)}, factory_{factory}, name_{name}, config_{config}, - logger_{logger} { + logger_{logger}, + ep_context_config_{std::move(ep_context_config)} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. // Initialize the execution provider's function table @@ -408,6 +410,28 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const auto fused_node_name = fused_node.GetName(); if (is_ep_context_node) { + Ort::ConstOpAttr embed_mode_attr; + RETURN_IF_ERROR(nodes[0].GetAttributeByName("embed_mode", embed_mode_attr)); + int64_t embed_mode = 1; + RETURN_IF_ERROR(embed_mode_attr.GetValue(embed_mode)); + + if (embed_mode == 0) { + Ort::ConstOpAttr ep_cache_context_attr; + RETURN_IF_ERROR(nodes[0].GetAttributeByName("ep_cache_context", ep_cache_context_attr)); + std::string ep_cache_context; + RETURN_IF_ERROR(ep_cache_context_attr.GetValue(ep_cache_context)); + + Ort::AllocatorWithDefaultOptions allocator; + void* ep_context_data = nullptr; + size_t ep_context_data_size = 0; + RETURN_IF_ERROR(ep->ep_api.ReadEpContextData(ep->ep_context_config_, ep_cache_context.c_str(), ort_graphs[0], + allocator, &ep_context_data, &ep_context_data_size)); + (void)ep_context_data_size; + if (ep_context_data != nullptr) { + allocator.Free(ep_context_data); + } + } + // Create EpContextKernel for EPContext nodes - clearly separates from MulKernel ep->ep_context_kernels_.emplace(fused_node_name, std::make_unique(ep->ort_api, ep->logger_)); @@ -448,7 +472,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const // Create EpContext nodes for the fused nodes we compiled (only for Mul, not EPContext). if (ep->config_.enable_ep_context) { assert(ep_context_nodes != nullptr); - RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + RETURN_IF_ERROR(ep->CreateEpContextNodes(ort_graphs[0], gsl::span(fused_nodes, count), gsl::span(ep_context_nodes, count))); } } @@ -478,7 +502,8 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, // Creates EPContext nodes from the given fused nodes. // This is an example implementation that can be used to generate an EPContext model. However, this example EP // cannot currently run the EPContext model. -OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes, +OrtStatus* ExampleEp::CreateEpContextNodes(const OrtGraph* graph, + gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes) { try { assert(fused_nodes.size() == ep_context_nodes.size()); @@ -511,11 +536,19 @@ OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes collect_input_output_names(fused_node_outputs, /*out*/ output_names); int64_t is_main_context = (i == 0); - int64_t embed_mode = 1; + int64_t embed_mode = config_.embed_ep_context_in_model ? 1 : 0; // Create node attributes. The CreateNode() function copies the attributes. + // The "ep_cache_context" attribute holds the raw EPContext binary data only when embed_mode != 0. + // When embed_mode == 0, it instead holds the EPContext file name; the binary data is written out + // separately via WriteEpContextData below. std::array attributes = {}; - std::string ep_ctx = "binary_data"; + std::string ep_ctx = config_.embed_ep_context_in_model ? "binary_data" : fused_node_name + ".ctx"; + if (!config_.embed_ep_context_in_model) { + const std::string ep_context_data = "binary_data"; + RETURN_IF_ERROR(ep_api.WriteEpContextData(ep_context_config_, ep_ctx.c_str(), graph, + ep_context_data.data(), ep_context_data.size())); + } attributes[0] = Ort::OpAttr("ep_cache_context", ep_ctx.data(), static_cast(ep_ctx.size()), ORT_OP_ATTR_STRING); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 2ba13658c3364..184f0c5fc624e 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -61,11 +61,13 @@ class ExampleEp : public OrtEp, public ApiPtrs { public: struct Config { bool enable_ep_context = false; + bool embed_ep_context_in_model = true; bool enable_weightless_ep_context_nodes = false; // Other EP configs (typically extracted from OrtSessionOptions or OrtHardwareDevice(s)) }; - ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger); + ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger, + Ort::EpContextConfig ep_context_config); ~ExampleEp(); @@ -105,7 +107,8 @@ class ExampleEp : public OrtEp, public ApiPtrs { static OrtStatus* ORT_API_CALL SyncImpl(_In_ OrtEp* this_ptr) noexcept; - OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, + OrtStatus* CreateEpContextNodes(const OrtGraph* graph, + gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); // Returns true if the EP should save constant initializers so that they are available during inference. @@ -119,6 +122,7 @@ class ExampleEp : public OrtEp, public ApiPtrs { std::string name_; Config config_{}; const OrtLogger& logger_; + Ort::EpContextConfig ep_context_config_{nullptr}; std::unordered_map> mul_kernels_; std::unordered_map> ep_context_kernels_; std::unordered_map float_initializers_; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index e003f3bd93786..7b1027f4baac2 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -217,17 +217,26 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, // Create EP configuration from session options, if needed. // Note: should not store a direct reference to the session options object as its lifespan is not guaranteed. std::string ep_context_enable; + std::string ep_context_embed_mode; std::string weightless_ep_context_nodes_enable; RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpContextEnable, "0", ep_context_enable)); + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpContextEmbedMode, "1", + ep_context_embed_mode)); RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpEnableWeightlessEpContextNodes, "0", weightless_ep_context_nodes_enable)); ExampleEp::Config config = {}; config.enable_ep_context = ep_context_enable == "1"; + config.embed_ep_context_in_model = ep_context_embed_mode != "0"; config.enable_weightless_ep_context_nodes = weightless_ep_context_nodes_enable == "1"; - auto dummy_ep = std::make_unique(*factory, factory->ep_name_, config, *logger); + OrtEpContextConfig* ep_context_config_raw = nullptr; + RETURN_IF_ERROR(factory->ep_api.SessionOptions_GetEpContextConfig(session_options, &ep_context_config_raw)); + Ort::EpContextConfig ep_context_config{ep_context_config_raw}; + + auto dummy_ep = std::make_unique(*factory, factory->ep_name_, config, *logger, + std::move(ep_context_config)); *ep = dummy_ep.release(); return nullptr; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 93633f9a375bb..35cbdc9009acc 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include +#include #include // #include #include @@ -29,6 +31,47 @@ namespace test { namespace { +struct EpContextDataCallbackState { + bool write_called = false; + bool read_called = false; + std::string write_file_name; + std::string read_file_name; + std::vector payload; +}; + +OrtStatus* ORT_API_CALL StoreEpContextDataCallback(void* state, const char* file_name, const void* buffer, + size_t buffer_size) { + auto* callback_state = static_cast(state); + callback_state->write_called = true; + callback_state->write_file_name = file_name; + callback_state->payload.clear(); + if (buffer_size != 0) { + callback_state->payload.assign(static_cast(buffer), static_cast(buffer) + buffer_size); + } + return nullptr; +} + +OrtStatus* ORT_API_CALL LoadEpContextDataCallback(void* state, const char* file_name, OrtAllocator* allocator, + void** buffer, size_t* data_size) { + auto* callback_state = static_cast(state); + callback_state->read_called = true; + callback_state->read_file_name = file_name; + + *buffer = nullptr; + *data_size = callback_state->payload.size(); + if (callback_state->payload.empty()) { + return nullptr; + } + + OrtStatus* status = Ort::GetApi().AllocatorAlloc(allocator, callback_state->payload.size(), buffer); + if (status != nullptr) { + return status; + } + + std::copy(callback_state->payload.begin(), callback_state->payload.end(), static_cast(*buffer)); + return nullptr; +} + void RunMulModelWithPluginEp(const ORTCHAR_T* model_path, const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, model_path, session_options); @@ -521,6 +564,79 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { } } +TEST(OrtEpLibrary, PluginEp_GenEpContextModel_ExternalDataUsesWriteCallback) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_external_ctx.onnx"); + std::filesystem::remove(output_model_file); + auto cleanup = gsl::finally([&]() { std::filesystem::remove(output_model_file); }); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + EpContextDataCallbackState callback_state; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetEpContextEmbedMode(false); + compile_options.SetEpContextDataWriteFunc(StoreEpContextDataCallback, &callback_state); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + ASSERT_TRUE(callback_state.write_called); + EXPECT_FALSE(callback_state.write_file_name.empty()); + EXPECT_EQ(std::string(callback_state.payload.begin(), callback_state.payload.end()), "binary_data"); +} + +TEST(OrtEpLibrary, PluginEp_LoadEpContextModel_ExternalDataUsesReadCallback) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* compiled_model_file = ORT_TSTR("plugin_ep_mul_1_external_ctx_load.onnx"); + std::filesystem::remove(compiled_model_file); + auto cleanup = gsl::finally([&]() { std::filesystem::remove(compiled_model_file); }); + + EpContextDataCallbackState write_callback_state; + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(compiled_model_file); + compile_options.SetEpContextEmbedMode(false); + compile_options.SetEpContextDataWriteFunc(StoreEpContextDataCallback, &write_callback_state); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(compiled_model_file)); + ASSERT_TRUE(write_callback_state.write_called); + } + + EpContextDataCallbackState read_callback_state; + read_callback_state.payload = write_callback_state.payload; + { + Ort::SessionOptions session_options; + session_options.SetEpContextDataReadFunc(LoadEpContextDataCallback, &read_callback_state); + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, compiled_model_file, session_options); + } + + ASSERT_TRUE(read_callback_state.read_called); + EXPECT_EQ(read_callback_state.read_file_name, write_callback_state.write_file_name); +} + TEST(OrtEpLibrary, PluginEp_GenWeightlessEpContextModel) { RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 8dcc56bcfea44..8199b904708b1 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -5,21 +5,25 @@ #include #include +#include #include #include #include "gsl/gsl" #include "gtest/gtest.h" +#include "core/common/path_string.h" #include "core/common/logging/sinks/file_sink.h" #include "core/framework/config_options.h" #include "core/framework/kernel_def_builder.h" #include "core/framework/op_kernel.h" #include "core/framework/resource_accountant.h" #include "core/graph/constants.h" +#include "core/graph/ep_api_types.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/optimizer/graph_optimizer_registry.h" #include "core/session/abi_devices.h" +#include "core/session/model_compilation_options.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "test/util/include/api_asserts.h" @@ -56,6 +60,110 @@ static void CheckFileIsEmpty(const PathString& filename) { EXPECT_TRUE(content.empty()); } +static void ExpectOrtStatus(OrtStatus* status_ptr, OrtErrorCode expected_code, const char* expected_message) { + Ort::Status status{status_ptr}; + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), expected_code); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr(expected_message)); +} + +static void ExpectOrtStatusNotOk(OrtStatus* status_ptr) { + Ort::Status status{status_ptr}; + EXPECT_FALSE(status.IsOK()); +} + +static std::filesystem::path MakeEpContextDataTestDir(const char* test_name) { + std::filesystem::path test_dir = std::filesystem::temp_directory_path() / test_name; + std::filesystem::remove_all(test_dir); + std::filesystem::create_directories(test_dir); + return test_dir; +} + +struct EpContextReadCallbackState { + bool called = false; + std::string file_name; + std::vector payload; +}; + +static OrtStatus* ORT_API_CALL EpContextReadCallback(void* state, const char* file_name, OrtAllocator* allocator, + void** buffer, size_t* data_size) { + auto* read_state = static_cast(state); + read_state->called = true; + read_state->file_name = file_name; + + *buffer = nullptr; + *data_size = read_state->payload.size(); + + if (read_state->payload.empty()) { + return nullptr; + } + + OrtStatus* status = Ort::GetApi().AllocatorAlloc(allocator, read_state->payload.size(), buffer); + if (status != nullptr) { + return status; + } + + std::memcpy(*buffer, read_state->payload.data(), read_state->payload.size()); + return nullptr; +} + +struct EpContextCallbackErrorState { + OrtErrorCode error_code = ORT_FAIL; + const char* message = nullptr; +}; + +static OrtStatus* ORT_API_CALL EpContextFailingReadCallback(void* state, const char* /*file_name*/, + OrtAllocator* /*allocator*/, void** /*buffer*/, + size_t* /*data_size*/) { + const auto* error_state = static_cast(state); + return Ort::GetApi().CreateStatus(error_state->error_code, error_state->message); +} + +static OrtStatus* ORT_API_CALL EpContextReadCallbackFailsAfterAlloc(void* state, const char* /*file_name*/, + OrtAllocator* allocator, void** buffer, + size_t* data_size) { + const auto* error_state = static_cast(state); + *data_size = 4; + OrtStatus* alloc_status = Ort::GetApi().AllocatorAlloc(allocator, *data_size, buffer); + if (alloc_status != nullptr) { + return alloc_status; + } + + return Ort::GetApi().CreateStatus(error_state->error_code, error_state->message); +} + +static OrtStatus* ORT_API_CALL EpContextNonEmptyNullBufferReadCallback(void* /*state*/, const char* /*file_name*/, + OrtAllocator* /*allocator*/, void** buffer, + size_t* data_size) { + *buffer = nullptr; + *data_size = 4; + return nullptr; +} + +struct EpContextWriteCallbackState { + bool called = false; + std::string file_name; + std::vector payload; +}; + +static OrtStatus* ORT_API_CALL EpContextWriteCallback(void* state, const char* file_name, const void* buffer, + size_t buffer_size) { + auto* write_state = static_cast(state); + write_state->called = true; + write_state->file_name = file_name; + write_state->payload.clear(); + if (buffer_size != 0) { + write_state->payload.assign(static_cast(buffer), static_cast(buffer) + buffer_size); + } + return nullptr; +} + +static OrtStatus* ORT_API_CALL EpContextFailingWriteCallback(void* state, const char* /*file_name*/, + const void* /*buffer*/, size_t /*buffer_size*/) { + const auto* error_state = static_cast(state); + return Ort::GetApi().CreateStatus(error_state->error_code, error_state->message); +} + // Normally, a plugin EP would be implemented in a separate library. // The `test_plugin_ep` namespace contains a local implementation intended for unit testing. namespace test_plugin_ep { @@ -1608,6 +1716,309 @@ TEST(PluginExecutionProviderTest, GetGraphCaptureNodeAssignmentPolicy) { } } +TEST(PluginExecutionProviderTest, EpContextDataReadFuncIsCalledViaEpApi) { + const auto& ep_api = Ort::GetEpApi(); + Ort::SessionOptions session_options; + + EpContextReadCallbackState read_state{ + false, + {}, + {'e', 'p', 'c', 't', 'x'}, + }; + session_options.SetEpContextDataReadFunc(EpContextReadCallback, &read_state); + + OrtEpContextConfig* ep_context_config = nullptr; + ASSERT_ORTSTATUS_OK(ep_api.SessionOptions_GetEpContextConfig(session_options, &ep_context_config)); + auto release_config = gsl::finally([&]() { ep_api.ReleaseEpContextConfig(ep_context_config); }); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ASSERT_ORTSTATUS_OK(ep_api.ReadEpContextData(ep_context_config, "context.bin", nullptr, allocator, + &buffer, &buffer_size)); + auto release_buffer = gsl::finally([&]() { allocator.Free(buffer); }); + + ASSERT_TRUE(read_state.called); + EXPECT_EQ(read_state.file_name, "context.bin"); + ASSERT_EQ(buffer_size, read_state.payload.size()); + EXPECT_EQ(std::vector(static_cast(buffer), static_cast(buffer) + buffer_size), + read_state.payload); +} + +TEST(PluginExecutionProviderTest, EpContextDataApiRejectsInvalidArguments) { + const auto& ort_api = Ort::GetApi(); + const auto& ep_api = Ort::GetEpApi(); + + Ort::SessionOptions session_options; + OrtEpContextConfig* ep_context_config = nullptr; + ExpectOrtStatus(ep_api.SessionOptions_GetEpContextConfig(nullptr, &ep_context_config), ORT_INVALID_ARGUMENT, + "OrtSessionOptions is NULL"); + ExpectOrtStatus(ep_api.SessionOptions_GetEpContextConfig(session_options, nullptr), ORT_INVALID_ARGUMENT, + "Output OrtEpContextConfig is NULL"); + + ExpectOrtStatus(ort_api.SessionOptions_SetEpContextDataReadFunc(nullptr, EpContextReadCallback, nullptr), + ORT_INVALID_ARGUMENT, "'options' parameter must not be NULL"); + ExpectOrtStatus(ort_api.SessionOptions_SetEpContextDataReadFunc(session_options, nullptr, nullptr), + ORT_INVALID_ARGUMENT, "'read_func' parameter must not be NULL"); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ExpectOrtStatus(ep_api.ReadEpContextData(nullptr, nullptr, nullptr, allocator, &buffer, &buffer_size), + ORT_INVALID_ARGUMENT, "file_name is NULL"); + ExpectOrtStatus(ep_api.ReadEpContextData(nullptr, "context.bin", nullptr, nullptr, &buffer, &buffer_size), + ORT_INVALID_ARGUMENT, "OrtAllocator is NULL"); + ExpectOrtStatus(ep_api.ReadEpContextData(nullptr, "context.bin", nullptr, allocator, nullptr, &buffer_size), + ORT_INVALID_ARGUMENT, "Output buffer is NULL"); + ExpectOrtStatus(ep_api.ReadEpContextData(nullptr, "context.bin", nullptr, allocator, &buffer, nullptr), + ORT_INVALID_ARGUMENT, "Output buffer_size is NULL"); + + const std::vector payload{'x'}; + ExpectOrtStatus(ep_api.WriteEpContextData(nullptr, nullptr, nullptr, payload.data(), payload.size()), + ORT_INVALID_ARGUMENT, "file_name is NULL"); + ExpectOrtStatus(ep_api.WriteEpContextData(nullptr, "context.bin", nullptr, nullptr, payload.size()), + ORT_INVALID_ARGUMENT, "EPContext data buffer is NULL for non-empty data"); + +#if !defined(ORT_MINIMAL_BUILD) + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataApiRejectsInvalidArguments"}; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + const auto& compile_api = Ort::GetCompileApi(); + ExpectOrtStatus(compile_api.ModelCompilationOptions_SetEpContextDataWriteFunc(nullptr, EpContextWriteCallback, + nullptr), + ORT_INVALID_ARGUMENT, "OrtModelCompilationOptions is null"); + ExpectOrtStatus(compile_api.ModelCompilationOptions_SetEpContextDataWriteFunc(compilation_options, nullptr, + nullptr), + ORT_INVALID_ARGUMENT, "OrtWriteEpContextDataFunc function is null"); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +TEST(PluginExecutionProviderTest, EpContextDataCallbackErrorsArePropagated) { + const auto& ep_api = Ort::GetEpApi(); + Ort::SessionOptions session_options; + + EpContextCallbackErrorState read_error{ORT_FAIL, "read callback failed"}; + session_options.SetEpContextDataReadFunc(EpContextFailingReadCallback, &read_error); + + OrtEpContextConfig* ep_context_config = nullptr; + ASSERT_ORTSTATUS_OK(ep_api.SessionOptions_GetEpContextConfig(session_options, &ep_context_config)); + auto release_config = gsl::finally([&]() { ep_api.ReleaseEpContextConfig(ep_context_config); }); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ExpectOrtStatus(ep_api.ReadEpContextData(ep_context_config, "context.bin", nullptr, allocator, + &buffer, &buffer_size), + ORT_FAIL, "read callback failed"); + +#if !defined(ORT_MINIMAL_BUILD) + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataCallbackErrorsArePropagated"}; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + EpContextCallbackErrorState write_error{ORT_EP_FAIL, "write callback failed"}; + compilation_options.SetEpContextDataWriteFunc(EpContextFailingWriteCallback, &write_error); + + const auto* internal_options = reinterpret_cast( + static_cast(compilation_options)); + OrtEpContextConfig* write_config = nullptr; + ASSERT_ORTSTATUS_OK(ep_api.SessionOptions_GetEpContextConfig(&internal_options->GetSessionOptions(), &write_config)); + auto release_write_config = gsl::finally([&]() { ep_api.ReleaseEpContextConfig(write_config); }); + + const std::vector payload{'x'}; + ExpectOrtStatus(ep_api.WriteEpContextData(write_config, "context.bin", nullptr, payload.data(), payload.size()), + ORT_EP_FAIL, "write callback failed"); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +TEST(PluginExecutionProviderTest, EpContextDataReadCallbackFailureClearsOutputBuffer) { + const auto& ep_api = Ort::GetEpApi(); + Ort::SessionOptions session_options; + constexpr uintptr_t kNonNullSentinel = 0x1; + + EpContextCallbackErrorState read_error{ORT_FAIL, "read callback failed after allocation"}; + session_options.SetEpContextDataReadFunc(EpContextReadCallbackFailsAfterAlloc, &read_error); + + OrtEpContextConfig* ep_context_config = nullptr; + ASSERT_ORTSTATUS_OK(ep_api.SessionOptions_GetEpContextConfig(session_options, &ep_context_config)); + auto release_config = gsl::finally([&]() { ep_api.ReleaseEpContextConfig(ep_context_config); }); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = reinterpret_cast(kNonNullSentinel); + size_t buffer_size = 1; + ExpectOrtStatus(ep_api.ReadEpContextData( + ep_context_config, "context.bin", nullptr, allocator, &buffer, &buffer_size), + ORT_FAIL, "read callback failed after allocation"); + EXPECT_EQ(buffer, nullptr); + EXPECT_EQ(buffer_size, 0U); +} + +TEST(PluginExecutionProviderTest, EpContextDataAllowsEmptyPayloads) { + const auto& ep_api = Ort::GetEpApi(); + Ort::SessionOptions session_options; + + EpContextReadCallbackState read_state{}; + session_options.SetEpContextDataReadFunc(EpContextReadCallback, &read_state); + + OrtEpContextConfig* ep_context_config = nullptr; + ASSERT_ORTSTATUS_OK(ep_api.SessionOptions_GetEpContextConfig(session_options, &ep_context_config)); + auto release_config = gsl::finally([&]() { ep_api.ReleaseEpContextConfig(ep_context_config); }); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = reinterpret_cast(0x1); + size_t buffer_size = 1; + ASSERT_ORTSTATUS_OK(ep_api.ReadEpContextData(ep_context_config, "empty.bin", nullptr, allocator, + &buffer, &buffer_size)); + EXPECT_TRUE(read_state.called); + EXPECT_EQ(read_state.file_name, "empty.bin"); + EXPECT_EQ(buffer, nullptr); + EXPECT_EQ(buffer_size, 0U); + +#if !defined(ORT_MINIMAL_BUILD) + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataAllowsEmptyPayloads"}; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + EpContextWriteCallbackState write_state{}; + compilation_options.SetEpContextDataWriteFunc(EpContextWriteCallback, &write_state); + + const auto* internal_options = reinterpret_cast( + static_cast(compilation_options)); + OrtEpContextConfig* write_config = nullptr; + ASSERT_ORTSTATUS_OK(ep_api.SessionOptions_GetEpContextConfig(&internal_options->GetSessionOptions(), &write_config)); + auto release_write_config = gsl::finally([&]() { ep_api.ReleaseEpContextConfig(write_config); }); + + ASSERT_ORTSTATUS_OK(ep_api.WriteEpContextData(write_config, "empty.bin", nullptr, nullptr, 0)); + EXPECT_TRUE(write_state.called); + EXPECT_EQ(write_state.file_name, "empty.bin"); + EXPECT_TRUE(write_state.payload.empty()); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +TEST(PluginExecutionProviderTest, EpContextDataReadRejectsNonEmptyNullCallbackBuffer) { + const auto& ep_api = Ort::GetEpApi(); + Ort::SessionOptions session_options; + session_options.SetEpContextDataReadFunc(EpContextNonEmptyNullBufferReadCallback, nullptr); + + OrtEpContextConfig* ep_context_config = nullptr; + ASSERT_ORTSTATUS_OK(ep_api.SessionOptions_GetEpContextConfig(session_options, &ep_context_config)); + auto release_config = gsl::finally([&]() { ep_api.ReleaseEpContextConfig(ep_context_config); }); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ExpectOrtStatus(ep_api.ReadEpContextData(ep_context_config, "context.bin", nullptr, allocator, + &buffer, &buffer_size), + ORT_FAIL, "returned a null buffer for non-empty EPContext data"); +} + +#if !defined(ORT_MINIMAL_BUILD) +TEST(PluginExecutionProviderTest, EpContextDataWriteFuncIsCalledViaEpApi) { + const auto& ep_api = Ort::GetEpApi(); + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataWriteFuncIsCalledViaEpApi"}; + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + + EpContextWriteCallbackState write_state{}; + compilation_options.SetEpContextDataWriteFunc(EpContextWriteCallback, &write_state); + + const auto* internal_options = reinterpret_cast( + static_cast(compilation_options)); + OrtEpContextConfig* ep_context_config = nullptr; + ASSERT_ORTSTATUS_OK(ep_api.SessionOptions_GetEpContextConfig(&internal_options->GetSessionOptions(), + &ep_context_config)); + auto release_config = gsl::finally([&]() { ep_api.ReleaseEpContextConfig(ep_context_config); }); + + const std::vector payload{'b', 'i', 'n', 'a', 'r', 'y'}; + ASSERT_ORTSTATUS_OK(ep_api.WriteEpContextData(ep_context_config, "engine.bin", nullptr, + payload.data(), payload.size())); + + ASSERT_TRUE(write_state.called); + EXPECT_EQ(write_state.file_name, "engine.bin"); + EXPECT_EQ(write_state.payload, payload); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +TEST(PluginExecutionProviderTest, EpContextDataFallsBackToDisk) { + const auto& ep_api = Ort::GetEpApi(); + const std::filesystem::path test_dir = MakeEpContextDataTestDir("ort_ep_context_data_test"); + const std::filesystem::path data_path = test_dir / "context.bin"; + const std::string data_path_utf8 = PathToUTF8String(data_path.native()); + auto cleanup = gsl::finally([&]() { + std::error_code ec; + std::filesystem::remove_all(test_dir, ec); + }); + + const std::vector payload{'d', 'i', 's', 'k'}; + ASSERT_ORTSTATUS_OK(ep_api.WriteEpContextData(nullptr, data_path_utf8.c_str(), nullptr, + payload.data(), payload.size())); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ASSERT_ORTSTATUS_OK(ep_api.ReadEpContextData(nullptr, data_path_utf8.c_str(), nullptr, allocator, + &buffer, &buffer_size)); + auto release_buffer = gsl::finally([&]() { allocator.Free(buffer); }); + + ASSERT_EQ(buffer_size, payload.size()); + EXPECT_EQ(std::vector(static_cast(buffer), static_cast(buffer) + buffer_size), payload); +} + +TEST(PluginExecutionProviderTest, EpContextDataDiskFallbackResolvesRelativePathAgainstGraphModelPath) { + const auto& ep_api = Ort::GetEpApi(); + const std::filesystem::path test_dir = MakeEpContextDataTestDir("ort_ep_context_data_relative_path_test"); + auto cleanup = gsl::finally([&]() { + std::error_code ec; + std::filesystem::remove_all(test_dir, ec); + }); + + const std::filesystem::path source_model_path{ORT_TSTR("testdata/add_mul_add.onnx")}; + const std::filesystem::path model_path = test_dir / "model.onnx"; + std::filesystem::copy_file(source_model_path, model_path, std::filesystem::copy_options::overwrite_existing); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_path.native().c_str(), model, nullptr, + DefaultLoggingManager().DefaultLogger())); + GraphViewer graph_viewer(model->MainGraph()); + std::unique_ptr ep_graph = nullptr; + ASSERT_STATUS_OK(EpGraph::Create(graph_viewer, ep_graph, true)); + + const std::vector payload{'r', 'e', 'l'}; + ASSERT_ORTSTATUS_OK(ep_api.WriteEpContextData(nullptr, "context.bin", ep_graph.get(), + payload.data(), payload.size())); + + const std::filesystem::path expected_context_path = test_dir / "context.bin"; + ASSERT_TRUE(std::filesystem::exists(expected_context_path)); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ASSERT_ORTSTATUS_OK(ep_api.ReadEpContextData(nullptr, "context.bin", ep_graph.get(), allocator, + &buffer, &buffer_size)); + auto release_buffer = gsl::finally([&]() { allocator.Free(buffer); }); + + ASSERT_EQ(buffer_size, payload.size()); + EXPECT_EQ(std::vector(static_cast(buffer), static_cast(buffer) + buffer_size), payload); +} + +TEST(PluginExecutionProviderTest, EpContextDataDiskFallbackReportsFileErrors) { + const auto& ep_api = Ort::GetEpApi(); + const std::filesystem::path test_dir = MakeEpContextDataTestDir("ort_ep_context_data_file_error_test"); + auto cleanup = gsl::finally([&]() { + std::error_code ec; + std::filesystem::remove_all(test_dir, ec); + }); + + const std::filesystem::path missing_file_path = test_dir / "missing" / "context.bin"; + const std::string missing_file_path_utf8 = PathToUTF8String(missing_file_path.native()); + const std::vector payload{'x'}; + + ExpectOrtStatus(ep_api.WriteEpContextData(nullptr, missing_file_path_utf8.c_str(), nullptr, + payload.data(), payload.size()), + ORT_FAIL, "Failed to open EPContext data file for write"); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ExpectOrtStatusNotOk(ep_api.ReadEpContextData(nullptr, missing_file_path_utf8.c_str(), nullptr, allocator, + &buffer, &buffer_size)); +} + // Helper: create a no-threshold resource accountant via the real factory (config ","). static IResourceAccountant* CreateNoThresholdAccountant(std::optional& acc_map) { ConfigOptions config;