@@ -138,13 +138,13 @@ class OnnxModel
138138
139139 // Declaration of a raw graph input fed directly from an Arrow buffer.
140140 struct PreprocInput {
141- enum class Type { TrackFloat, // per-track float column [N]
142- TrackInt32, // per-track int32 column [N]
143- TrackUint8, // per-track uint8 column [N]
144- TrackInt8, // per-track int8 column [N]
145- TrackBool, // per-track bool mask [N]
146- CollisionFloat,// per-collision float array [C]
147- ScalarFloat }; // single scalar (e.g. mass) [1]
141+ enum class Type { TrackFloat, // per-track float column [N]
142+ TrackInt32, // per-track int32 column [N]
143+ TrackUint8, // per-track uint8 column [N]
144+ TrackInt8, // per-track int8 column [N]
145+ TrackBool, // per-track bool mask [N]
146+ CollisionFloat, // per-collision float array [C]
147+ ScalarFloat }; // single scalar (e.g. mass) [1]
148148 std::string name;
149149 Type type;
150150 };
@@ -318,13 +318,30 @@ class OnnxModel
318318 std::vector<int64_t > dims = {-1 };
319319 std::vector<std::string> sym = {" N" };
320320 switch (pin.type ) {
321- case PreprocInput::Type::TrackFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ; break ;
322- case PreprocInput::Type::TrackInt32: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 ; break ;
323- case PreprocInput::Type::TrackUint8: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 ; break ;
324- case PreprocInput::Type::TrackInt8: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 ; break ;
325- case PreprocInput::Type::TrackBool: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL ; break ;
326- case PreprocInput::Type::CollisionFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ; sym = {" C" }; break ;
327- case PreprocInput::Type::ScalarFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ; dims = {1 }; sym = {" " }; break ;
321+ case PreprocInput::Type::TrackFloat:
322+ et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ;
323+ break ;
324+ case PreprocInput::Type::TrackInt32:
325+ et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 ;
326+ break ;
327+ case PreprocInput::Type::TrackUint8:
328+ et = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 ;
329+ break ;
330+ case PreprocInput::Type::TrackInt8:
331+ et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 ;
332+ break ;
333+ case PreprocInput::Type::TrackBool:
334+ et = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL ;
335+ break ;
336+ case PreprocInput::Type::CollisionFloat:
337+ et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ;
338+ sym = {" C" };
339+ break ;
340+ case PreprocInput::Type::ScalarFloat:
341+ et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ;
342+ dims = {1 };
343+ sym = {" " };
344+ break ;
328345 }
329346 Ort::TensorTypeAndShapeInfo tInfo (et, dims, &sym);
330347 auto typeInfo = Ort::TypeInfo::CreateTensorInfo (tInfo.GetConst ());
0 commit comments