Skip to content

Commit c023a0d

Browse files
committed
Sync from upstream TF.
1 parent 7c26381 commit c023a0d

8 files changed

Lines changed: 101 additions & 24 deletions

File tree

tensorflow/lite/core/c/common.h

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ limitations under the License.
5656
#include <stdbool.h>
5757
#include <stddef.h>
5858
#include <stdint.h>
59+
#include <stdio.h>
5960

6061
#include "tensorflow/lite/core/c/c_api_types.h" // IWYU pragma: export
6162

@@ -277,13 +278,34 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
277278
} \
278279
} while (0)
279280

280-
#define TF_LITE_ENSURE_OK(context, status) \
281-
do { \
282-
const TfLiteStatus s = (status); \
283-
if ((s) != kTfLiteOk) { \
284-
return s; \
285-
} \
281+
#ifndef TF_LITE_STRIP_ERROR_STRINGS
282+
#define TF_LITE_VAR_ARG_HEAD(FIRST, ...) FIRST
283+
#define TF_LITE_STRINGIFY_HELPER(x) #x
284+
#define TF_LITE_STRINGIFY(x) TF_LITE_STRINGIFY_HELPER(x)
285+
// Checks that `status` evaluates to `kTfLiteOk`.
286+
//
287+
// Can take a printf style log message and its parameters after the status. The
288+
// message will be printed using `TF_LITE_KERNEL_LOG` in case of error.
289+
#define TF_LITE_ENSURE_OK(context, status, ...) \
290+
do { \
291+
const TfLiteStatus s = (status); \
292+
if (s != kTfLiteOk) { \
293+
if (sizeof(TF_LITE_VAR_ARG_HEAD("" __VA_ARGS__)) > sizeof("")) { \
294+
TF_LITE_MAYBE_KERNEL_LOG((context), __FILE__ ":" TF_LITE_STRINGIFY( \
295+
__LINE__) ": " __VA_ARGS__); \
296+
} \
297+
return s; \
298+
} \
286299
} while (0)
300+
#else
301+
#define TF_LITE_ENSURE_OK(context, status, ...) \
302+
do { \
303+
const TfLiteStatus s = (status); \
304+
if ((s) != kTfLiteOk) { \
305+
return s; \
306+
} \
307+
} while (0)
308+
#endif
287309

288310
// `std::unreachable` not available until CC23.
289311
#ifdef __GNUC__ // GCC, Clang, ICC
@@ -1060,6 +1082,13 @@ typedef struct TfLiteContext {
10601082
/// WARNING: This is an experimental interface that is subject to change.
10611083
TfLiteStatus (*ReleaseSubgraphContext)(struct TfLiteContext* context,
10621084
int subgraph_index);
1085+
#if defined(_WIN32)
1086+
/// Create a array of a given `size` (uninitialized entries).
1087+
TfLiteIntArray* (*TfLiteIntArrayCreate)(int size); // NOLINT
1088+
1089+
/// Free memory of array `a`.
1090+
void (*TfLiteIntArrayFree)(TfLiteIntArray* a); // NOLINT
1091+
#endif // defined(_WIN32)
10631092
} TfLiteContext;
10641093

10651094
/// `TfLiteOperator` is an external version of `TfLiteRegistration`

tensorflow/lite/kernels/internal/reference/concatenation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONCATENATION_H_
1818

1919
#include <algorithm>
20+
#include <cstddef>
2021

2122
#include "tensorflow/lite/kernels/internal/common.h"
2223
#include "tensorflow/lite/kernels/internal/compatibility.h"
@@ -109,7 +110,7 @@ inline void Concatenation<Int4>(const ConcatenationParams& params,
109110
// not garbage.
110111
// Note: output_shape.FlatSize() gives number of elements (nibbles).
111112
// Bytes needed: (elements + 1) / 2.
112-
memset(output_ptr, 0, (output_shape.FlatSize() + 1) / 2);
113+
memset(output_ptr, 0, (static_cast<size_t>(output_shape.FlatSize()) + 1) / 2);
113114

114115
int64_t output_offset = 0;
115116
for (int k = 0; k < outer_size; k++) {

tensorflow/lite/kernels/internal/types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ struct TanhParams {
999999
int input_left_shift;
10001000
};
10011001

1002-
constexpr int kTransposeMaxDimensions = 6;
1002+
constexpr int kTransposeMaxDimensions = 8;
10031003

10041004
struct TransposeParams {
10051005
int8_t perm_count;

tensorflow/lite/kernels/kernel_util.cc

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#ifndef TF_LITE_STATIC_MEMORY
2626
#include <string>
2727

28+
#include "absl/types/span.h"
2829
#include "tensorflow/lite/array.h"
2930
#endif // TF_LITE_STATIC_MEMORY
3031

@@ -33,6 +34,7 @@ limitations under the License.
3334
#include "tensorflow/lite/core/c/common.h"
3435
#include "tensorflow/lite/kernels/internal/cppmath.h"
3536
#include "tensorflow/lite/kernels/internal/quantization_util.h"
37+
#include "tensorflow/lite/util.h"
3638

3739
#if defined(__APPLE__)
3840
#include "TargetConditionals.h"
@@ -101,9 +103,8 @@ inline TfLiteStatus GetMutableInputSafe(const TfLiteContext* context,
101103
const TfLiteNode* node, int index,
102104
const TfLiteTensor** tensor) {
103105
int tensor_index;
104-
TF_LITE_ENSURE_OK(
105-
context, ValidateTensorIndexingSafe(context, index, node->inputs->size,
106-
node->inputs->data, &tensor_index));
106+
TF_LITE_ENSURE_STATUS(ValidateTensorIndexingSafe(
107+
context, index, node->inputs->size, node->inputs->data, &tensor_index));
107108
*tensor = GetTensorAtIndex(context, tensor_index);
108109
return kTfLiteOk;
109110
}
@@ -140,9 +141,8 @@ TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
140141
TfLiteStatus GetOutputSafe(const TfLiteContext* context, const TfLiteNode* node,
141142
int index, TfLiteTensor** tensor) {
142143
int tensor_index;
143-
TF_LITE_ENSURE_OK(
144-
context, ValidateTensorIndexingSafe(context, index, node->outputs->size,
145-
node->outputs->data, &tensor_index));
144+
TF_LITE_ENSURE_STATUS(ValidateTensorIndexingSafe(
145+
context, index, node->outputs->size, node->outputs->data, &tensor_index));
146146
*tensor = GetTensorAtIndex(context, tensor_index);
147147
return kTfLiteOk;
148148
}
@@ -167,8 +167,8 @@ TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
167167
const TfLiteNode* node, int index,
168168
TfLiteTensor** tensor) {
169169
int tensor_index;
170-
TF_LITE_ENSURE_OK(context, ValidateTensorIndexingSafe(
171-
context, index, node->temporaries->size,
170+
TF_LITE_ENSURE_STATUS(
171+
ValidateTensorIndexingSafe(context, index, node->temporaries->size,
172172
node->temporaries->data, &tensor_index));
173173
*tensor = GetTensorAtIndex(context, tensor_index);
174174
return kTfLiteOk;
@@ -188,8 +188,8 @@ TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
188188
const TfLiteNode* node, int index,
189189
TfLiteTensor** tensor) {
190190
int tensor_index;
191-
TF_LITE_ENSURE_OK(context, ValidateTensorIndexingSafe(
192-
context, index, node->intermediates->size,
191+
TF_LITE_ENSURE_STATUS(
192+
ValidateTensorIndexingSafe(context, index, node->intermediates->size,
193193
node->intermediates->data, &tensor_index));
194194
*tensor = GetTensorAtIndex(context, tensor_index);
195195
return kTfLiteOk;
@@ -595,4 +595,25 @@ bool HasUnspecifiedDimension(const TfLiteTensor* tensor) {
595595
return false;
596596
}
597597

598+
TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
599+
absl::Span<const int> dims,
600+
const char* error_message, size_t& product) {
601+
// The CheckedNumElements function already checks for negative dimensions, so
602+
// we don't do it here.
603+
TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk,
604+
"%s", error_message);
605+
return kTfLiteOk;
606+
}
607+
608+
TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context,
609+
absl::Span<const int> dims,
610+
const char* error_message, int& product) {
611+
for (const int dim : dims) {
612+
TF_LITE_ENSURE_MSG(context, dim >= 0, "Encountered a negative dimension.");
613+
}
614+
TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk,
615+
"%s", error_message);
616+
return kTfLiteOk;
617+
}
618+
598619
} // namespace tflite

tensorflow/lite/kernels/kernel_util.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ limitations under the License.
1717

1818
#include <stdint.h>
1919

20+
#include <cstddef>
2021
#include <limits>
2122
#ifndef TF_LITE_STATIC_MEMORY
2223
#include <string>
2324
#endif // TF_LITE_STATIC_MEMORY
2425

26+
#include "absl/types/span.h"
2527
#include "tensorflow/lite/core/c/builtin_op_data.h"
2628
#include "tensorflow/lite/core/c/common.h"
2729
#ifndef NDEBUG
@@ -341,6 +343,30 @@ bool IsMobilePlatform();
341343
// Returns whether there is unspecified dimension in the tensor's dim signature.
342344
bool HasUnspecifiedDimension(const TfLiteTensor* tensor);
343345

346+
/**
347+
* Calculates the product of the given dimensions. Returns an error if any of
348+
* the dimensions is negative or if the product overflows.
349+
* @param context The context to use for error reporting.
350+
* @param dims The dimensions to multiply.
351+
* @param error_message The error message to use if an error is encountered.
352+
* @param product The output parameter to store the product.
353+
*/
354+
TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
355+
absl::Span<const int> dims,
356+
const char* error_message, size_t& product);
357+
358+
/**
359+
* Calculates the product of the given dimensions. Returns an error if any of
360+
* the dimensions is negative or if the product overflows.
361+
* @param context The context to use for error reporting.
362+
* @param dims The dimensions to multiply.
363+
* @param error_message The error message to use if an error is encountered.
364+
* @param product The output parameter to store the product.
365+
*/
366+
TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context,
367+
absl::Span<const int> dims,
368+
const char* error_message, int& product);
369+
344370
} // namespace tflite
345371

346372
#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_

tensorflow/lite/tools/flatbuffer_utils_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import subprocess
1919
import sys
2020

21-
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
22-
from tflite_micro.tensorflow.lite.tools import flatbuffer_utils
23-
from tflite_micro.tensorflow.lite.tools import test_utils
21+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
22+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import flatbuffer_utils
23+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import test_utils
2424
from tensorflow.python.framework import test_util
2525
from tensorflow.python.platform import test
2626

tensorflow/lite/tools/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
import flatbuffers
21-
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
21+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
2222

2323
TFLITE_SCHEMA_VERSION = 3
2424

tensorflow/lite/tools/visualize_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import os
1717
import re
1818

19-
from tflite_micro.tensorflow.lite.tools import test_utils
20-
from tflite_micro.tensorflow.lite.tools import visualize
19+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import test_utils
20+
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import visualize
2121
from tensorflow.python.framework import test_util
2222
from tensorflow.python.platform import test
2323

0 commit comments

Comments
 (0)