@@ -23,6 +23,7 @@ limitations under the License.
2323#include " fixedpoint/fixedpoint.h"
2424#include " tensorflow/lite/kernels/internal/common.h"
2525#include " tensorflow/lite/kernels/internal/compatibility.h"
26+ #include " tensorflow/lite/kernels/internal/reference/broadcast_loop.h"
2627
2728namespace tflite {
2829
@@ -39,7 +40,7 @@ inline void Add(const ArithmeticParams& params,
3940 const int flat_size =
4041 MatchingElementsSize (input1_shape, input2_shape, output_shape);
4142 for (int i = 0 ; i < flat_size; ++i) {
42- output_data[i] = ActivationFunctionWithMinMax (
43+ output_data[i] = ActivationFunctionWithMinMax<T> (
4344 input1_data[i] + input2_data[i], activation_min, activation_max);
4445 }
4546}
@@ -328,6 +329,20 @@ BroadcastAdd6DSlow(const ArithmeticParams& params,
328329 constexpr int kMaxBroadcastDim = 6 ;
329330 T activation_min, activation_max;
330331 GetActivationParams (params, &activation_min, &activation_max);
332+ const int broadcast_rank = std::max (
333+ output_shape.DimensionsCount (),
334+ std::max (input1_shape.DimensionsCount (), input2_shape.DimensionsCount ()));
335+ if (broadcast_rank > kMaxBroadcastDim ) {
336+ ForEachBroadcastedElement (
337+ input1_shape, input2_shape, output_shape,
338+ [&](int output_index, int input1_index, int input2_index) {
339+ output_data[output_index] = ActivationFunctionWithMinMax (
340+ static_cast <T>(input1_data[input1_index] +
341+ input2_data[input2_index]),
342+ activation_min, activation_max);
343+ });
344+ return ;
345+ }
331346
332347 // In Tensorflow, the dimensions are canonically named (batch_number, row,
333348 // col, channel), with extents (batches, height, width, depth), with the
@@ -421,6 +436,19 @@ BroadcastAdd6DSlow(const ArithmeticParams& params,
421436 const RuntimeShape& input2_shape, const T* input2_data,
422437 const RuntimeShape& output_shape, T* output_data) {
423438 constexpr int kMaxBroadcastDim = 6 ;
439+ const int broadcast_rank = std::max (
440+ output_shape.DimensionsCount (),
441+ std::max (input1_shape.DimensionsCount (), input2_shape.DimensionsCount ()));
442+ if (broadcast_rank > kMaxBroadcastDim ) {
443+ ForEachBroadcastedElement (
444+ input1_shape, input2_shape, output_shape,
445+ [&](int output_index, int input1_index, int input2_index) {
446+ AddElementwise (1 , params, input1_data + input1_index,
447+ input2_data + input2_index,
448+ output_data + output_index);
449+ });
450+ return ;
451+ }
424452
425453 // In Tensorflow, the dimensions are canonically named (batch_number, row,
426454 // col, channel), with extents (batches, height, width, depth), with the
0 commit comments