Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tensorflow/lite/micro/kernels/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,28 @@ TfLiteStatus GatherEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetEvalOutput(context, node, kOutputTensor);

if (coords->type == kTfLiteInt32) {
// The reference Gather() below indexes `input` with the values in `coords`
// and only guards them with TFLITE_DCHECK, which is compiled out in release
// (-DNDEBUG) builds. Validate the index values at runtime so that an
// out-of-range index supplied as a model's runtime input fails closed with
// kTfLiteError instead of reading out of bounds. This mirrors the runtime
// index validation performed by the full TFLite GATHER kernel.
const TfLiteIntArray* input_dims = input->dims;
int axis = params->axis;
if (axis < 0) {
axis += input_dims->size;
}
TF_LITE_ENSURE(context, axis >= 0 && axis < input_dims->size);
const int32_t axis_size = input_dims->data[axis];
const int32_t* coords_data = tflite::micro::GetTensorData<int32_t>(coords);
int num_coords = 1;
for (int i = 0; i < coords->dims->size; ++i) {
num_coords *= coords->dims->data[i];
}
for (int i = 0; i < num_coords; ++i) {
TF_LITE_ENSURE(context, coords_data[i] >= 0 && coords_data[i] < axis_size);
}

switch (input->type) {
case kTfLiteFloat32:
return Gather<float, int32_t>(params, input, coords, output);
Expand Down
62 changes: 62 additions & 0 deletions tensorflow/lite/micro/kernels/gather_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,40 @@ void TestGather(int* input_dims, const InType* input_data, int* positions_dims,
}
}

// Builds a GATHER op with the given (out-of-range) index value and checks that
// Invoke() fails closed with kTfLiteError instead of reading the input tensor
// out of bounds. The reference kernel only guards index values with
// TFLITE_DCHECK, which is a no-op in release (-DNDEBUG) builds.
template <typename InType, typename PosType>
void TestGatherFailsForOutOfRangeIndex(
int* input_dims, const InType* input_data, int* positions_dims,
const PosType* positions_data, int* output_dims, InType* output_data,
const int axis = 0, const int batch_dims = 0) {
TfLiteIntArray* in_dims = IntArrayFromInts(input_dims);
TfLiteIntArray* pos_dims = IntArrayFromInts(positions_dims);
TfLiteIntArray* out_dims = IntArrayFromInts(output_dims);
TfLiteGatherParams params = {axis, batch_dims};

constexpr int tensors_size = 3;
TfLiteTensor tensors[tensors_size] = {
CreateTensor(input_data, in_dims),
CreateTensor(positions_data, pos_dims),
CreateTensor(output_data, out_dims, true),
};
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);

const TFLMRegistration registration = Register_GATHER();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, &params);
// Prepare does not inspect index values, so it still succeeds.
EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
// Invoke must reject the out-of-range index rather than read out of bounds.
EXPECT_EQ(kTfLiteError, runner.Invoke());
}

} // namespace
} // namespace testing
} // namespace tflite
Expand Down Expand Up @@ -458,4 +492,32 @@ TEST(GatherTest, GatherOp_BatchDimsEqualIndexDims) {
output_data, golden_dims, golden_data, axis, batch_dims);
}

TEST(GatherTest, GatherOp_IndexGreaterThanAxisSizeFailsClosed) {
// axis 0 has size 3, so index 3 is out of range and must be rejected at
// runtime rather than reading past the end of the input tensor.
int input_dims[] = {2, 3, 4};
const float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int positions_dims[] = {1, 1};
const int32_t positions_data[] = {3};
float output_data[4];
int output_dims[] = {2, 0, 0};
tflite::testing::TestGatherFailsForOutOfRangeIndex<float, int32_t>(
input_dims, input_data, positions_dims, positions_data, output_dims,
output_data);
}

TEST(GatherTest, GatherOp_NegativeIndexFailsClosed) {
// A negative index must be rejected at runtime rather than reading before the
// start of the input tensor.
int input_dims[] = {2, 3, 4};
const float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int positions_dims[] = {1, 1};
const int32_t positions_data[] = {-1};
float output_data[4];
int output_dims[] = {2, 0, 0};
tflite::testing::TestGatherFailsForOutOfRangeIndex<float, int32_t>(
input_dims, input_data, positions_dims, positions_data, output_dims,
output_data);
}

TF_LITE_MICRO_TESTS_MAIN
Loading