Skip to content

Commit db6d83b

Browse files
adrianlizarragaCopilotedgchen1
authored
[EP ABI] Initial support for kernel-based EPs (microsoft#26206)
### Description This PR adds an initial set of C APIs necessary to support kernel registration for plugin EPs. ### Example use The example plugin EP implementation now registers `MemcpyFromHost` and `MemcpyToHost` operator kernels using the new APIs. New utilities in the example implementation make the process of defining operator kernels very similar to the existing process used by provider-bridge EPs. First, the operator kernel class is defined: ```c++ // File: onnxruntime/test/autoep/library/kernels/memcpy.h struct Memcpy : public OrtKernelImpl { static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Memcpy>& kernel); Memcpy(const OrtKernelInfo* info, void* state); static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) noexcept; private: const OrtKernelInfo* info_; void* state_; // Custom state passed from OrtEp }; ``` Then, a macro defines a function that can be called to register the operator with the EP's kernel registry: ```c++ // File: onnxruntime/test/autoep/library/kernels/memcpy.cc ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, kOnnxDomain, 1, (Ort::KernelDefBuilder() .SetInputMemType(0, OrtMemType::OrtMemTypeCPUInput) .AddTypeConstraint("T", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), Memcpy) ONNX_OPERATOR_KERNEL_EX( MemcpyToHost, kOnnxDomain, 1, (Ort::KernelDefBuilder() .SetOutputMemType(0, OrtMemType::OrtMemTypeCPUOutput) .AddTypeConstraint("T", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), Memcpy) ``` Lastly, the functions defined by the above macro are entered into a table: ```c++ // File: onnxruntime/test/autoep/library/ep_kernel_registration.cc // Include kernel files: #include "kernels/memcpy.h" // Forward declarations of kernel classes used as template args for BuildKernelCreateInfo class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyFromHost); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyToHost); // Table of BuildKernelCreateInfo functions for each operator static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = { BuildKernelCreateInfo<void>, // Dummy to avoid table becoming empty. BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyFromHost)>, BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyToHost)>, }; ``` The [example EP processes the entries in the above table](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-kernel-based-eps/onnxruntime/test/autoep/library/ep_kernel_registration.cc) to add information about the supported operator kernels to the EP's kernel registry (`OrtKernelRegistry`). Additionally, during the call to `OrtEp::GetCapability`, an EP can now lookup registered kernel definitions via the new API `EpGraphSupportInfo_LookUpKernel`. Note that an EP would not normally lookup kernels for `Memcpy**Host`, which are inserted by ORT. Instead, it would be used to look up other registered operator kernels like `Conv`, for example. ```c++ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) noexcept { // ... for (const OrtNode* node : nodes) { const OrtKernelDef* kernel_def = nullptr; OrtStatus* status = this_ep->ep_api->EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def); if (status != nullptr) { return status; } if (kernel_def != nullptr) { // Take node if this EP has a registered kernel for it. if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, node); st != nullptr) { return st; } } } return nullptr; } ``` ### EP implementation details An EP instance (i.e., `OrtEp`) that needs to register operator kernels with ONNX Runtime must implement the following `OrtEp::GetKernelRegistry()` function: | Function Signature | Description | |--------------------|-------------| |**GetKernelRegistry**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtEp* this_ptr`: The OrtEp instance.</li><li>`const OrtKernelRegistry** kernel_registry`: Output parameter set to the EP's kernel registry, which must remain valid throughout the lifetime of the EP.</li></ul>| Gets the execution provider's kernel registry, if any.<br/><br/>**Remarks:** A kernel registry contains kernel creation information for operator kernels supported by an EP.<br/><br/>**Note:** Implementation of this function is optional. If set to NULL, ORT assumes the EP compiles nodes. | If defined by the EP, the `OrtEp::GetKernelRegistry()` function is [called by ONNX Runtime](https://github.com/microsoft/onnxruntime/blob/0f7145f3809103c123de2d281a6b310677e6d56c/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc#L146-L147) after creating an instance of the `OrtEp` in order to retrieve the EP's kernel registry. #### APIs used by EP to add entries to kernel registry An EP's kernel registry (`OrtKernelRegistry`) contains **information** necessary for the (later) creation of operator kernels supported by an EP. Conceptually, a kernel registry contains an array of "kernel creation information" elements, one per operator. Each such element consists of: - A kernel **definition** (`OrtKernelDef`), which specifies operator type, supported versions, type constraints, I/O memory types, etc. - A function of type `OrtKernelCreateFunc` that ORT calls to create an instance of the kernel (`OrtKernelImpl`). - Custom opaque state (provided by the `OrtEp`) that is passed to the `OrtKernelCreateFunc`. An EP uses the following `OrtEpApi::KernelRegistry_AddKernel()` function to add an entry for one supported operator. | Function Signature | Description | |--------------------|-------------| |**KernelRegistry_AddKernel**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtKernelRegistry* kernel_registry`: The OrtKernelRegistry instance.</li><li>`const OrtKernelDef* kernel_def`: The kernel definition, which includes operator type, version, EP name, type constraints, etc.</li><li>`OrtKernelCreateFunc kernel_create_func`: Function that creates an instance of the operator kernel as a OrtKernelImpl instance.</li><li>`void* kernel_create_func_state`: Custom state passed to the kernel creation function. Can be null.</li></ul>| Adds kernel creation information for a supported operator kernel to the given kernel registry.<br/><br/>**Remarks:** Refer to OrtEp::GetKernelRegistry, which returns an EP's kernel registry to ORT. | ##### Building a kernel definition An EP uses a kernel definition builder (`OrtKernelDefBuilder`) to create a kernel definition (`OrtKernelDef`). The following table lists **some** of the C APIs related to building a kernel definition. The above `ONNX_OPERATOR_KERNEL_EX` macro [uses these APIs](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-kernel-based-eps/onnxruntime/test/autoep/library/kernels/utils.h#L42). | Function Signature | Description | |--------------------|-------------| |**KernelDefBuilder_SetOperatorType**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtKernelDefBuilder* kernel_def_builder`: The OrtKernelDefBuilder instance.</li><li>`const char* op_type`: A null-terminated string representing the operator type.</li></ul>| Sets the kernel's operator type. | |**KernelDefBuilder_SetDomain**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtKernelDefBuilder* kernel_def_builder`: The OrtKernelDefBuilder instance.</li><li>`const char* domain`: A null-terminated string representing the operator's domain.</li></ul>| Sets the kernel's domain. | | ... | ... | |**KernelDefBuilder_Build**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtKernelDefBuilder* kernel_def_builder`: The OrtKernelDefBuilder instance.</li><li>`OrtKernelDef** kernel_def_out`: The new OrtKernelDef instance.</li></ul>| Creates a OrtKernelDef instance from the given kernel definition builder. | ##### Defining a kernel implementation An EP defines a kernel implementation by initializing an instance of `OrtKernelImpl` (shown below) with function pointers for computation, release, etc. ```c++ struct OrtKernelImpl { uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION /** \brief Computation function called to execute the kernel on an EP. * * \param[in] this_ptr The OrtKernelImpl instance. * \param[in] context The OrtKernelContext instance that provides access to the inputs and outputs. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.24. */ ORT_API2_STATUS(Compute, _In_ OrtKernelImpl* this_ptr, _In_ OrtKernelContext* context); /** \brief Called by ORT to release the OrtKernelImpl instance and its resources. * * \param[in] this_ptr The OrtKernelImpl instance. * * \since Version 1.24. */ ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr); }; ``` As shown previously, the example EP creates a `Memcpy` class that inherits from `OrtKernelImpl` and [implements the above functions](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-kernel-based-eps/onnxruntime/test/autoep/library/kernels/memcpy.cc). ##### Defining a kernel creation function An EP must provide a function of type `OrtKernelCreateFunc` that ORT can later call to create an instance of a kernel (`OrtKernelImpl`). The signature of the `OrtKernelCreateFunc` is shown below. ```c++ /** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel. * * \param[in] ctx Unused/reserved for future use. * \param[in] kernel_create_func_state Opaque state initially provided by the EP that registered the kernel. * Refer to OrtEpApi::KernelRegistry_AddKernel(). May be null. * \param[in] info The OrtKernelInfo instance that provides access to the kernel's input and output characteristics. * \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.24. */ typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ OrtKernelCreateContext* ctx, // unused/reserved as of 1.24 _In_ void* kernel_create_func_state, _In_ const OrtKernelInfo* info, _Outptr_result_maybenull_ OrtKernelImpl** kernel_out); ``` The example EP declares kernel creation functions via use of the previously mentioned `ONNX_OPERATOR_KERNEL_EX` [macro](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-kernel-based-eps/onnxruntime/test/autoep/library/kernels/utils.h#L56-L64). If one were to expand the macro call, the kernel creation function for `MemcpyFromHost` would look similar to the following snippet: ```c++ OrtStatus* ORT_API_CALL CreateMemcpyKernel(OrtKernelCreateContext* /*ctx*/, void* kernel_create_func_state, const OrtKernelInfo* info, OrtKernelImpl** kernel_out) { *kernel_out = nullptr; std::unique_ptr<Memcpy> kernel; RETURN_IF_ERROR(Memcpy::Create(info, kernel_create_func_state, kernel)); *kernel_out = kernel.release(); return nullptr; } ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
1 parent 8a433fb commit db6d83b

42 files changed

Lines changed: 2741 additions & 64 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cmake/onnxruntime_unittests.cmake

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,50 @@ if (onnxruntime_BUILD_SHARED_LIB AND
20942094
set_target_properties(example_plugin_ep_virt_gpu PROPERTIES FOLDER "ONNXRuntimeTest")
20952095
source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src})
20962096

2097+
#
2098+
# example_plugin_ep_kernel_registry
2099+
#
2100+
set(onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src
2101+
"${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h"
2102+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib_entry.cc"
2103+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h"
2104+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc"
2105+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h"
2106+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc"
2107+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h"
2108+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc"
2109+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h"
2110+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h"
2111+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc"
2112+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h"
2113+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc"
2114+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"
2115+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc"
2116+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h"
2117+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc")
2118+
onnxruntime_add_shared_library_module(example_plugin_ep_kernel_registry ${onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src})
2119+
target_include_directories(example_plugin_ep_kernel_registry PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
2120+
target_link_libraries(example_plugin_ep_kernel_registry PRIVATE onnxruntime ${GSL_TARGET})
2121+
2122+
if(UNIX)
2123+
if (APPLE)
2124+
set(ONNXRUNTIME_EXAMPLE_PLUGIN_EP_KERNEL_REGISTRY_LINK_FLAG "-Xlinker -dead_strip")
2125+
elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX")
2126+
string(CONCAT ONNXRUNTIME_EXAMPLE_PLUGIN_EP_KERNEL_REGISTRY_LINK_FLAG
2127+
"-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib.lds "
2128+
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
2129+
endif()
2130+
else()
2131+
set(ONNXRUNTIME_EXAMPLE_PLUGIN_EP_KERNEL_REGISTRY_LINK_FLAG
2132+
"-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib.def")
2133+
endif()
2134+
2135+
set_property(TARGET example_plugin_ep_kernel_registry APPEND_STRING PROPERTY LINK_FLAGS
2136+
${ONNXRUNTIME_EXAMPLE_PLUGIN_EP_KERNEL_REGISTRY_LINK_FLAG})
2137+
2138+
set_target_properties(example_plugin_ep_kernel_registry PROPERTIES FOLDER "ONNXRuntimeTest")
2139+
source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src})
2140+
20972141
#
20982142
# test library
20992143
#
@@ -2129,7 +2173,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND
21292173
TARGET onnxruntime_autoep_test
21302174
SOURCES ${onnxruntime_autoep_test_SRC} ${onnxruntime_unittest_main_src}
21312175
LIBS ${onnxruntime_autoep_test_LIBS}
2132-
DEPENDS ${all_dependencies} example_plugin_ep example_plugin_ep_virt_gpu
2176+
DEPENDS ${all_dependencies} example_plugin_ep example_plugin_ep_virt_gpu example_plugin_ep_kernel_registry
21332177
)
21342178
endif()
21352179

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,9 @@ ORT_DEFINE_RELEASE(ValueInfo);
644644

645645
ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);
646646
ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);
647+
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi);
648+
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi);
649+
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi);
647650

648651
// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type,
649652
// but the struct has V2 in its name to indicate that it is the second version of the options.
@@ -3292,5 +3295,89 @@ struct Model : detail::ModelImpl<OrtModel> {
32923295
explicit Model(const std::vector<DomainOpsetPair>& opsets);
32933296
#endif
32943297
};
3298+
3299+
namespace detail {
3300+
template <typename T>
3301+
struct ConstKernelDefImpl : Base<T> {
3302+
using B = Base<T>;
3303+
using B::B;
3304+
3305+
///< Wraps OrtEpApi::KernelDef_GetOperatorType
3306+
const char* GetOperatorType() const;
3307+
3308+
///< Wraps OrtEpApi::KernelDef_GetDomain
3309+
const char* GetDomain() const;
3310+
3311+
///< Wraps OrtEpApi::KernelDef_GetSinceVersion
3312+
std::pair<int, int> GetSinceVersion() const;
3313+
3314+
///< Wraps OrtEpApi::KernelDef_GetExecutionProvider
3315+
const char* GetExecutionProvider() const;
3316+
3317+
///< Wraps OrtEpApi::KernelDef_GetInputMemType
3318+
OrtMemType GetInputMemType(size_t input_index) const;
3319+
3320+
///< Wraps OrtEpApi::KernelDef_GetOutputMemType
3321+
OrtMemType GetOutputMemType(size_t output_index) const;
3322+
};
3323+
} // namespace detail
3324+
3325+
using ConstKernelDef = detail::ConstKernelDefImpl<detail::Unowned<const OrtKernelDef>>;
3326+
3327+
struct KernelDef : detail::ConstKernelDefImpl<OrtKernelDef> {
3328+
using Base = detail::ConstKernelDefImpl<OrtKernelDef>;
3329+
using Base::Base;
3330+
3331+
explicit KernelDef(std::nullptr_t) {}
3332+
explicit KernelDef(OrtKernelDef* p) : detail::ConstKernelDefImpl<OrtKernelDef>{p} {}
3333+
3334+
ConstKernelDef GetConst() const { return ConstKernelDef{this->p_}; }
3335+
};
3336+
3337+
/** \brief Builder for OrtKernelDef.
3338+
*
3339+
* Used by plugin EPs to build a kernel definition.
3340+
*/
3341+
struct KernelDefBuilder : detail::Base<OrtKernelDefBuilder> {
3342+
KernelDefBuilder(); ///< Wraps OrtEpApi::CreateKernelDefBuilder
3343+
explicit KernelDefBuilder(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
3344+
explicit KernelDefBuilder(OrtKernelDefBuilder* ort_kernel_def_builder);
3345+
3346+
KernelDefBuilder& SetOperatorType(const char* op_type);
3347+
KernelDefBuilder& SetDomain(const char* domain);
3348+
KernelDefBuilder& SetSinceVersion(int since_version_start, int since_version_end);
3349+
KernelDefBuilder& SetExecutionProvider(const char* ep_name);
3350+
KernelDefBuilder& SetInputMemType(size_t input_index, OrtMemType mem_type);
3351+
KernelDefBuilder& SetOutputMemType(size_t output_index, OrtMemType mem_type);
3352+
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtDataType* data_type);
3353+
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const std::vector<const OrtDataType*>& data_types);
3354+
KernelDefBuilder& AddInputOutputAlias(int input_index, int output_index);
3355+
KernelDefBuilder& AddInputOutputAliases(const std::vector<int>& input_indices,
3356+
const std::vector<int>& output_indices);
3357+
KernelDefBuilder& AddInputOutputMutableAlias(int input_index, int output_index);
3358+
KernelDefBuilder& AddInputOutputMutableAliases(const std::vector<int>& input_indices,
3359+
const std::vector<int>& output_indices);
3360+
3361+
KernelDef Build();
3362+
};
3363+
3364+
/** \brief Registry for kernels supported by an EP.
3365+
*
3366+
* Used by plugin EPs to register definitions for supported kernels.
3367+
*/
3368+
struct KernelRegistry : detail::Base<OrtKernelRegistry> {
3369+
///< Wrapper around OrtEpApi::CreateKernelRegistry
3370+
KernelRegistry();
3371+
3372+
///< Create an empty object, must be assigned a valid one to be used
3373+
explicit KernelRegistry(std::nullptr_t) {}
3374+
3375+
///< Take ownership of a pointer created with the C API.
3376+
explicit KernelRegistry(OrtKernelRegistry* ort_kernel_registry);
3377+
3378+
///< Wraps KernelRegistry_AddKernel
3379+
Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func,
3380+
void* kernel_create_func_state);
3381+
};
32953382
} // namespace Ort
32963383
#include "onnxruntime_cxx_inline.h"

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <iterator>
1313
#include <string>
1414
#include <type_traits>
15+
#include <utility>
1516
#include <vector>
1617

1718
// Convert OrtStatus to Ort::Status and return
@@ -3572,4 +3573,144 @@ inline Model::Model(const std::vector<DomainOpsetPair>& opsets) {
35723573
}
35733574
#endif
35743575

3576+
namespace detail {
3577+
template <typename T>
3578+
inline const char* ConstKernelDefImpl<T>::GetOperatorType() const {
3579+
return GetEpApi().KernelDef_GetOperatorType(this->p_);
3580+
}
3581+
3582+
template <typename T>
3583+
inline const char* ConstKernelDefImpl<T>::GetDomain() const {
3584+
return GetEpApi().KernelDef_GetDomain(this->p_);
3585+
}
3586+
3587+
template <typename T>
3588+
inline std::pair<int, int> ConstKernelDefImpl<T>::GetSinceVersion() const {
3589+
int start = 0;
3590+
int end = 0;
3591+
3592+
ThrowOnError(GetEpApi().KernelDef_GetSinceVersion(this->p_, &start, &end));
3593+
return std::pair<int, int>(start, end);
3594+
}
3595+
3596+
template <typename T>
3597+
inline const char* ConstKernelDefImpl<T>::GetExecutionProvider() const {
3598+
return GetEpApi().KernelDef_GetExecutionProvider(this->p_);
3599+
}
3600+
3601+
template <typename T>
3602+
inline OrtMemType ConstKernelDefImpl<T>::GetInputMemType(size_t input_index) const {
3603+
OrtMemType mem_type{};
3604+
ThrowOnError(GetEpApi().KernelDef_GetInputMemType(this->p_, input_index, &mem_type));
3605+
3606+
return mem_type;
3607+
}
3608+
3609+
template <typename T>
3610+
inline OrtMemType ConstKernelDefImpl<T>::GetOutputMemType(size_t output_index) const {
3611+
OrtMemType mem_type{};
3612+
ThrowOnError(GetEpApi().KernelDef_GetOutputMemType(this->p_, output_index, &mem_type));
3613+
3614+
return mem_type;
3615+
}
3616+
} // namespace detail
3617+
3618+
inline KernelDefBuilder::KernelDefBuilder() {
3619+
ThrowOnError(GetEpApi().CreateKernelDefBuilder(&p_));
3620+
}
3621+
3622+
inline KernelDefBuilder::KernelDefBuilder(OrtKernelDefBuilder* p) : detail::Base<OrtKernelDefBuilder>{p} {
3623+
}
3624+
3625+
inline KernelDefBuilder& KernelDefBuilder::SetOperatorType(const char* op_type) {
3626+
ThrowOnError(GetEpApi().KernelDefBuilder_SetOperatorType(p_, op_type));
3627+
return *this;
3628+
}
3629+
3630+
inline KernelDefBuilder& KernelDefBuilder::SetDomain(const char* domain) {
3631+
ThrowOnError(GetEpApi().KernelDefBuilder_SetDomain(p_, domain));
3632+
return *this;
3633+
}
3634+
3635+
inline KernelDefBuilder& KernelDefBuilder::SetSinceVersion(int since_version_start, int since_version_end) {
3636+
ThrowOnError(GetEpApi().KernelDefBuilder_SetSinceVersion(p_, since_version_start, since_version_end));
3637+
return *this;
3638+
}
3639+
3640+
inline KernelDefBuilder& KernelDefBuilder::SetExecutionProvider(const char* ep_name) {
3641+
ThrowOnError(GetEpApi().KernelDefBuilder_SetExecutionProvider(p_, ep_name));
3642+
return *this;
3643+
}
3644+
3645+
inline KernelDefBuilder& KernelDefBuilder::SetInputMemType(size_t input_index, OrtMemType mem_type) {
3646+
ThrowOnError(GetEpApi().KernelDefBuilder_SetInputMemType(p_, input_index, mem_type));
3647+
return *this;
3648+
}
3649+
3650+
inline KernelDefBuilder& KernelDefBuilder::SetOutputMemType(size_t output_index, OrtMemType mem_type) {
3651+
ThrowOnError(GetEpApi().KernelDefBuilder_SetOutputMemType(p_, output_index, mem_type));
3652+
return *this;
3653+
}
3654+
3655+
inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name,
3656+
const OrtDataType* data_type) {
3657+
ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, &data_type, 1));
3658+
return *this;
3659+
}
3660+
3661+
inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_name,
3662+
const std::vector<const OrtDataType*>& data_types) {
3663+
ThrowOnError(GetEpApi().KernelDefBuilder_AddTypeConstraint(p_, arg_name, data_types.data(), data_types.size()));
3664+
return *this;
3665+
}
3666+
3667+
inline KernelDefBuilder& KernelDefBuilder::AddInputOutputAlias(int input_index, int output_index) {
3668+
ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputAliases(p_, &input_index, &output_index, 1));
3669+
return *this;
3670+
}
3671+
3672+
inline KernelDefBuilder& KernelDefBuilder::AddInputOutputAliases(const std::vector<int>& input_indices,
3673+
const std::vector<int>& output_indices) {
3674+
if (input_indices.size() != output_indices.size()) {
3675+
ORT_CXX_API_THROW("Expecting input and output indices to have the same element count", ORT_INVALID_ARGUMENT);
3676+
}
3677+
3678+
ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputAliases(p_, input_indices.data(), output_indices.data(),
3679+
input_indices.size()));
3680+
return *this;
3681+
}
3682+
3683+
inline KernelDefBuilder& KernelDefBuilder::AddInputOutputMutableAlias(int input_index, int output_index) {
3684+
ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputMutableAliases(p_, &input_index, &output_index, 1));
3685+
return *this;
3686+
}
3687+
3688+
inline KernelDefBuilder& KernelDefBuilder::AddInputOutputMutableAliases(const std::vector<int>& input_indices,
3689+
const std::vector<int>& output_indices) {
3690+
if (input_indices.size() != output_indices.size()) {
3691+
ORT_CXX_API_THROW("Expecting input and output indices to have the same element count", ORT_INVALID_ARGUMENT);
3692+
}
3693+
3694+
ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputMutableAliases(p_, input_indices.data(), output_indices.data(),
3695+
input_indices.size()));
3696+
return *this;
3697+
}
3698+
3699+
inline KernelDef KernelDefBuilder::Build() {
3700+
OrtKernelDef* kernel_def = nullptr;
3701+
ThrowOnError(GetEpApi().KernelDefBuilder_Build(p_, &kernel_def));
3702+
return KernelDef(kernel_def);
3703+
}
3704+
3705+
inline KernelRegistry::KernelRegistry() {
3706+
ThrowOnError(GetEpApi().CreateKernelRegistry(&p_));
3707+
}
3708+
3709+
inline KernelRegistry::KernelRegistry(OrtKernelRegistry* p) : detail::Base<OrtKernelRegistry>{p} {
3710+
}
3711+
3712+
inline Status KernelRegistry::AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func,
3713+
void* kernel_create_func_state) {
3714+
return Status{GetEpApi().KernelRegistry_AddKernel(p_, kernel_def, kernel_create_func, kernel_create_func_state)};
3715+
}
35753716
} // namespace Ort

0 commit comments

Comments
 (0)