Skip to content

Commit b8f5f1a

Browse files
authored
Add KernelInfo string-array attribute APIs to the ORT C and C++ APIs (microsoft#27599)
## Summary This PR adds a new `OrtApi` entry point for reading repeated string attributes from `OrtKernelInfo`: - `KernelInfoGetAttributeArray_string` It also wires that support through the C++ wrapper so callers can use: - `Ort::ConstKernelInfo::GetAttributes<std::string>(...)` ## Problem The existing kernel info APIs already support scalar and array attribute access for numeric types, but there was no C API for reading string-array attributes from `OrtKernelInfo`. That created a gap for code paths that rely on repeated string attributes in kernel metadata, such as: - custom op / kernel consumers using the public C API - C++ wrapper callers expecting `GetAttributes<std::string>` to work end-to-end - plugin EP scenarios that need to compile existing kernels against the adapter/C API surface One concrete case is CUDA plugin EP RNN support, where the RNN kernels read the `activations` attribute via `GetAttrs<std::string>("activations", ...)`. The adapter path needed a corresponding ORT C API to expose that data. ## Changes ### C API Added `OrtApi::KernelInfoGetAttributeArray_string` to fetch repeated string attributes from `OrtKernelInfo`. Behavior: - If `out == nullptr`, the API returns the attribute count in `size`. - Otherwise, the API allocates the pointer array and each UTF-8 string with the provided `OrtAllocator`. - For empty attributes, `*out` is set to `nullptr` and `*size` is set to `0`. - The caller frees each string and the pointer array with the same allocator. ### Implementation Added the implementation in the ORT session/custom-op API layer by: - reading the underlying attribute with `OpKernelInfo::GetAttrs<std::string>` - copying the result into allocator-owned C-style string storage for the public API ### C++ wrapper Completed C++ wrapper support so `Ort::ConstKernelInfo::GetAttributes<std::string>(name)` works through the new C API. The wrapper follows the standard two-call pattern: 1. query the number of strings 2. allocate and fetch the returned string array 3. copy into `std::vector<std::string>` and release allocator-owned memory ### Tests Added framework tests covering: - non-empty string-array attributes - empty string-array attributes - missing attribute failure path - C++ wrapper access through `Ort::ConstKernelInfo` ## Files Changed - `include/onnxruntime/core/session/onnxruntime_c_api.h` - `include/onnxruntime/core/session/onnxruntime_cxx_api.h` - `include/onnxruntime/core/session/onnxruntime_cxx_inline.h` - `onnxruntime/core/session/custom_ops.cc` - `onnxruntime/core/session/onnxruntime_c_api.cc` - `onnxruntime/core/session/ort_apis.h` - `onnxruntime/test/framework/kernel_info_test.cc` ## Why This Change This closes a real API gap in kernel attribute access and makes the public API surface more consistent with the existing numeric attribute helpers. It also unblocks plugin/adapter-based kernel code that depends on repeated string attributes without requiring those kernels to special-case plugin builds. For example, porting rnn operator to cuda plugin EP will need this API. ## Validation Validated with new unit coverage in `kernel_info_test.cc` for: - `KernelInfoGetAttributeArray_string` with populated attributes - `KernelInfoGetAttributeArray_string` with empty attributes - missing-attribute error handling - `Ort::ConstKernelInfo::GetAttributes<std::string>` parity with the C API
1 parent 0edb66b commit b8f5f1a

7 files changed

Lines changed: 249 additions & 1 deletion

File tree

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7305,6 +7305,28 @@ struct OrtApi {
73057305
* \since Version 1.25.
73067306
*/
73077307
ORT_API2_STATUS(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options);
7308+
7309+
/** \brief Fetch an array of strings stored as an attribute in the graph node
7310+
*
7311+
* If `out` is nullptr, the value of `size` is set to the true size of the attribute
7312+
* array and a success status is returned.
7313+
*
7314+
* Otherwise, the strings and pointer array are allocated using `allocator`.
7315+
* The caller must free each string and the pointer array with `allocator`.
7316+
* If the attribute array is empty, `*out` is set to nullptr and `*size` is set to 0.
7317+
*
7318+
* \param[in] info instance
7319+
* \param[in] name name of the attribute to be parsed
7320+
* \param[in] allocator allocator used to allocate the returned string array and strings
7321+
* \param[out] out pointer to the returned array of null-terminated UTF-8 strings
7322+
* \param[out] size actual size of attribute array
7323+
*
7324+
* \snippet{doc} snippets.dox OrtStatus Return Value
7325+
*
7326+
* \since Version 1.25.
7327+
*/
7328+
ORT_API2_STATUS(KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name,
7329+
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size);
73087330
};
73097331

73107332
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2868,6 +2868,7 @@ void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
28682868
void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
28692869
void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
28702870
void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
2871+
void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<std::string>&);
28712872
} // namespace attr_utils
28722873

28732874
template <typename T>
@@ -2884,7 +2885,7 @@ struct KernelInfoImpl : Base<T> {
28842885
return val;
28852886
}
28862887

2887-
template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
2888+
template <typename R> // R is only implemented for float, int64_t, and string
28882889
std::vector<R> GetAttributes(const char* name) const {
28892890
std::vector<R> result;
28902891
attr_utils::GetAttrs(this->p_, name, result);

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3063,6 +3063,39 @@ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::
30633063
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
30643064
out.swap(result);
30653065
}
3066+
3067+
inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<std::string>& result) {
3068+
AllocatorWithDefaultOptions allocator;
3069+
char** out = nullptr;
3070+
size_t size = 0;
3071+
3072+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_string(p, name, allocator, nullptr, &size));
3073+
if (size == 0) {
3074+
result.clear();
3075+
return;
3076+
}
3077+
3078+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_string(p, name, allocator, &out, &size));
3079+
3080+
auto deleter = detail::AllocatedFree(allocator);
3081+
std::unique_ptr<void, decltype(deleter)> array_guard(out, deleter);
3082+
auto strings_deleter = [&deleter, size](char** values) {
3083+
for (size_t i = 0; i < size; ++i) {
3084+
if (values[i] != nullptr) {
3085+
deleter(values[i]);
3086+
}
3087+
}
3088+
};
3089+
std::unique_ptr<char*, decltype(strings_deleter)> strings_guard(out, strings_deleter);
3090+
3091+
std::vector<std::string> strings;
3092+
strings.reserve(size);
3093+
for (size_t i = 0; i < size; ++i) {
3094+
strings.emplace_back(out[i]);
3095+
}
3096+
3097+
strings.swap(result);
3098+
}
30663099
} // namespace detail
30673100

30683101
inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}

onnxruntime/core/session/custom_ops.cc

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <unordered_set>
1212

1313
#include <gsl/gsl>
14+
#include "core/common/safeint.h"
15+
#include "core/common/string_helper.h"
1416
#include "core/common/logging/logging.h"
1517
#include "core/framework/data_types.h"
1618
#include "core/framework/error_code_helper.h"
@@ -545,6 +547,65 @@ static Status CopyDataFromVectorToMemory(const std::vector<T>& values, T* out, s
545547
return Status::OK();
546548
}
547549

550+
static char* DuplicateStringToAllocatorMemory(const std::string& value, OrtAllocator* allocator) {
551+
SafeInt<size_t> allocation_size(value.size());
552+
allocation_size += 1;
553+
554+
char* duplicated_value = static_cast<char*>(allocator->Alloc(allocator, allocation_size));
555+
if (duplicated_value == nullptr) {
556+
return nullptr;
557+
}
558+
559+
std::memcpy(duplicated_value, value.data(), value.size());
560+
duplicated_value[value.size()] = '\0';
561+
return duplicated_value;
562+
}
563+
564+
static Status CopyStringDataFromVectorToMemory(const std::vector<std::string>& values, OrtAllocator* allocator, char*** out, size_t* size) {
565+
*size = values.size();
566+
567+
if (out == nullptr) {
568+
return Status::OK();
569+
}
570+
571+
ORT_RETURN_IF_NOT(allocator != nullptr, "allocator must not be null when out is provided");
572+
*out = nullptr;
573+
574+
if (values.empty()) {
575+
return Status::OK();
576+
}
577+
578+
auto free_with_allocator = [allocator](void* value) {
579+
allocator->Free(allocator, value);
580+
};
581+
SafeInt<size_t> alloc_count(values.size());
582+
char** array = reinterpret_cast<char**>(allocator->Alloc(allocator, alloc_count * sizeof(char*)));
583+
ORT_RETURN_IF_NOT(array != nullptr, "Failed to allocate string attribute pointer array");
584+
std::unique_ptr<void, decltype(free_with_allocator)> array_guard(array, free_with_allocator);
585+
586+
size_t allocated_string_count = 0;
587+
for (size_t i = 0; i < values.size(); ++i) {
588+
char* duplicated_value = DuplicateStringToAllocatorMemory(values[i], allocator);
589+
if (duplicated_value == nullptr) {
590+
for (size_t j = 0; j < allocated_string_count; ++j) {
591+
if (array[j] != nullptr) {
592+
allocator->Free(allocator, array[j]);
593+
}
594+
}
595+
596+
return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL,
597+
"Failed to allocate string attribute array");
598+
}
599+
600+
array[i] = duplicated_value;
601+
++allocated_string_count;
602+
}
603+
604+
*out = array;
605+
array_guard.release();
606+
return Status::OK();
607+
}
608+
548609
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
549610
_Out_ float* out, _Inout_ size_t* size) {
550611
return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr {
@@ -569,6 +630,18 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKe
569630
});
570631
}
571632

633+
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name,
634+
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size) {
635+
return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr {
636+
std::vector<std::string> values;
637+
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttrs<std::string>(name, values);
638+
if (status.IsOK()) {
639+
status = CopyStringDataFromVectorToMemory(values, allocator, out, size);
640+
}
641+
return onnxruntime::ToOrtStatus(status);
642+
});
643+
}
644+
572645
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name,
573646
_Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) {
574647
return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr {

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4807,6 +4807,8 @@ static constexpr OrtApi ort_api_1_to_25 = {
48074807

48084808
&OrtApis::RunOptionsEnableProfiling,
48094809
&OrtApis::RunOptionsDisableProfiling,
4810+
&OrtApis::KernelInfoGetAttributeArray_string,
4811+
// End of Version 25 - DO NOT MODIFY ABOVE (see above text for more information)
48104812
};
48114813

48124814
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
@@ -4844,6 +4846,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz
48444846
static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");
48454847
static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change");
48464848
static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change");
4849+
static_assert(offsetof(OrtApi, KernelInfoGetAttributeArray_string) / sizeof(void*) == 417, "Size of version 25 API cannot change");
48474850

48484851
// So that nobody forgets to finish an API version, this check will serve as a reminder:
48494852
static_assert(std::string_view(ORT_VERSION) == "1.25.0",

onnxruntime/core/session/ort_apis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ ORT_API_STATUS_IMPL(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options);
125125
ORT_API(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream);
126126
ORT_API_STATUS_IMPL(RunOptionsEnableProfiling, _Inout_ OrtRunOptions* options, _In_ const ORTCHAR_T* profile_file_prefix);
127127
ORT_API_STATUS_IMPL(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options);
128+
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name,
129+
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size);
128130

129131
ORT_API_STATUS_IMPL(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
130132
_In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type,
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "gtest/gtest.h"
5+
#include "test/test_environment.h"
6+
#include "test/util/include/asserts.h"
7+
#include "core/graph/model.h"
8+
#include "core/graph/op.h"
9+
#include "core/graph/onnx_protobuf.h"
10+
#include "core/framework/execution_providers.h"
11+
#include "core/framework/op_kernel.h"
12+
#include "core/framework/external_data_loader_manager.h"
13+
#include "core/framework/session_state.h"
14+
#include "core/providers/cpu/cpu_execution_provider.h"
15+
#include "core/session/onnxruntime_cxx_api.h"
16+
17+
using namespace ONNX_NAMESPACE;
18+
19+
namespace onnxruntime {
20+
namespace test {
21+
22+
ONNX_OPERATOR_SCHEMA(KernelInfoStringArrayAttrOp)
23+
.SetDoc("Test op for kernel info string-array attributes.")
24+
.Attr("strings_attr", "Repeated string attribute for kernel info API tests.",
25+
AttrType::AttributeProto_AttributeType_STRINGS, std::vector<std::string>{})
26+
.Output(0, "output_1", "docstr for output_1.", "tensor(int32)");
27+
28+
static void VerifyKernelInfoStringArrayAttribute(const std::vector<std::string>& attribute_values) {
29+
OrtThreadPoolParams to;
30+
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
31+
32+
onnxruntime::Model model("graph_kernel_info_string_attr", false, DefaultLoggingManager().DefaultLogger());
33+
auto& graph = model.MainGraph();
34+
35+
ExecutionProviders execution_providers;
36+
auto tmp_cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo(false));
37+
auto* cpu_execution_provider = tmp_cpu_execution_provider.get();
38+
ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_cpu_execution_provider)));
39+
40+
DataTransferManager dtm;
41+
ExternalDataLoaderManager edlm;
42+
profiling::Profiler profiler;
43+
44+
SessionOptions sess_options;
45+
sess_options.enable_mem_pattern = true;
46+
sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
47+
sess_options.use_deterministic_compute = false;
48+
sess_options.enable_mem_reuse = true;
49+
50+
SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm,
51+
DefaultLoggingManager().DefaultLogger(), profiler, sess_options);
52+
53+
std::vector<onnxruntime::NodeArg*> inputs;
54+
std::vector<onnxruntime::NodeArg*> outputs;
55+
TypeProto output_type;
56+
output_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
57+
output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
58+
onnxruntime::NodeArg output_arg("node_1_out_1", &output_type);
59+
outputs.push_back(&output_arg);
60+
61+
onnxruntime::Node& node = graph.AddNode("node_1", "KernelInfoStringArrayAttrOp", "node 1.", inputs, outputs);
62+
node.AddAttribute("strings_attr", gsl::make_span(attribute_values));
63+
ASSERT_STATUS_OK(graph.Resolve());
64+
65+
auto kernel_def = KernelDefBuilder().SetName("KernelInfoStringArrayAttrOp").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();
66+
67+
OpKernelInfo kernel_info(node, *kernel_def, *cpu_execution_provider, session_state.GetConstantInitializedTensors(),
68+
session_state.GetOrtValueNameIdxMap(), session_state.GetDataTransferMgr(), session_state.GetAllocators(),
69+
session_state.GetSessionOptions().config_options);
70+
71+
const OrtApi& ort_api = Ort::GetApi();
72+
OrtAllocator* allocator = nullptr;
73+
ASSERT_EQ(nullptr, ort_api.GetAllocatorWithDefaultOptions(&allocator));
74+
75+
size_t size = 0;
76+
ASSERT_EQ(nullptr, ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast<const OrtKernelInfo*>(&kernel_info), "strings_attr",
77+
allocator, nullptr, &size));
78+
ASSERT_EQ(attribute_values.size(), size);
79+
80+
char** out = nullptr;
81+
ASSERT_EQ(nullptr, ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast<const OrtKernelInfo*>(&kernel_info), "strings_attr",
82+
allocator, &out, &size));
83+
ASSERT_EQ(attribute_values.size(), size);
84+
85+
if (attribute_values.empty()) {
86+
ASSERT_EQ(nullptr, out);
87+
} else {
88+
ASSERT_NE(nullptr, out);
89+
for (size_t i = 0; i < size; ++i) {
90+
EXPECT_STREQ(attribute_values[i].c_str(), out[i]);
91+
allocator->Free(allocator, out[i]);
92+
}
93+
allocator->Free(allocator, out);
94+
}
95+
96+
Ort::ConstKernelInfo ort_kernel_info{reinterpret_cast<const OrtKernelInfo*>(&kernel_info)};
97+
EXPECT_EQ(attribute_values, ort_kernel_info.GetAttributes<std::string>("strings_attr"));
98+
99+
OrtStatus* status = ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast<const OrtKernelInfo*>(&kernel_info), "missing_attr",
100+
allocator, nullptr, &size);
101+
ASSERT_NE(nullptr, status);
102+
ort_api.ReleaseStatus(status);
103+
}
104+
105+
TEST(KernelInfoTests, KernelInfoGetAttributeArrayString) {
106+
VerifyKernelInfoStringArrayAttribute({"alpha", "beta", "gamma"});
107+
}
108+
109+
TEST(KernelInfoTests, KernelInfoGetAttributeArrayStringEmpty) {
110+
VerifyKernelInfoStringArrayAttribute({});
111+
}
112+
113+
} // namespace test
114+
} // namespace onnxruntime

0 commit comments

Comments
 (0)