Skip to content

Commit 2c4a0c1

Browse files
authored
Fix plugin EP test failure when host ORT lacks newer data types (#28659)
### Description Fix the CUDA plugin EP package test pipeline failure where the plugin is built with the latest code (which includes `float8e8m0` and other newer data types), but the host ORT 1.26 release does not support these types. When the plugin attempts to register kernel type constraints containing unsupported types, `GetTensorDataType` fails and the plugin load crashes. ### Motivation and Context The plugin EP architecture allows plugins to be built against a newer version of the ONNX Runtime headers while being loaded into an older host ORT. However, the existing `KernelDefBuilder::TypeConstraint` methods call `GetTensorType` (which throws on unsupported types), making it impossible for a forward-compatible plugin to register kernels that include newer data types in their type constraint lists. ### Changes - Add `TryGetTensorType()` — a non-throwing variant of `GetTensorType()` that returns `nullptr` when the host ORT does not recognize a tensor element type. - Add `TryMLDataTypeToOrtDataType()` — a non-throwing variant of `MLDataTypeToOrtDataType()` that returns `nullptr` instead of asserting/throwing. - Update `KernelDefBuilder::TypeConstraint` (both vector and single-type overloads) to use the `Try` variants and gracefully skip unsupported types rather than failing. ### Impact - Plugins built with newer headers can now load into older host ORT releases without crashing on unknown data types. - If all types in a constraint list are unsupported, the constraint is simply not registered (the kernel will not match, which is the correct fallback behavior). - No behavioral change when the host supports all types — the code path is identical to before.
1 parent 30b6528 commit 2c4a0c1

2 files changed

Lines changed: 50 additions & 5 deletions

File tree

include/onnxruntime/ep/adapter/kernel_def_builder.h

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#endif
99

1010
#include <memory>
11+
#include <vector>
1112

1213
#include "core/framework/data_types.h"
1314

@@ -28,6 +29,23 @@ inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) {
2829
return result;
2930
}
3031

32+
/// <summary>
33+
/// Gets an OrtDataType for a tensor type. Returns nullptr if the host ORT does not support the type.
34+
/// </summary>
35+
inline const OrtDataType* TryGetTensorType(ONNXTensorElementDataType elem_type) {
36+
const OrtEpApi& ep_api = Ort::GetEpApi();
37+
const OrtDataType* result = nullptr;
38+
39+
Ort::Status status(ep_api.GetTensorDataType(elem_type, &result));
40+
if (!status.IsOK()) {
41+
if (status.GetErrorCode() == ORT_INVALID_ARGUMENT || status.GetErrorCode() == ORT_NOT_IMPLEMENTED) {
42+
return nullptr;
43+
}
44+
Ort::ThrowOnError(status);
45+
}
46+
return result;
47+
}
48+
3149
inline const OrtDataType* MLDataTypeToOrtDataType(MLDataType ml_type) {
3250
auto tensor_type = ml_type->AsTensorType();
3351
EP_ENFORCE(tensor_type != nullptr, "EP Kernel registration only supports tensor types.");
@@ -37,6 +55,20 @@ inline const OrtDataType* MLDataTypeToOrtDataType(MLDataType ml_type) {
3755
return GetTensorType(onnx_type);
3856
}
3957

58+
/// <summary>
59+
/// Converts an MLDataType to an OrtDataType. Returns nullptr if the host ORT does not support the type.
60+
/// This enables forward-compatible plugins to register kernels with type constraints that include newer
61+
/// data types without failing when loaded into an older host ORT.
62+
/// </summary>
63+
inline const OrtDataType* TryMLDataTypeToOrtDataType(MLDataType ml_type) {
64+
auto tensor_type = ml_type->AsTensorType();
65+
EP_ENFORCE(tensor_type != nullptr, "EP Kernel registration only supports tensor types.");
66+
auto elem_type = tensor_type->GetElementType();
67+
auto primitive_type = static_cast<const PrimitiveDataTypeBase*>(elem_type);
68+
auto onnx_type = static_cast<ONNXTensorElementDataType>(primitive_type->GetDataType());
69+
return TryGetTensorType(onnx_type);
70+
}
71+
4072
/// <summary>
4173
/// An adapter class partially implementing the interface of `onnxruntime::KernelDefBuilder`.
4274
/// </summary>
@@ -73,14 +105,26 @@ struct KernelDefBuilder {
73105
std::vector<const OrtDataType*> ort_types;
74106
ort_types.reserve(types.size());
75107
for (const auto& type : types) {
76-
ort_types.push_back(MLDataTypeToOrtDataType(type));
108+
const OrtDataType* ort_type = TryMLDataTypeToOrtDataType(type);
109+
if (ort_type != nullptr) {
110+
ort_types.push_back(ort_type);
111+
}
112+
}
113+
if (types.empty() || !ort_types.empty()) {
114+
builder_.AddTypeConstraint(arg_name, ort_types);
115+
} else {
116+
valid_ = false;
77117
}
78-
builder_.AddTypeConstraint(arg_name, ort_types);
79118
return *this;
80119
}
81120

82121
KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type) {
83-
builder_.AddTypeConstraint(arg_name, MLDataTypeToOrtDataType(type));
122+
const OrtDataType* ort_type = TryMLDataTypeToOrtDataType(type);
123+
if (ort_type != nullptr) {
124+
builder_.AddTypeConstraint(arg_name, ort_type);
125+
} else {
126+
valid_ = false;
127+
}
84128
return *this;
85129
}
86130

@@ -134,10 +178,11 @@ struct KernelDefBuilder {
134178
// assignment externally; the queue id hint is not needed.
135179
KernelDefBuilder& ExecQueueId(int /*queue_id*/) { return *this; }
136180

137-
Ort::KernelDef Build() { return builder_.Build(); }
181+
Ort::KernelDef Build() { return valid_ ? builder_.Build() : Ort::KernelDef{nullptr}; }
138182

139183
private:
140184
Ort::KernelDefBuilder builder_;
185+
bool valid_ = true;
141186
};
142187

143188
} // namespace adapter

include/onnxruntime/ep/adapter/kernel_registry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct KernelCreateInfo {
3535
KernelCreatePtrFn create_func)
3636
: kernel_def(std::move(definition)),
3737
kernel_create_func(create_func) {
38-
assert(kernel_def != nullptr);
38+
assert(kernel_def == nullptr || kernel_create_func != nullptr);
3939
}
4040

4141
KernelCreateInfo(KernelCreateInfo&& other) noexcept

0 commit comments

Comments
 (0)