Skip to content

Commit f869884

Browse files
committed
Accumulate CPU half-precision sums in float32
1 parent 80bcd1c commit f869884

2 files changed

Lines changed: 122 additions & 42 deletions

File tree

mlx/backend/cpu/reduce.cpp

Lines changed: 91 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <cassert>
44
#include <functional>
55
#include <limits>
6+
#include <type_traits>
67

78
#include "mlx/backend/common/reduce.h"
89
#include "mlx/backend/cpu/encoder.h"
@@ -96,19 +97,29 @@ void strided_reduce(
9697
};
9798

9899
template <typename T, typename U, typename Op>
99-
void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {
100+
U strided_reduce(const T* x, int size, size_t stride, Op op, U accumulator) {
101+
for (int i = 0; i < size; i++) {
102+
accumulator = op(accumulator, *x);
103+
x += stride;
104+
}
105+
return accumulator;
106+
}
107+
108+
template <typename T, typename U, typename Op>
109+
U contiguous_reduce(const T* x, int size, Op op, U accumulator, U init) {
100110
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
101111
simd::Simd<U, N> accumulator_v(init);
102112
while (size >= N) {
103113
accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
104114
x += N;
105115
size -= N;
106116
}
107-
*accumulator = op(*accumulator, op(accumulator_v));
117+
accumulator = op(accumulator, op(accumulator_v));
108118
while (size-- > 0) {
109-
*accumulator = op(*accumulator, *x);
119+
accumulator = op(accumulator, *x);
110120
x++;
111121
}
122+
return accumulator;
112123
}
113124

114125
// Helper for the ndimensional strided loop
@@ -135,27 +146,25 @@ void nd_loop(
135146
loop_inner(0, 0);
136147
}
137148

138-
template <typename T, typename U, typename Op>
149+
template <typename T, typename OutT, typename Op, typename AccT = OutT>
139150
void reduction_op(
140151
const array& x,
141152
array& out,
142153
const std::vector<int>& axes,
143-
U init) {
154+
AccT init) {
144155
ReductionPlan plan = get_reduction_plan(x, axes);
145156

146157
auto in_ptr = x.data<T>();
147-
auto out_ptr = out.data<U>();
158+
auto out_ptr = out.data<OutT>();
148159
if (plan.type == ContiguousAllReduce) {
149-
*out_ptr = init;
150-
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
160+
*out_ptr = contiguous_reduce(in_ptr, x.size(), Op{}, init, init);
151161
return;
152162
}
153163

154164
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
155165
int reduction_size = plan.shape[0];
156166
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
157-
*out_ptr = init;
158-
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
167+
*out_ptr = contiguous_reduce(in_ptr, reduction_size, Op{}, init, init);
159168
}
160169
return;
161170
}
@@ -170,24 +179,25 @@ void reduction_op(
170179
if (plan.shape.size() == 0) {
171180
for (int i = 0; i < out.size(); i++, out_ptr++) {
172181
int offset = elem_to_loc(i, shape, strides);
173-
*out_ptr = init;
174-
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
182+
*out_ptr = contiguous_reduce(
183+
in_ptr + offset, reduction_size, Op{}, init, init);
175184
}
176185
} else {
177186
for (int i = 0; i < out.size(); i++, out_ptr++) {
178187
int offset = elem_to_loc(i, shape, strides);
179-
*out_ptr = init;
188+
AccT val = init;
180189
nd_loop(
181190
[&](int extra_offset) {
182-
contiguous_reduce(
191+
val = contiguous_reduce(
183192
in_ptr + offset + extra_offset,
184-
out_ptr,
185193
reduction_size,
186194
Op{},
195+
val,
187196
init);
188197
},
189198
plan.shape,
190199
plan.strides);
200+
*out_ptr = val;
191201
}
192202
}
193203
return;
@@ -199,8 +209,15 @@ void reduction_op(
199209
plan.shape.pop_back();
200210
plan.strides.pop_back();
201211
for (int i = 0; i < out.size(); i += reduction_stride) {
202-
std::fill_n(out_ptr, reduction_stride, init);
203-
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
212+
if constexpr (std::is_same_v<OutT, AccT>) {
213+
std::fill_n(out_ptr, reduction_stride, init);
214+
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
215+
} else {
216+
for (size_t j = 0; j < reduction_stride; j++) {
217+
out_ptr[j] = strided_reduce(
218+
in_ptr + j, reduction_size, reduction_stride, Op{}, init);
219+
}
220+
}
204221
in_ptr += reduction_stride * reduction_size;
205222
out_ptr += reduction_stride;
206223
}
@@ -218,26 +235,55 @@ void reduction_op(
218235
if (plan.shape.size() == 0) {
219236
for (int i = 0; i < out.size(); i += reduction_stride) {
220237
int offset = elem_to_loc(i, shape, strides);
221-
std::fill_n(out_ptr, reduction_stride, init);
222-
strided_reduce(
223-
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
238+
if constexpr (std::is_same_v<OutT, AccT>) {
239+
std::fill_n(out_ptr, reduction_stride, init);
240+
strided_reduce(
241+
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
242+
} else {
243+
for (size_t j = 0; j < reduction_stride; j++) {
244+
out_ptr[j] = strided_reduce(
245+
in_ptr + offset + j,
246+
reduction_size,
247+
reduction_stride,
248+
Op{},
249+
init);
250+
}
251+
}
224252
out_ptr += reduction_stride;
225253
}
226254
} else {
227255
for (int i = 0; i < out.size(); i += reduction_stride) {
228256
int offset = elem_to_loc(i, shape, strides);
229-
std::fill_n(out_ptr, reduction_stride, init);
230-
nd_loop(
231-
[&](int extra_offset) {
232-
strided_reduce(
233-
in_ptr + offset + extra_offset,
234-
out_ptr,
235-
reduction_size,
236-
reduction_stride,
237-
Op{});
238-
},
239-
plan.shape,
240-
plan.strides);
257+
if constexpr (std::is_same_v<OutT, AccT>) {
258+
std::fill_n(out_ptr, reduction_stride, init);
259+
nd_loop(
260+
[&](int extra_offset) {
261+
strided_reduce(
262+
in_ptr + offset + extra_offset,
263+
out_ptr,
264+
reduction_size,
265+
reduction_stride,
266+
Op{});
267+
},
268+
plan.shape,
269+
plan.strides);
270+
} else {
271+
for (size_t j = 0; j < reduction_stride; j++) {
272+
AccT val = init;
273+
nd_loop(
274+
[&](int extra_offset) {
275+
val = strided_reduce(
276+
in_ptr + offset + extra_offset + j,
277+
reduction_size,
278+
reduction_stride,
279+
Op{},
280+
val);
281+
},
282+
plan.shape,
283+
plan.strides);
284+
out_ptr[j] = val;
285+
}
286+
}
241287
out_ptr += reduction_stride;
242288
}
243289
}
@@ -249,7 +295,7 @@ void reduction_op(
249295

250296
for (int i = 0; i < out.size(); i++, out_ptr++) {
251297
int offset = elem_to_loc(i, shape, strides);
252-
U val = init;
298+
AccT val = init;
253299
nd_loop(
254300
[&](int extra_offset) {
255301
val = Op{}(val, *(in_ptr + offset + extra_offset));
@@ -404,9 +450,9 @@ void reduce_dispatch_and_or(
404450
Reduce::ReduceType rtype,
405451
const std::vector<int>& axes) {
406452
if (rtype == Reduce::And) {
407-
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
453+
reduction_op<InT, bool, AndReduce, bool>(in, out, axes, true);
408454
} else {
409-
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
455+
reduction_op<InT, bool, OrReduce, bool>(in, out, axes, false);
410456
}
411457
}
412458

@@ -417,16 +463,19 @@ void reduce_dispatch_sum_prod(
417463
Reduce::ReduceType rtype,
418464
const std::vector<int>& axes) {
419465
if (rtype == Reduce::Sum) {
420-
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
421-
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
466+
if constexpr (
467+
std::is_same_v<InT, float16_t> || std::is_same_v<InT, bfloat16_t>) {
468+
reduction_op<InT, InT, SumReduce, float>(in, out, axes, 0.0f);
469+
} else if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
470+
reduction_op<InT, int32_t, SumReduce, int32_t>(in, out, axes, 0);
422471
} else {
423-
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
472+
reduction_op<InT, InT, SumReduce, InT>(in, out, axes, 0);
424473
}
425474
} else {
426475
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
427-
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
476+
reduction_op<InT, int32_t, ProdReduce, int32_t>(in, out, axes, 1);
428477
} else {
429-
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
478+
reduction_op<InT, InT, ProdReduce, InT>(in, out, axes, 1);
430479
}
431480
}
432481
}
@@ -439,10 +488,10 @@ void reduce_dispatch_min_max(
439488
const std::vector<int>& axes) {
440489
if (rtype == Reduce::Max) {
441490
auto init = Limits<InT>::min;
442-
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
491+
reduction_op<InT, InT, MaxReduce, InT>(in, out, axes, init);
443492
} else {
444493
auto init = Limits<InT>::max;
445-
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
494+
reduction_op<InT, InT, MinReduce, InT>(in, out, axes, init);
446495
}
447496
}
448497

tests/ops_tests.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,37 @@ TEST_CASE("test reduction ops") {
10291029
x = array({1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}, {2, 3});
10301030
CHECK(array_equal(sum(x, 0), full({3}, 3.0f)).item<bool>());
10311031
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
1032+
1033+
auto check_half_sum = [](Dtype dtype, Shape shape, std::vector<int> axes) {
1034+
int size = 1;
1035+
for (auto dim : shape) {
1036+
size *= dim;
1037+
}
1038+
auto x = astype(
1039+
reshape(
1040+
divide(arange(size * 1.0, float32, Device::cpu), array(1000.0f)),
1041+
shape,
1042+
Device::cpu),
1043+
dtype,
1044+
Device::cpu);
1045+
auto out = sum(x, axes, false, Device::cpu);
1046+
auto expected =
1047+
sum(astype(x, float32, Device::cpu), axes, false, Device::cpu);
1048+
auto diff =
1049+
max(abs(subtract(
1050+
astype(out, float32, Device::cpu), expected, Device::cpu)),
1051+
Device::cpu)
1052+
.item<float>();
1053+
auto tolerance = dtype == float16 ? 0.5f : 2.0f;
1054+
CHECK_EQ(out.dtype(), dtype);
1055+
CHECK(diff <= tolerance);
1056+
};
1057+
check_half_sum(float16, {1000}, {0});
1058+
check_half_sum(bfloat16, {1000}, {0});
1059+
check_half_sum(float16, {100, 10}, {0});
1060+
check_half_sum(bfloat16, {100, 10}, {0});
1061+
check_half_sum(float16, {100, 10}, {1});
1062+
check_half_sum(bfloat16, {100, 10}, {1});
10321063
}
10331064

10341065
// Test unsigned sum

0 commit comments

Comments
 (0)