1212
1313#include < executorch/backends/vulkan/runtime/vk_api/Exception.h>
1414
15+ #include < c10/util/safe_numerics.h>
16+
1517#include < cmath>
1618#include < limits>
17- #include < numeric>
1819#include < type_traits>
1920
2021namespace vkcompute {
@@ -465,24 +466,8 @@ inline ivec4 make_whcn_ivec4(const std::vector<int64_t>& arr) {
465466}
466467
467468/*
468- * Wrapper around std::accumulate that accumulates values of a container of
469- * integral types into int64_t. Taken from `multiply_integers` in
470- * <c10/util/accumulate.h>
471- */
472- template <
473- typename C,
474- std::enable_if_t <std::is_integral<typename C::value_type>::value, int > = 0 >
475- inline int64_t multiply_integers (const C& container) {
476- return std::accumulate (
477- container.begin (),
478- container.end (),
479- static_cast <int64_t >(1 ),
480- std::multiplies<>());
481- }
482-
483- /*
484- * Product of integer elements referred to by iterators; accumulates into the
485- * int64_t datatype. Taken from `multiply_integers` in <c10/util/accumulate.h>
469+ * Computes the product of integral values referred to by iterators,
470+ * accumulating into int64_t with overflow checking. Throws on overflow.
486471 */
487472template <
488473 typename Iter,
@@ -491,11 +476,24 @@ template <
491476 typename std::iterator_traits<Iter>::value_type>::value,
492477 int > = 0 >
493478inline int64_t multiply_integers (Iter begin, Iter end) {
494- // std::accumulate infers return type from `init` type, so if the `init` type
495- // is not large enough to hold the result, computation can overflow. We use
496- // `int64_t` here to avoid this.
497- return std::accumulate (
498- begin, end, static_cast <int64_t >(1 ), std::multiplies<>());
479+ int64_t result = 1 ;
480+ for (Iter it = begin; it != end; ++it) {
481+ VK_CHECK_COND (
482+ !c10::mul_overflows (result, static_cast <int64_t >(*it), &result),
483+ " Integer overflow in multiply_integers" );
484+ }
485+ return result;
486+ }
487+
488+ /*
489+ * Computes the product of integral values in a container, accumulating into
490+ * int64_t with overflow checking. Throws on overflow.
491+ */
492+ template <
493+ typename C,
494+ std::enable_if_t <std::is_integral<typename C::value_type>::value, int > = 0 >
495+ inline int64_t multiply_integers (const C& container) {
496+ return multiply_integers (container.begin (), container.end ());
499497}
500498
501499class WorkgroupSize final {
0 commit comments