@@ -25,6 +25,7 @@ limitations under the License.
2525#ifndef TF_LITE_STATIC_MEMORY
2626#include < string>
2727
28+ #include " absl/types/span.h"
2829#include " tensorflow/lite/array.h"
2930#endif // TF_LITE_STATIC_MEMORY
3031
@@ -33,6 +34,7 @@ limitations under the License.
3334#include " tensorflow/lite/core/c/common.h"
3435#include " tensorflow/lite/kernels/internal/cppmath.h"
3536#include " tensorflow/lite/kernels/internal/quantization_util.h"
37+ #include " tensorflow/lite/util.h"
3638
3739#if defined(__APPLE__)
3840#include " TargetConditionals.h"
@@ -101,9 +103,8 @@ inline TfLiteStatus GetMutableInputSafe(const TfLiteContext* context,
101103 const TfLiteNode* node, int index,
102104 const TfLiteTensor** tensor) {
103105 int tensor_index;
104- TF_LITE_ENSURE_OK (
105- context, ValidateTensorIndexingSafe (context, index, node->inputs ->size ,
106- node->inputs ->data , &tensor_index));
106+ TF_LITE_ENSURE_STATUS (ValidateTensorIndexingSafe (
107+ context, index, node->inputs ->size , node->inputs ->data , &tensor_index));
107108 *tensor = GetTensorAtIndex (context, tensor_index);
108109 return kTfLiteOk ;
109110}
@@ -140,9 +141,8 @@ TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
140141TfLiteStatus GetOutputSafe (const TfLiteContext* context, const TfLiteNode* node,
141142 int index, TfLiteTensor** tensor) {
142143 int tensor_index;
143- TF_LITE_ENSURE_OK (
144- context, ValidateTensorIndexingSafe (context, index, node->outputs ->size ,
145- node->outputs ->data , &tensor_index));
144+ TF_LITE_ENSURE_STATUS (ValidateTensorIndexingSafe (
145+ context, index, node->outputs ->size , node->outputs ->data , &tensor_index));
146146 *tensor = GetTensorAtIndex (context, tensor_index);
147147 return kTfLiteOk ;
148148}
@@ -167,8 +167,8 @@ TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
167167 const TfLiteNode* node, int index,
168168 TfLiteTensor** tensor) {
169169 int tensor_index;
170- TF_LITE_ENSURE_OK (context, ValidateTensorIndexingSafe (
171- context, index, node->temporaries ->size ,
170+ TF_LITE_ENSURE_STATUS (
171+ ValidateTensorIndexingSafe ( context, index, node->temporaries ->size ,
172172 node->temporaries ->data , &tensor_index));
173173 *tensor = GetTensorAtIndex (context, tensor_index);
174174 return kTfLiteOk ;
@@ -188,8 +188,8 @@ TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
188188 const TfLiteNode* node, int index,
189189 TfLiteTensor** tensor) {
190190 int tensor_index;
191- TF_LITE_ENSURE_OK (context, ValidateTensorIndexingSafe (
192- context, index, node->intermediates ->size ,
191+ TF_LITE_ENSURE_STATUS (
192+ ValidateTensorIndexingSafe ( context, index, node->intermediates ->size ,
193193 node->intermediates ->data , &tensor_index));
194194 *tensor = GetTensorAtIndex (context, tensor_index);
195195 return kTfLiteOk ;
@@ -595,4 +595,25 @@ bool HasUnspecifiedDimension(const TfLiteTensor* tensor) {
595595 return false ;
596596}
597597
598+ TfLiteStatus CheckedShapeProduct (TfLiteContext* context,
599+ absl::Span<const int > dims,
600+ const char * error_message, size_t & product) {
601+ // The CheckedNumElements function already checks for negative dimensions, so
602+ // we don't do it here.
603+ TF_LITE_ENSURE_MSG (context, CheckedNumElements (dims, product) == kTfLiteOk ,
604+ " %s" , error_message);
605+ return kTfLiteOk ;
606+ }
607+
608+ TfLiteStatus CheckedShapeProductToInt (TfLiteContext* context,
609+ absl::Span<const int > dims,
610+ const char * error_message, int & product) {
611+ for (const int dim : dims) {
612+ TF_LITE_ENSURE_MSG (context, dim >= 0 , " Encountered a negative dimension." );
613+ }
614+ TF_LITE_ENSURE_MSG (context, CheckedNumElements (dims, product) == kTfLiteOk ,
615+ " %s" , error_message);
616+ return kTfLiteOk ;
617+ }
618+
598619} // namespace tflite
0 commit comments