Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/operators/reduce-nd.c
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,16 @@ static enum xnn_status reshape_reduce_nd(
size_t num_reduction_elements;
if (normalized_reduction_axes[num_reduction_axes - 1] == num_input_dims - 1) {
if (workspace_size != NULL) {
const size_t num_output_elements = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4];
*workspace_size = (num_output_elements << log2_accumulator_element_size) + XNN_EXTRA_BYTES;
size_t num_output_elements;
size_t tmp;
if (__builtin_mul_overflow(normalized_input_shape[0], normalized_input_shape[2], &tmp) ||
__builtin_mul_overflow(tmp, normalized_input_shape[4], &num_output_elements) ||
__builtin_mul_overflow(num_output_elements, (size_t)1 << log2_accumulator_element_size, &tmp)) {
xnn_log_error("failed to reshape %s operator: workspace size overflow",
xnn_operator_type_to_string_v2(reduce_op));
return xnn_status_invalid_parameter;
}
*workspace_size = tmp + XNN_EXTRA_BYTES;
}
num_reduction_elements = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5];
const size_t axis_dim = normalized_input_shape[5];
Expand Down Expand Up @@ -283,8 +291,16 @@ static enum xnn_status reshape_reduce_nd(
// Reduction along the non-innermost dimension
const size_t channel_like_dim = normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1];
if (workspace_size != NULL) {
const size_t num_output_elements = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5];
*workspace_size = (num_output_elements << log2_accumulator_element_size) + XNN_EXTRA_BYTES;
size_t num_output_elements;
size_t tmp;
if (__builtin_mul_overflow(normalized_input_shape[1], normalized_input_shape[3], &tmp) ||
__builtin_mul_overflow(tmp, normalized_input_shape[5], &num_output_elements) ||
__builtin_mul_overflow(num_output_elements, (size_t)1 << log2_accumulator_element_size, &tmp)) {
xnn_log_error("failed to reshape %s operator: workspace size overflow",
xnn_operator_type_to_string_v2(reduce_op));
return xnn_status_invalid_parameter;
}
*workspace_size = tmp + XNN_EXTRA_BYTES;
}
num_reduction_elements = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4];
const size_t axis_dim = normalized_input_shape[4];
Expand Down