Skip to content

Commit 46e025c

Browse files
committed
Make inference const to allow inference inside const objects that include a OrtModel
1 parent 8361c42 commit 46e025c

2 files changed

Lines changed: 23 additions & 22 deletions

File tree

Common/ML/include/ML/OrtInterface.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,16 @@ class OrtModel
9191

9292
// Inferencing
9393
template <class I, class O> // class I is the input data type, e.g. float, class O is the output data type, e.g. OrtDataType::Float16_t from O2/Common/ML/include/ML/GPUORTFloat16.h
94-
std::vector<O> inference(std::vector<I>&);
94+
std::vector<O> inference(std::vector<I>&) const;
9595

9696
template <class I, class O>
97-
std::vector<O> inference(std::vector<std::vector<I>>&);
97+
std::vector<O> inference(std::vector<std::vector<I>>&) const;
9898

9999
template <class I, class O>
100-
void inference(I*, int64_t, O*);
100+
void inference(I*, int64_t, O*) const;
101101

102102
template <class I, class O>
103-
void inference(I**, int64_t, O*);
103+
void inference(I**, int64_t, O*) const;
104104

105105
void release(bool = false);
106106

@@ -112,7 +112,8 @@ class OrtModel
112112
// Input & Output specifications of the loaded network
113113
std::vector<const char*> mInputNamesChar, mOutputNamesChar;
114114
std::vector<std::string> mInputNames, mOutputNames;
115-
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes, mInputShapesCopy, mOutputShapesCopy; // Input shapes
115+
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes;
116+
mutable std::vector<std::vector<int64_t>> mInputShapesCopy, mOutputShapesCopy; // Input shapes
116117
std::vector<int64_t> mInputSizePerNode, mOutputSizePerNode; // Output shapes
117118
int32_t mInputsTotal = 0, mOutputsTotal = 0; // Total number of inputs and outputs
118119

Common/ML/src/OrtInterface.cxx

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ void OrtModel::setEnv(Ort::Env* env)
289289

290290
// Inference
291291
template <class I, class O>
292-
std::vector<O> OrtModel::inference(std::vector<I>& input)
292+
std::vector<O> OrtModel::inference(std::vector<I>& input) const
293293
{
294294
std::vector<int64_t> inputShape = mInputShapes[0];
295295
inputShape[0] = input.size();
@@ -310,12 +310,12 @@ std::vector<O> OrtModel::inference(std::vector<I>& input)
310310
return outputValuesVec;
311311
}
312312

313-
template std::vector<float> o2::ml::OrtModel::inference<float, float>(std::vector<float>&);
314-
template std::vector<float> o2::ml::OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&);
315-
template std::vector<OrtDataType::Float16_t> o2::ml::OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
313+
template std::vector<float> o2::ml::OrtModel::inference<float, float>(std::vector<float>&) const;
314+
template std::vector<float> o2::ml::OrtModel::inference<OrtDataType::Float16_t, float>(std::vector<OrtDataType::Float16_t>&) const;
315+
template std::vector<OrtDataType::Float16_t> o2::ml::OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&) const;
316316

317317
template <class I, class O>
318-
void OrtModel::inference(I* input, int64_t input_size, O* output)
318+
void OrtModel::inference(I* input, int64_t input_size, O* output) const
319319
{
320320
// std::vector<std::string> providers = Ort::GetAvailableProviders();
321321
// for (const auto& provider : providers) {
@@ -350,13 +350,13 @@ void OrtModel::inference(I* input, int64_t input_size, O* output)
350350
// mOutputNamesChar.size());
351351
}
352352

353-
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*);
354-
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*);
355-
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*);
356-
template void OrtModel::inference<float, float>(float*, int64_t, float*);
353+
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t*, int64_t, OrtDataType::Float16_t*) const;
354+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*) const;
355+
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*) const;
356+
template void OrtModel::inference<float, float>(float*, int64_t, float*) const;
357357

358358
template <class I, class O>
359-
void OrtModel::inference(I** input, int64_t input_size, O* output)
359+
void OrtModel::inference(I** input, int64_t input_size, O* output) const
360360
{
361361
std::vector<Ort::Value> inputTensors(mInputShapesCopy.size());
362362

@@ -410,13 +410,13 @@ void OrtModel::inference(I** input, int64_t input_size, O* output)
410410
mOutputNamesChar.size());
411411
}
412412

413-
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, int64_t, OrtDataType::Float16_t*);
414-
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*);
415-
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*);
416-
template void OrtModel::inference<float, float>(float**, int64_t, float*);
413+
template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(OrtDataType::Float16_t**, int64_t, OrtDataType::Float16_t*) const;
414+
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*) const;
415+
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*) const;
416+
template void OrtModel::inference<float, float>(float**, int64_t, float*) const;
417417

418418
template <class I, class O>
419-
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)
419+
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs) const
420420
{
421421
std::vector<Ort::Value> input_tensors;
422422

@@ -461,8 +461,8 @@ std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)
461461
return output_vec;
462462
}
463463

464-
template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&);
465-
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&);
464+
template std::vector<float> OrtModel::inference<float, float>(std::vector<std::vector<float>>&) const;
465+
template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>&) const;
466466

467467
// Release session
468468
void OrtModel::release(bool profilingEnabled)

0 commit comments

Comments
 (0)