diff --git a/backends/vulkan/runtime/utils/VecUtils.h b/backends/vulkan/runtime/utils/VecUtils.h index d84eb54d2b9..7bf57f0976e 100644 --- a/backends/vulkan/runtime/utils/VecUtils.h +++ b/backends/vulkan/runtime/utils/VecUtils.h @@ -465,24 +465,8 @@ inline ivec4 make_whcn_ivec4(const std::vector& arr) { } /* - * Wrapper around std::accumulate that accumulates values of a container of - * integral types into int64_t. Taken from `multiply_integers` in - * - */ -template < - typename C, - std::enable_if_t::value, int> = 0> -inline int64_t multiply_integers(const C& container) { - return std::accumulate( - container.begin(), - container.end(), - static_cast(1), - std::multiplies<>()); -} - -/* - * Product of integer elements referred to by iterators; accumulates into the - * int64_t datatype. Taken from `multiply_integers` in + * Computes the product of integral values referred to by iterators, + * accumulating into int64_t with overflow checking. Throws on overflow. */ template < typename Iter, @@ -491,11 +475,30 @@ template < typename std::iterator_traits::value_type>::value, int> = 0> inline int64_t multiply_integers(Iter begin, Iter end) { - // std::accumulate infers return type from `init` type, so if the `init` type - // is not large enough to hold the result, computation can overflow. We use - // `int64_t` here to avoid this. - return std::accumulate( - begin, end, static_cast(1), std::multiplies<>()); + int64_t result = 1; + for (Iter it = begin; it != end; ++it) { + const int64_t val = static_cast(*it); + VK_CHECK_COND(val >= 0, "Negative value in multiply_integers"); + if (val == 0) { + return 0; + } + VK_CHECK_COND( + result <= std::numeric_limits::max() / val, + "Integer overflow in multiply_integers"); + result *= val; + } + return result; +} + +/* + * Computes the product of integral values in a container, accumulating into + * int64_t with overflow checking. Throws on overflow. + */ +template < + typename C, + std::enable_if_t::value, int> = 0> +inline int64_t multiply_integers(const C& container) { + return multiply_integers(container.begin(), container.end()); } class WorkgroupSize final {