Skip to content

Commit fe809e2

Browse files
committed
Please consider the following formatting changes
1 parent 3da893a commit fe809e2

1 file changed

Lines changed: 31 additions & 14 deletions

File tree

Tools/ML/model.h

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ class OnnxModel
138138

139139
// Declaration of a raw graph input fed directly from an Arrow buffer.
140140
struct PreprocInput {
141-
enum class Type { TrackFloat, // per-track float column [N]
142-
TrackInt32, // per-track int32 column [N]
143-
TrackUint8, // per-track uint8 column [N]
144-
TrackInt8, // per-track int8 column [N]
145-
TrackBool, // per-track bool mask [N]
146-
CollisionFloat,// per-collision float array [C]
147-
ScalarFloat }; // single scalar (e.g. mass) [1]
141+
enum class Type { TrackFloat, // per-track float column [N]
142+
TrackInt32, // per-track int32 column [N]
143+
TrackUint8, // per-track uint8 column [N]
144+
TrackInt8, // per-track int8 column [N]
145+
TrackBool, // per-track bool mask [N]
146+
CollisionFloat, // per-collision float array [C]
147+
ScalarFloat }; // single scalar (e.g. mass) [1]
148148
std::string name;
149149
Type type;
150150
};
@@ -318,13 +318,30 @@ class OnnxModel
318318
std::vector<int64_t> dims = {-1};
319319
std::vector<std::string> sym = {"N"};
320320
switch (pin.type) {
321-
case PreprocInput::Type::TrackFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; break;
322-
case PreprocInput::Type::TrackInt32: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; break;
323-
case PreprocInput::Type::TrackUint8: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; break;
324-
case PreprocInput::Type::TrackInt8: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; break;
325-
case PreprocInput::Type::TrackBool: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; break;
326-
case PreprocInput::Type::CollisionFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; sym = {"C"}; break;
327-
case PreprocInput::Type::ScalarFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; dims = {1}; sym = {""}; break;
321+
case PreprocInput::Type::TrackFloat:
322+
et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
323+
break;
324+
case PreprocInput::Type::TrackInt32:
325+
et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
326+
break;
327+
case PreprocInput::Type::TrackUint8:
328+
et = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
329+
break;
330+
case PreprocInput::Type::TrackInt8:
331+
et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
332+
break;
333+
case PreprocInput::Type::TrackBool:
334+
et = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
335+
break;
336+
case PreprocInput::Type::CollisionFloat:
337+
et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
338+
sym = {"C"};
339+
break;
340+
case PreprocInput::Type::ScalarFloat:
341+
et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
342+
dims = {1};
343+
sym = {""};
344+
break;
328345
}
329346
Ort::TensorTypeAndShapeInfo tInfo(et, dims, &sym);
330347
auto typeInfo = Ort::TypeInfo::CreateTensorInfo(tInfo.GetConst());

0 commit comments

Comments
 (0)