Skip to content

Commit fded1da

Browse files
committed
Sync from upstream TF.
1 parent 9f5ac25 commit fded1da

7 files changed

Lines changed: 64 additions & 7 deletions

File tree

tensorflow/lite/core/c/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,13 @@ typedef struct TfLiteContext {
10601060
/// WARNING: This is an experimental interface that is subject to change.
10611061
TfLiteStatus (*ReleaseSubgraphContext)(struct TfLiteContext* context,
10621062
int subgraph_index);
1063+
#if defined(_WIN32)
1064+
/// Create a array of a given `size` (uninitialized entries).
1065+
TfLiteIntArray* (*TfLiteIntArrayCreate)(int size); // NOLINT
1066+
1067+
/// Free memory of array `a`.
1068+
void (*TfLiteIntArrayFree)(TfLiteIntArray* a); // NOLINT
1069+
#endif // defined(_WIN32)
10631070
} TfLiteContext;
10641071

10651072
/// `TfLiteOperator` is an external version of `TfLiteRegistration`

tensorflow/lite/kernels/internal/reference/concatenation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONCATENATION_H_
1818

1919
#include <algorithm>
20+
#include <cstddef>
2021

2122
#include "tensorflow/lite/kernels/internal/common.h"
2223
#include "tensorflow/lite/kernels/internal/compatibility.h"
@@ -109,7 +110,7 @@ inline void Concatenation<Int4>(const ConcatenationParams& params,
109110
// not garbage.
110111
// Note: output_shape.FlatSize() gives number of elements (nibbles).
111112
// Bytes needed: (elements + 1) / 2.
112-
memset(output_ptr, 0, (output_shape.FlatSize() + 1) / 2);
113+
memset(output_ptr, 0, (static_cast<size_t>(output_shape.FlatSize()) + 1) / 2);
113114

114115
int64_t output_offset = 0;
115116
for (int k = 0; k < outer_size; k++) {

tensorflow/lite/kernels/kernel_util.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#ifndef TF_LITE_STATIC_MEMORY
2626
#include <string>
2727

28+
#include "absl/types/span.h"
2829
#include "tensorflow/lite/array.h"
2930
#endif // TF_LITE_STATIC_MEMORY
3031

@@ -33,6 +34,7 @@ limitations under the License.
3334
#include "tensorflow/lite/core/c/common.h"
3435
#include "tensorflow/lite/kernels/internal/cppmath.h"
3536
#include "tensorflow/lite/kernels/internal/quantization_util.h"
37+
#include "tensorflow/lite/util.h"
3638

3739
#if defined(__APPLE__)
3840
#include "TargetConditionals.h"
@@ -595,4 +597,25 @@ bool HasUnspecifiedDimension(const TfLiteTensor* tensor) {
595597
return false;
596598
}
597599

600+
TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
601+
absl::Span<const int> dims,
602+
const char* error_message, size_t& product) {
603+
// The CheckedNumElements function already checks for negative dimensions, so
604+
// we don't do it here.
605+
TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk,
606+
"%s", error_message);
607+
return kTfLiteOk;
608+
}
609+
610+
TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context,
611+
absl::Span<const int> dims,
612+
const char* error_message, int& product) {
613+
for (const int dim : dims) {
614+
TF_LITE_ENSURE_MSG(context, dim >= 0, "Encountered a negative dimension.");
615+
}
616+
TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk,
617+
"%s", error_message);
618+
return kTfLiteOk;
619+
}
620+
598621
} // namespace tflite

tensorflow/lite/kernels/kernel_util.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ limitations under the License.
1717

1818
#include <stdint.h>
1919

20+
#include <cstddef>
2021
#include <limits>
2122
#ifndef TF_LITE_STATIC_MEMORY
2223
#include <string>
2324
#endif // TF_LITE_STATIC_MEMORY
2425

26+
#include "absl/types/span.h"
2527
#include "tensorflow/lite/core/c/builtin_op_data.h"
2628
#include "tensorflow/lite/core/c/common.h"
2729
#ifndef NDEBUG
@@ -341,6 +343,30 @@ bool IsMobilePlatform();
341343
// Returns whether there is unspecified dimension in the tensor's dim signature.
342344
bool HasUnspecifiedDimension(const TfLiteTensor* tensor);
343345

346+
/**
347+
* Calculates the product of the given dimensions. Returns an error if any of
348+
* the dimensions is negative or if the product overflows.
349+
* @param context The context to use for error reporting.
350+
* @param dims The dimensions to multiply.
351+
* @param error_message The error message to use if an error is encountered.
352+
* @param product The output parameter to store the product.
353+
*/
354+
TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
355+
absl::Span<const int> dims,
356+
const char* error_message, size_t& product);
357+
358+
/**
359+
* Calculates the product of the given dimensions. Returns an error if any of
360+
* the dimensions is negative or if the product overflows.
361+
* @param context The context to use for error reporting.
362+
* @param dims The dimensions to multiply.
363+
* @param error_message The error message to use if an error is encountered.
364+
* @param product The output parameter to store the product.
365+
*/
366+
TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context,
367+
absl::Span<const int> dims,
368+
const char* error_message, int& product);
369+
344370
} // namespace tflite
345371

346372
#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_

tensorflow/lite/tools/flatbuffer_utils_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import subprocess
1919
import sys
2020

21-
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
22-
from tflite_micro.tensorflow.lite.tools import flatbuffer_utils
23-
from tflite_micro.tensorflow.lite.tools import test_utils
21+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
22+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import flatbuffer_utils
23+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import test_utils
2424
from tensorflow.python.framework import test_util
2525
from tensorflow.python.platform import test
2626

tensorflow/lite/tools/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
import flatbuffers
21-
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
21+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
2222

2323
TFLITE_SCHEMA_VERSION = 3
2424

tensorflow/lite/tools/visualize_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import os
1717
import re
1818

19-
from tflite_micro.tensorflow.lite.tools import test_utils
20-
from tflite_micro.tensorflow.lite.tools import visualize
19+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import test_utils
20+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import visualize
2121
from tensorflow.python.framework import test_util
2222
from tensorflow.python.platform import test
2323

0 commit comments

Comments
 (0)