Skip to content

Commit f6c4c22

Browse files
awoll-bdaiexploy-bot
authored andcommitted
Support input initializers in onnx runtime
### What change is being made Extend the onnx runtime to support [input initializers](https://onnx.ai/onnx/intro/python.html#initializer-default-value). In the first run or after reset, evaluation of the model is called without the tensors of inputs which have initializers, such that the onnx runtime uses the default values. ### Why this change is being made Allow to initialize inputs, e.g. memory. ### Tested Extended unit tests. GitOrigin-RevId: d630b9d9baf0dddff8da2e970673f71141b7fe6d
1 parent f749f6e commit f6c4c22

4 files changed

Lines changed: 220 additions & 30 deletions

File tree

control/include/exploy/onnx_runtime.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ class OnnxRuntime {
173173
Ort::RunOptions run_options_{nullptr};
174174

175175
struct TensorData {
176-
std::size_t size;
176+
std::size_t size{0};
177177
std::vector<std::vector<int64_t>> shapes;
178178
std::vector<const char*> names;
179179
std::vector<Ort::Value> tensors;
@@ -184,6 +184,10 @@ class OnnxRuntime {
184184
TensorData input_;
185185
TensorData output_;
186186

187+
// Number of inputs without an overridable initializer.
188+
std::size_t non_initializer_input_count_{0};
189+
bool use_initializers_{true};
190+
187191
std::unordered_map<std::string, int> input_names_to_index_{};
188192
std::unordered_map<std::string, int> output_names_to_index_{};
189193
};

control/src/onnx_runtime.cpp

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,46 @@ namespace exploy::control {
1414

1515
namespace {
1616

17+
enum class TensorKind { Input, Output, Initializer };
18+
19+
std::size_t getTensorCount(const Ort::Session& session, TensorKind kind) {
20+
switch (kind) {
21+
case TensorKind::Input:
22+
return session.GetInputCount();
23+
case TensorKind::Output:
24+
return session.GetOutputCount();
25+
case TensorKind::Initializer:
26+
return session.GetOverridableInitializerCount();
27+
}
28+
return 0;
29+
}
30+
31+
Ort::AllocatedStringPtr getTensorNameAllocated(Ort::Session& session,
32+
Ort::AllocatorWithDefaultOptions& allocator,
33+
TensorKind kind, std::size_t index) {
34+
switch (kind) {
35+
case TensorKind::Input:
36+
return session.GetInputNameAllocated(index, allocator);
37+
case TensorKind::Output:
38+
return session.GetOutputNameAllocated(index, allocator);
39+
case TensorKind::Initializer:
40+
return session.GetOverridableInitializerNameAllocated(index, allocator);
41+
}
42+
return Ort::AllocatedStringPtr{nullptr, Ort::detail::AllocatedFree{allocator}};
43+
}
44+
45+
Ort::TypeInfo getTensorTypeInfo(Ort::Session& session, TensorKind kind, std::size_t index) {
46+
switch (kind) {
47+
case TensorKind::Input:
48+
return session.GetInputTypeInfo(index);
49+
case TensorKind::Output:
50+
return session.GetOutputTypeInfo(index);
51+
case TensorKind::Initializer:
52+
return session.GetOverridableInitializerTypeInfo(index);
53+
}
54+
return Ort::TypeInfo{nullptr};
55+
}
56+
1757
void resetTensorBuffer(Ort::Value& tensor, ONNXTensorElementDataType data_type) {
1858
const auto count = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
1959
switch (data_type) {
@@ -36,34 +76,35 @@ void resetTensorBuffer(Ort::Value& tensor, ONNXTensorElementDataType data_type)
3676
}
3777

3878
template <typename TensorDataType>
39-
void initializeTensorData(TensorDataType& tensor_data, std::unique_ptr<Ort::Session>& session,
40-
Ort::AllocatorWithDefaultOptions& allocator,
41-
std::unordered_map<std::string, int>& names_to_index, bool is_input) {
42-
tensor_data.size = is_input ? session->GetInputCount() : session->GetOutputCount();
43-
44-
tensor_data.names.reserve(tensor_data.size);
45-
tensor_data.shapes.reserve(tensor_data.size);
46-
tensor_data.data_types.reserve(tensor_data.size);
47-
tensor_data.tensors.reserve(tensor_data.size);
48-
tensor_data.allocated_names.reserve(tensor_data.size);
49-
50-
for (std::size_t n = 0; n < tensor_data.size; n++) {
51-
auto name_ptr = is_input ? session->GetInputNameAllocated(n, allocator)
52-
: session->GetOutputNameAllocated(n, allocator);
53-
tensor_data.allocated_names.push_back(std::move(name_ptr));
79+
void appendTensorData(TensorDataType& tensor_data, std::unique_ptr<Ort::Session>& session,
80+
Ort::AllocatorWithDefaultOptions& allocator,
81+
std::unordered_map<std::string, int>& names_to_index, TensorKind kind) {
82+
const std::size_t count = getTensorCount(*session, kind);
83+
84+
const std::size_t new_size = tensor_data.size + count;
85+
tensor_data.names.reserve(new_size);
86+
tensor_data.shapes.reserve(new_size);
87+
tensor_data.data_types.reserve(new_size);
88+
tensor_data.tensors.reserve(new_size);
89+
tensor_data.allocated_names.reserve(new_size);
90+
91+
for (std::size_t n = 0; n < count; n++) {
92+
tensor_data.allocated_names.push_back(getTensorNameAllocated(*session, allocator, kind, n));
5493
tensor_data.names.push_back(tensor_data.allocated_names.back().get());
5594

56-
auto type_info = is_input ? session->GetInputTypeInfo(n) : session->GetOutputTypeInfo(n);
95+
auto type_info = getTensorTypeInfo(*session, kind, n);
5796
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
58-
5997
tensor_data.shapes.push_back(tensor_info.GetShape());
6098
tensor_data.data_types.push_back(tensor_info.GetElementType());
6199

62-
tensor_data.tensors.push_back(Ort::Value::CreateTensor(allocator, tensor_data.shapes[n].data(),
63-
tensor_data.shapes[n].size(),
64-
tensor_data.data_types[n]));
100+
tensor_data.tensors.push_back(
101+
Ort::Value::CreateTensor(allocator, tensor_data.shapes.back().data(),
102+
tensor_data.shapes.back().size(), tensor_data.data_types.back()));
103+
104+
resetTensorBuffer(tensor_data.tensors.back(), tensor_data.data_types.back());
65105

66-
names_to_index[std::string(tensor_data.names.back())] = n;
106+
names_to_index[std::string(tensor_data.names.back())] = static_cast<int>(tensor_data.size);
107+
tensor_data.size++;
67108
}
68109
}
69110

@@ -113,22 +154,37 @@ bool OnnxRuntime::initialize(const std::string& model_path, const OnnxRuntimeOpt
113154
break;
114155
}
115156

116-
initializeTensorData(input_, session_, allocator_, input_names_to_index_, /*is_input=*/true);
117-
initializeTensorData(output_, session_, allocator_, output_names_to_index_, /*is_input=*/false);
157+
input_ = TensorData{};
158+
output_ = TensorData{};
159+
input_names_to_index_.clear();
160+
output_names_to_index_.clear();
161+
non_initializer_input_count_ = 0;
118162

163+
// Append initializer-backed inputs after regular inputs so that we can optionally let ONNX
164+
// Runtime use the model's default values for them after a reset.
165+
appendTensorData(input_, session_, allocator_, input_names_to_index_, TensorKind::Input);
166+
appendTensorData(input_, session_, allocator_, input_names_to_index_, TensorKind::Initializer);
167+
appendTensorData(output_, session_, allocator_, output_names_to_index_, TensorKind::Output);
168+
non_initializer_input_count_ = getTensorCount(*session_, TensorKind::Input);
169+
use_initializers_ = true;
119170
metadata_ = session_->GetModelMetadata();
120171

121172
return true;
122173
}
123174

124175
bool OnnxRuntime::evaluate() {
176+
// If use_initializers_ is true, we pass only the leading non-initializer inputs to let ONNX
177+
// Runtime use the model's default values for the rest. After the first run, we always pass all
178+
// inputs and ignore the model defaults.
179+
const std::size_t input_count = use_initializers_ ? non_initializer_input_count_ : input_.size;
125180
try {
126-
session_->Run(run_options_, input_.names.data(), input_.tensors.data(), input_.size,
181+
session_->Run(run_options_, input_.names.data(), input_.tensors.data(), input_count,
127182
output_.names.data(), output_.tensors.data(), output_.size);
128183
} catch (const Ort::Exception& e) {
129184
LOG_STREAM(ERROR, "ONNX Runtime evaluation failed: " << e.what());
130185
return false;
131186
}
187+
use_initializers_ = false;
132188
return true;
133189
}
134190

@@ -144,6 +200,7 @@ void OnnxRuntime::resetBuffers() {
144200
for (std::size_t n = 0; n < output_.size; n++) {
145201
resetTensorBuffer(output_.tensors[n], output_.data_types[n]);
146202
}
203+
use_initializers_ = true;
147204
}
148205

149206
std::unordered_set<std::string> OnnxRuntime::inputNames() const {

control/test/onnx_runtime_test.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <gtest/gtest.h>
66

7+
#include <algorithm>
78
#include <filesystem>
89
#include <span>
910

@@ -79,6 +80,7 @@ TEST_F(OnnxRuntimeTest, InputTensorNames) {
7980
EXPECT_TRUE(input_names.contains("float_input"));
8081
EXPECT_TRUE(input_names.contains("int_input"));
8182
EXPECT_TRUE(input_names.contains("bool_input"));
83+
EXPECT_TRUE(input_names.contains("init_float_input"));
8284
}
8385

8486
TEST_F(OnnxRuntimeTest, OutputTensorNames) {
@@ -400,4 +402,122 @@ TEST_F(OnnxRuntimeTest, CopyOutputToInputTypeMismatch) {
400402
EXPECT_FALSE(runtime.copyOutputToInput("float_output", "bool_input"));
401403
}
402404

405+
TEST_F(OnnxRuntimeTest, FirstRunUsesInitializerDefaults) {
406+
OnnxRuntime runtime;
407+
ASSERT_TRUE(runtime.initialize(simple_model_path_));
408+
runtime.resetBuffers();
409+
410+
auto float_input = runtime.inputBuffer<float>("float_input");
411+
ASSERT_TRUE(float_input.has_value());
412+
float_input.value()[0] = 1.5f;
413+
float_input.value()[1] = 2.5f;
414+
float_input.value()[2] = 3.5f;
415+
416+
// Populate the ``init_float_input`` buffer with non-zero values that would clearly
417+
// differ from the baked-in defaults [0, 0, 0]. The first run must IGNORE these because
418+
// ``init_float_input`` is an overridable initializer.
419+
auto init_float_input = runtime.inputBuffer<float>("init_float_input");
420+
ASSERT_TRUE(init_float_input.has_value());
421+
init_float_input.value()[0] = 100.0f;
422+
init_float_input.value()[1] = 200.0f;
423+
init_float_input.value()[2] = 300.0f;
424+
425+
ASSERT_TRUE(runtime.evaluate());
426+
427+
// float_output = float_input * 2 + 0 (init_float_input defaults to zeros).
428+
auto float_output = runtime.outputBuffer<float>("float_output");
429+
ASSERT_TRUE(float_output.has_value());
430+
EXPECT_FLOAT_EQ(float_output.value()[0], 3.0f);
431+
EXPECT_FLOAT_EQ(float_output.value()[1], 5.0f);
432+
EXPECT_FLOAT_EQ(float_output.value()[2], 7.0f);
433+
}
434+
435+
TEST_F(OnnxRuntimeTest, SubsequentRunsUseInitializerBuffer) {
436+
OnnxRuntime runtime;
437+
ASSERT_TRUE(runtime.initialize(simple_model_path_));
438+
runtime.resetBuffers();
439+
440+
auto float_input = runtime.inputBuffer<float>("float_input");
441+
ASSERT_TRUE(float_input.has_value());
442+
float_input.value()[0] = 1.5f;
443+
float_input.value()[1] = 2.5f;
444+
float_input.value()[2] = 3.5f;
445+
446+
auto init_float_input = runtime.inputBuffer<float>("init_float_input");
447+
ASSERT_TRUE(init_float_input.has_value());
448+
449+
// First run: defaults [0, 0, 0] used regardless of buffer contents.
450+
std::ranges::fill(init_float_input.value(), 999.0f);
451+
ASSERT_TRUE(runtime.evaluate());
452+
453+
auto float_output = runtime.outputBuffer<float>("float_output");
454+
ASSERT_TRUE(float_output.has_value());
455+
EXPECT_FLOAT_EQ(float_output.value()[0], 3.0f); // 1.5 * 2 + 0
456+
EXPECT_FLOAT_EQ(float_output.value()[1], 5.0f); // 2.5 * 2 + 0
457+
EXPECT_FLOAT_EQ(float_output.value()[2], 7.0f); // 3.5 * 2 + 0
458+
459+
// Second run: now the buffer values must override the defaults.
460+
init_float_input.value()[0] = 10.0f;
461+
init_float_input.value()[1] = 20.0f;
462+
init_float_input.value()[2] = 30.0f;
463+
ASSERT_TRUE(runtime.evaluate());
464+
465+
EXPECT_FLOAT_EQ(float_output.value()[0], 13.0f); // 1.5 * 2 + 10
466+
EXPECT_FLOAT_EQ(float_output.value()[1], 25.0f); // 2.5 * 2 + 20
467+
EXPECT_FLOAT_EQ(float_output.value()[2], 37.0f); // 3.5 * 2 + 30
468+
469+
// Third run: confirm the buffer keeps overriding (still not falling back).
470+
init_float_input.value()[0] = -1.0f;
471+
init_float_input.value()[1] = -2.0f;
472+
init_float_input.value()[2] = -3.0f;
473+
ASSERT_TRUE(runtime.evaluate());
474+
475+
EXPECT_FLOAT_EQ(float_output.value()[0], 2.0f); // 1.5 * 2 - 1
476+
EXPECT_FLOAT_EQ(float_output.value()[1], 3.0f); // 2.5 * 2 - 2
477+
EXPECT_FLOAT_EQ(float_output.value()[2], 4.0f); // 3.5 * 2 - 3
478+
}
479+
480+
TEST_F(OnnxRuntimeTest, ResetBuffersReArmsInitializerDefaults) {
481+
OnnxRuntime runtime;
482+
ASSERT_TRUE(runtime.initialize(simple_model_path_));
483+
runtime.resetBuffers();
484+
485+
auto float_input = runtime.inputBuffer<float>("float_input");
486+
auto init_float_input = runtime.inputBuffer<float>("init_float_input");
487+
ASSERT_TRUE(float_input.has_value());
488+
ASSERT_TRUE(init_float_input.has_value());
489+
490+
float_input.value()[0] = 1.5f;
491+
float_input.value()[1] = 2.5f;
492+
float_input.value()[2] = 3.5f;
493+
494+
// Run #1 uses defaults.
495+
ASSERT_TRUE(runtime.evaluate());
496+
// Run #2 uses the buffer.
497+
init_float_input.value()[0] = 10.0f;
498+
init_float_input.value()[1] = 20.0f;
499+
init_float_input.value()[2] = 30.0f;
500+
ASSERT_TRUE(runtime.evaluate());
501+
auto float_output = runtime.outputBuffer<float>("float_output");
502+
ASSERT_TRUE(float_output.has_value());
503+
EXPECT_FLOAT_EQ(float_output.value()[0], 13.0f); // 1.5 * 2 + 10
504+
505+
// After resetBuffers(), the next evaluate() must again fall back to the model's
506+
// initializer defaults, regardless of buffer contents.
507+
runtime.resetBuffers();
508+
// resetBuffers() also zeroes the buffers, so set non-zero values to prove they are
509+
// ignored on the post-reset run. Restore float_input as well (also zeroed by reset).
510+
float_input.value()[0] = 1.5f;
511+
float_input.value()[1] = 2.5f;
512+
float_input.value()[2] = 3.5f;
513+
init_float_input.value()[0] = 999.0f;
514+
init_float_input.value()[1] = 999.0f;
515+
init_float_input.value()[2] = 999.0f;
516+
517+
ASSERT_TRUE(runtime.evaluate());
518+
EXPECT_FLOAT_EQ(float_output.value()[0], 3.0f); // 1.5 * 2 + 0
519+
EXPECT_FLOAT_EQ(float_output.value()[1], 5.0f); // 2.5 * 2 + 0
520+
EXPECT_FLOAT_EQ(float_output.value()[2], 7.0f); // 3.5 * 2 + 0
521+
}
522+
403523
} // namespace exploy::control

control/test/testdata/test_onnx_generator.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66

7+
import numpy as np
78
import onnx
89
import torch
910

@@ -368,9 +369,8 @@ class SimpleTestModel(torch.nn.Module):
368369
def __init__(self):
369370
super().__init__()
370371

371-
def forward(self, float_input, int_input, bool_input):
372-
# Simple pass-through model that just forwards inputs to outputs
373-
float_output = float_input * 2.0 # Simple transformation
372+
def forward(self, float_input, int_input, bool_input, init_float_input):
373+
float_output = float_input * 2.0 + init_float_input
374374
int_output = int_input + 1 # Simple transformation
375375
bool_output = torch.logical_not(bool_input) # Simple transformation
376376

@@ -387,15 +387,24 @@ def export_simple_model(data_dir: str):
387387
float_input = torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32)
388388
int_input = torch.tensor([[10, 20, 30]], dtype=torch.int32)
389389
bool_input = torch.tensor([[True, False, True]], dtype=torch.bool)
390+
# Default values for the overridable initializer baked into the exported model.
391+
default_init_float_input = np.zeros((1, 3), dtype=np.float32)
392+
init_float_input = torch.from_numpy(default_init_float_input)
390393

391394
torch.onnx.export(
392395
simple_model,
393-
(float_input, int_input, bool_input),
396+
(float_input, int_input, bool_input, init_float_input),
394397
output_path_simple,
395-
input_names=["float_input", "int_input", "bool_input"],
398+
input_names=["float_input", "int_input", "bool_input", "init_float_input"],
396399
output_names=["float_output", "int_output", "bool_output"],
397400
)
398401

402+
onnx_model = onnx.load(output_path_simple)
403+
onnx_model.graph.initializer.append(
404+
onnx.numpy_helper.from_array(default_init_float_input, name="init_float_input")
405+
)
406+
onnx.save(onnx_model, output_path_simple)
407+
399408
# Add simple metadata to the simple test model
400409
simple_metadata = {
401410
"model_version": "1.0",

0 commit comments

Comments
 (0)