Accumulate CPU half-precision sums in float32#3488
Accumulate CPU half-precision sums in float32#3488sofinvalery wants to merge 1 commit intoml-explore:mainfrom
Conversation
6aee023 to
10188e2
Compare
|
Made it inline. Feels cleaner. |
zcbenz
left a comment
There was a problem hiding this comment.
Allocating a new array would have heavy performance penalty, the correct way would be refactoring strided_reduce/contiguous_reduce to accumulate in float32 rather than the output type.
10188e2 to
f869884
Compare
|
Refactored |
| for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) { | ||
| *out_ptr = init; | ||
| contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init); | ||
| *out_ptr = contiguous_reduce(in_ptr, reduction_size, Op{}, init, init); |
There was a problem hiding this comment.
Is it supposed to be:
| *out_ptr = contiguous_reduce(in_ptr, reduction_size, Op{}, init, init); | |
| *out_ptr = contiguous_reduce(in_ptr, reduction_size, Op{}, *out_ptr, init); |
| constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>); | ||
| simd::Simd<U, N> accumulator_v(init); | ||
| while (size >= N) { | ||
| accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x))); |
There was a problem hiding this comment.
When building for macOS 14 it seems that we can't simply convert Simd<float16> to Simd<float>:
/Users/runner/actions-runner/_work/mlx/mlx/mlx/backend/cpu/simd/accelerate_simd.h:63:40: error: no matching function for call to 'convert'
63 | Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
| ^~~~~~~~~~~~~~~~~~~~~~
/Users/runner/actions-runner/_work/mlx/mlx/mlx/backend/cpu/reduce.cpp:113:39: note: in instantiation of function template specialization 'mlx::core::simd::Simd<float, 8>::Simd<__fp16>' requested here
113 | accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
| ^
| const std::vector<int>& axes) { | ||
| if (rtype == Reduce::And) { | ||
| reduction_op<InT, bool, AndReduce>(in, out, axes, true); | ||
| reduction_op<InT, bool, AndReduce, bool>(in, out, axes, true); |
There was a problem hiding this comment.
I prefer omitting the optional AccT so it would be obvious when the code is accumulating in a different type.
Proposed changes
Accumulate CPU float16 and bfloat16 sum reductions in float32, while preserving the output dtype. This fixes precision loss in ops that use sum().
Fixes #3326.
Checklist
Put an
xin the boxes that apply.