@@ -14,6 +14,46 @@ namespace exploy::control {
1414
1515namespace {
1616
17+ enum class TensorKind { Input, Output, Initializer };
18+
19+ std::size_t getTensorCount (const Ort::Session& session, TensorKind kind) {
20+ switch (kind) {
21+ case TensorKind::Input:
22+ return session.GetInputCount ();
23+ case TensorKind::Output:
24+ return session.GetOutputCount ();
25+ case TensorKind::Initializer:
26+ return session.GetOverridableInitializerCount ();
27+ }
28+ return 0 ;
29+ }
30+
31+ Ort::AllocatedStringPtr getTensorNameAllocated (Ort::Session& session,
32+ Ort::AllocatorWithDefaultOptions& allocator,
33+ TensorKind kind, std::size_t index) {
34+ switch (kind) {
35+ case TensorKind::Input:
36+ return session.GetInputNameAllocated (index, allocator);
37+ case TensorKind::Output:
38+ return session.GetOutputNameAllocated (index, allocator);
39+ case TensorKind::Initializer:
40+ return session.GetOverridableInitializerNameAllocated (index, allocator);
41+ }
42+ return Ort::AllocatedStringPtr{nullptr , Ort::detail::AllocatedFree{allocator}};
43+ }
44+
45+ Ort::TypeInfo getTensorTypeInfo (Ort::Session& session, TensorKind kind, std::size_t index) {
46+ switch (kind) {
47+ case TensorKind::Input:
48+ return session.GetInputTypeInfo (index);
49+ case TensorKind::Output:
50+ return session.GetOutputTypeInfo (index);
51+ case TensorKind::Initializer:
52+ return session.GetOverridableInitializerTypeInfo (index);
53+ }
54+ return Ort::TypeInfo{nullptr };
55+ }
56+
1757void resetTensorBuffer (Ort::Value& tensor, ONNXTensorElementDataType data_type) {
1858 const auto count = tensor.GetTensorTypeAndShapeInfo ().GetElementCount ();
1959 switch (data_type) {
@@ -36,34 +76,35 @@ void resetTensorBuffer(Ort::Value& tensor, ONNXTensorElementDataType data_type)
3676}
3777
3878template <typename TensorDataType>
39- void initializeTensorData (TensorDataType& tensor_data, std::unique_ptr<Ort::Session>& session,
40- Ort::AllocatorWithDefaultOptions& allocator,
41- std::unordered_map<std::string, int >& names_to_index, bool is_input) {
42- tensor_data.size = is_input ? session->GetInputCount () : session->GetOutputCount ();
43-
44- tensor_data.names .reserve (tensor_data.size );
45- tensor_data.shapes .reserve (tensor_data.size );
46- tensor_data.data_types .reserve (tensor_data.size );
47- tensor_data.tensors .reserve (tensor_data.size );
48- tensor_data.allocated_names .reserve (tensor_data.size );
49-
50- for (std::size_t n = 0 ; n < tensor_data.size ; n++) {
51- auto name_ptr = is_input ? session->GetInputNameAllocated (n, allocator)
52- : session->GetOutputNameAllocated (n, allocator);
53- tensor_data.allocated_names .push_back (std::move (name_ptr));
79+ void appendTensorData (TensorDataType& tensor_data, std::unique_ptr<Ort::Session>& session,
80+ Ort::AllocatorWithDefaultOptions& allocator,
81+ std::unordered_map<std::string, int >& names_to_index, TensorKind kind) {
82+ const std::size_t count = getTensorCount (*session, kind);
83+
84+ const std::size_t new_size = tensor_data.size + count;
85+ tensor_data.names .reserve (new_size);
86+ tensor_data.shapes .reserve (new_size);
87+ tensor_data.data_types .reserve (new_size);
88+ tensor_data.tensors .reserve (new_size);
89+ tensor_data.allocated_names .reserve (new_size);
90+
91+ for (std::size_t n = 0 ; n < count; n++) {
92+ tensor_data.allocated_names .push_back (getTensorNameAllocated (*session, allocator, kind, n));
5493 tensor_data.names .push_back (tensor_data.allocated_names .back ().get ());
5594
56- auto type_info = is_input ? session-> GetInputTypeInfo (n) : session-> GetOutputTypeInfo ( n);
95+ auto type_info = getTensorTypeInfo (*session, kind, n);
5796 auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
58-
5997 tensor_data.shapes .push_back (tensor_info.GetShape ());
6098 tensor_data.data_types .push_back (tensor_info.GetElementType ());
6199
62- tensor_data.tensors .push_back (Ort::Value::CreateTensor (allocator, tensor_data.shapes [n].data (),
63- tensor_data.shapes [n].size (),
64- tensor_data.data_types [n]));
100+ tensor_data.tensors .push_back (
101+ Ort::Value::CreateTensor (allocator, tensor_data.shapes .back ().data (),
102+ tensor_data.shapes .back ().size (), tensor_data.data_types .back ()));
103+
104+ resetTensorBuffer (tensor_data.tensors .back (), tensor_data.data_types .back ());
65105
66- names_to_index[std::string (tensor_data.names .back ())] = n;
106+ names_to_index[std::string (tensor_data.names .back ())] = static_cast <int >(tensor_data.size );
107+ tensor_data.size ++;
67108 }
68109}
69110
@@ -113,22 +154,37 @@ bool OnnxRuntime::initialize(const std::string& model_path, const OnnxRuntimeOpt
113154 break ;
114155 }
115156
116- initializeTensorData (input_, session_, allocator_, input_names_to_index_, /* is_input=*/ true );
117- initializeTensorData (output_, session_, allocator_, output_names_to_index_, /* is_input=*/ false );
157+ input_ = TensorData{};
158+ output_ = TensorData{};
159+ input_names_to_index_.clear ();
160+ output_names_to_index_.clear ();
161+ non_initializer_input_count_ = 0 ;
118162
163+ // Append initializer-backed inputs after regular inputs so that we can optionally let ONNX
164+ // Runtime use the model's default values for them after a reset.
165+ appendTensorData (input_, session_, allocator_, input_names_to_index_, TensorKind::Input);
166+ appendTensorData (input_, session_, allocator_, input_names_to_index_, TensorKind::Initializer);
167+ appendTensorData (output_, session_, allocator_, output_names_to_index_, TensorKind::Output);
168+ non_initializer_input_count_ = getTensorCount (*session_, TensorKind::Input);
169+ use_initializers_ = true ;
119170 metadata_ = session_->GetModelMetadata ();
120171
121172 return true ;
122173}
123174
124175bool OnnxRuntime::evaluate () {
176+ // If use_initializers_ is true, we pass only the leading non-initializer inputs to let ONNX
177+ // Runtime use the model's default values for the rest. After the first run, we always pass all
178+ // inputs and ignore the model defaults.
179+ const std::size_t input_count = use_initializers_ ? non_initializer_input_count_ : input_.size ;
125180 try {
126- session_->Run (run_options_, input_.names .data (), input_.tensors .data (), input_. size ,
181+ session_->Run (run_options_, input_.names .data (), input_.tensors .data (), input_count ,
127182 output_.names .data (), output_.tensors .data (), output_.size );
128183 } catch (const Ort::Exception& e) {
129184 LOG_STREAM (ERROR , " ONNX Runtime evaluation failed: " << e.what ());
130185 return false ;
131186 }
187+ use_initializers_ = false ;
132188 return true ;
133189}
134190
@@ -144,6 +200,7 @@ void OnnxRuntime::resetBuffers() {
144200 for (std::size_t n = 0 ; n < output_.size ; n++) {
145201 resetTensorBuffer (output_.tensors [n], output_.data_types [n]);
146202 }
203+ use_initializers_ = true ;
147204}
148205
149206std::unordered_set<std::string> OnnxRuntime::inputNames () const {
0 commit comments