@@ -25,7 +25,6 @@ limitations under the License.
2525#ifndef TF_LITE_STATIC_MEMORY
2626#include < string>
2727
28- #include " absl/types/span.h"
2928#include " tensorflow/lite/array.h"
3029#endif // TF_LITE_STATIC_MEMORY
3130
@@ -35,9 +34,7 @@ limitations under the License.
3534#include " tensorflow/lite/kernels/internal/cppmath.h"
3635#include " tensorflow/lite/kernels/internal/quantization_util.h"
3736
38- #ifndef TF_LITE_STATIC_MEMORY
39- #include " tensorflow/lite/util.h"
40- #endif
37+
4138
4239#if defined(__APPLE__)
4340#include " TargetConditionals.h"
@@ -598,27 +595,43 @@ bool HasUnspecifiedDimension(const TfLiteTensor* tensor) {
598595 return false ;
599596}
600597
601- #ifndef TF_LITE_STATIC_MEMORY
602598TfLiteStatus CheckedShapeProduct (TfLiteContext* context,
603- absl::Span< const int > dims,
599+ std::initializer_list< int > dims,
604600 const char * error_message, size_t & product) {
605- // The CheckedNumElements function already checks for negative dimensions, so
606- // we don't do it here.
607- TF_LITE_ENSURE_MSG (context, CheckedNumElements (dims, product) == kTfLiteOk ,
608- " %s" , error_message);
601+ size_t checked_count = 1 ;
602+ for (const int d : dims) {
603+ if (d < 0 ) {
604+ TF_LITE_ENSURE_MSG (context, false , " %s" , error_message);
605+ }
606+ if (checked_count > 0 &&
607+ static_cast <size_t >(d) > std::numeric_limits<size_t >::max () / checked_count) {
608+ TF_LITE_ENSURE_MSG (context, false , " %s" , error_message);
609+ }
610+ checked_count *= d;
611+ }
612+ product = checked_count;
609613 return kTfLiteOk ;
610614}
611615
612616TfLiteStatus CheckedShapeProductToInt (TfLiteContext* context,
613- absl::Span< const int > dims,
617+ std::initializer_list< int > dims,
614618 const char * error_message, int & product) {
615- for (const int dim : dims) {
616- TF_LITE_ENSURE_MSG (context, dim >= 0 , " Encountered a negative dimension." );
619+ size_t checked_count = 1 ;
620+ for (const int d : dims) {
621+ if (d < 0 ) {
622+ TF_LITE_ENSURE_MSG (context, false , " Encountered a negative dimension." );
623+ }
624+ if (checked_count > 0 &&
625+ static_cast <size_t >(d) > std::numeric_limits<size_t >::max () / checked_count) {
626+ TF_LITE_ENSURE_MSG (context, false , " %s" , error_message);
627+ }
628+ checked_count *= d;
617629 }
618- TF_LITE_ENSURE_MSG (context, CheckedNumElements (dims, product) == kTfLiteOk ,
619- " %s" , error_message);
630+ if (checked_count > std::numeric_limits<int >::max ()) {
631+ TF_LITE_ENSURE_MSG (context, false , " %s" , error_message);
632+ }
633+ product = static_cast <int >(checked_count);
620634 return kTfLiteOk ;
621635}
622- #endif
623636
624637} // namespace tflite
0 commit comments