33#include < cassert>
44#include < functional>
55#include < limits>
6+ #include < type_traits>
67
78#include " mlx/backend/common/reduce.h"
89#include " mlx/backend/cpu/encoder.h"
@@ -96,19 +97,29 @@ void strided_reduce(
9697};
9798
9899template <typename T, typename U, typename Op>
99- void contiguous_reduce (const T* x, U* accumulator, int size, Op op, U init) {
100+ U strided_reduce (const T* x, int size, size_t stride, Op op, U accumulator) {
101+ for (int i = 0 ; i < size; i++) {
102+ accumulator = op (accumulator, *x);
103+ x += stride;
104+ }
105+ return accumulator;
106+ }
107+
108+ template <typename T, typename U, typename Op>
109+ U contiguous_reduce (const T* x, int size, Op op, U accumulator, U init) {
100110 constexpr int N = std::min (simd::max_size<T>, simd::max_size<U>);
101111 simd::Simd<U, N> accumulator_v (init);
102112 while (size >= N) {
103113 accumulator_v = op (accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
104114 x += N;
105115 size -= N;
106116 }
107- * accumulator = op (* accumulator, op (accumulator_v));
117+ accumulator = op (accumulator, op (accumulator_v));
108118 while (size-- > 0 ) {
109- * accumulator = op (* accumulator, *x);
119+ accumulator = op (accumulator, *x);
110120 x++;
111121 }
122+ return accumulator;
112123}
113124
114125// Helper for the ndimensional strided loop
@@ -135,27 +146,25 @@ void nd_loop(
135146 loop_inner (0 , 0 );
136147}
137148
138- template <typename T, typename U , typename Op>
149+ template <typename T, typename OutT , typename Op, typename AccT = OutT >
139150void reduction_op (
140151 const array& x,
141152 array& out,
142153 const std::vector<int >& axes,
143- U init) {
154+ AccT init) {
144155 ReductionPlan plan = get_reduction_plan (x, axes);
145156
146157 auto in_ptr = x.data <T>();
147- auto out_ptr = out.data <U >();
158+ auto out_ptr = out.data <OutT >();
148159 if (plan.type == ContiguousAllReduce) {
149- *out_ptr = init;
150- contiguous_reduce (in_ptr, out_ptr, x.size (), Op{}, init);
160+ *out_ptr = contiguous_reduce (in_ptr, x.size (), Op{}, init, init);
151161 return ;
152162 }
153163
154164 if (plan.type == ContiguousReduce && plan.shape .size () == 1 ) {
155165 int reduction_size = plan.shape [0 ];
156166 for (int i = 0 ; i < out.size (); i++, out_ptr++, in_ptr += reduction_size) {
157- *out_ptr = init;
158- contiguous_reduce (in_ptr, out_ptr, reduction_size, Op{}, init);
167+ *out_ptr = contiguous_reduce (in_ptr, reduction_size, Op{}, init, init);
159168 }
160169 return ;
161170 }
@@ -170,24 +179,25 @@ void reduction_op(
170179 if (plan.shape .size () == 0 ) {
171180 for (int i = 0 ; i < out.size (); i++, out_ptr++) {
172181 int offset = elem_to_loc (i, shape, strides);
173- *out_ptr = init;
174- contiguous_reduce ( in_ptr + offset, out_ptr, reduction_size, Op{}, init);
182+ *out_ptr = contiguous_reduce (
183+ in_ptr + offset, reduction_size, Op{}, init , init);
175184 }
176185 } else {
177186 for (int i = 0 ; i < out.size (); i++, out_ptr++) {
178187 int offset = elem_to_loc (i, shape, strides);
179- *out_ptr = init;
188+ AccT val = init;
180189 nd_loop (
181190 [&](int extra_offset) {
182- contiguous_reduce (
191+ val = contiguous_reduce (
183192 in_ptr + offset + extra_offset,
184- out_ptr,
185193 reduction_size,
186194 Op{},
195+ val,
187196 init);
188197 },
189198 plan.shape ,
190199 plan.strides );
200+ *out_ptr = val;
191201 }
192202 }
193203 return ;
@@ -199,8 +209,15 @@ void reduction_op(
199209 plan.shape .pop_back ();
200210 plan.strides .pop_back ();
201211 for (int i = 0 ; i < out.size (); i += reduction_stride) {
202- std::fill_n (out_ptr, reduction_stride, init);
203- strided_reduce (in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
212+ if constexpr (std::is_same_v<OutT, AccT>) {
213+ std::fill_n (out_ptr, reduction_stride, init);
214+ strided_reduce (in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
215+ } else {
216+ for (size_t j = 0 ; j < reduction_stride; j++) {
217+ out_ptr[j] = strided_reduce (
218+ in_ptr + j, reduction_size, reduction_stride, Op{}, init);
219+ }
220+ }
204221 in_ptr += reduction_stride * reduction_size;
205222 out_ptr += reduction_stride;
206223 }
@@ -218,26 +235,55 @@ void reduction_op(
218235 if (plan.shape .size () == 0 ) {
219236 for (int i = 0 ; i < out.size (); i += reduction_stride) {
220237 int offset = elem_to_loc (i, shape, strides);
221- std::fill_n (out_ptr, reduction_stride, init);
222- strided_reduce (
223- in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
238+ if constexpr (std::is_same_v<OutT, AccT>) {
239+ std::fill_n (out_ptr, reduction_stride, init);
240+ strided_reduce (
241+ in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
242+ } else {
243+ for (size_t j = 0 ; j < reduction_stride; j++) {
244+ out_ptr[j] = strided_reduce (
245+ in_ptr + offset + j,
246+ reduction_size,
247+ reduction_stride,
248+ Op{},
249+ init);
250+ }
251+ }
224252 out_ptr += reduction_stride;
225253 }
226254 } else {
227255 for (int i = 0 ; i < out.size (); i += reduction_stride) {
228256 int offset = elem_to_loc (i, shape, strides);
229- std::fill_n (out_ptr, reduction_stride, init);
230- nd_loop (
231- [&](int extra_offset) {
232- strided_reduce (
233- in_ptr + offset + extra_offset,
234- out_ptr,
235- reduction_size,
236- reduction_stride,
237- Op{});
238- },
239- plan.shape ,
240- plan.strides );
257+ if constexpr (std::is_same_v<OutT, AccT>) {
258+ std::fill_n (out_ptr, reduction_stride, init);
259+ nd_loop (
260+ [&](int extra_offset) {
261+ strided_reduce (
262+ in_ptr + offset + extra_offset,
263+ out_ptr,
264+ reduction_size,
265+ reduction_stride,
266+ Op{});
267+ },
268+ plan.shape ,
269+ plan.strides );
270+ } else {
271+ for (size_t j = 0 ; j < reduction_stride; j++) {
272+ AccT val = init;
273+ nd_loop (
274+ [&](int extra_offset) {
275+ val = strided_reduce (
276+ in_ptr + offset + extra_offset + j,
277+ reduction_size,
278+ reduction_stride,
279+ Op{},
280+ val);
281+ },
282+ plan.shape ,
283+ plan.strides );
284+ out_ptr[j] = val;
285+ }
286+ }
241287 out_ptr += reduction_stride;
242288 }
243289 }
@@ -249,7 +295,7 @@ void reduction_op(
249295
250296 for (int i = 0 ; i < out.size (); i++, out_ptr++) {
251297 int offset = elem_to_loc (i, shape, strides);
252- U val = init;
298+ AccT val = init;
253299 nd_loop (
254300 [&](int extra_offset) {
255301 val = Op{}(val, *(in_ptr + offset + extra_offset));
@@ -404,9 +450,9 @@ void reduce_dispatch_and_or(
404450 Reduce::ReduceType rtype,
405451 const std::vector<int >& axes) {
406452 if (rtype == Reduce::And) {
407- reduction_op<InT, bool , AndReduce>(in, out, axes, true );
453+ reduction_op<InT, bool , AndReduce, bool >(in, out, axes, true );
408454 } else {
409- reduction_op<InT, bool , OrReduce>(in, out, axes, false );
455+ reduction_op<InT, bool , OrReduce, bool >(in, out, axes, false );
410456 }
411457}
412458
@@ -417,16 +463,19 @@ void reduce_dispatch_sum_prod(
417463 Reduce::ReduceType rtype,
418464 const std::vector<int >& axes) {
419465 if (rtype == Reduce::Sum) {
420- if constexpr (std::is_integral_v<InT> && sizeof (InT) <= 4 ) {
421- reduction_op<InT, int32_t , SumReduce>(in, out, axes, 0 );
466+ if constexpr (
467+ std::is_same_v<InT, float16_t > || std::is_same_v<InT, bfloat16_t >) {
468+ reduction_op<InT, InT, SumReduce, float >(in, out, axes, 0 .0f );
469+ } else if constexpr (std::is_integral_v<InT> && sizeof (InT) <= 4 ) {
470+ reduction_op<InT, int32_t , SumReduce, int32_t >(in, out, axes, 0 );
422471 } else {
423- reduction_op<InT, InT, SumReduce>(in, out, axes, 0 );
472+ reduction_op<InT, InT, SumReduce, InT >(in, out, axes, 0 );
424473 }
425474 } else {
426475 if constexpr (std::is_integral_v<InT> && sizeof (InT) <= 4 ) {
427- reduction_op<InT, int32_t , ProdReduce>(in, out, axes, 1 );
476+ reduction_op<InT, int32_t , ProdReduce, int32_t >(in, out, axes, 1 );
428477 } else {
429- reduction_op<InT, InT, ProdReduce>(in, out, axes, 1 );
478+ reduction_op<InT, InT, ProdReduce, InT >(in, out, axes, 1 );
430479 }
431480 }
432481}
@@ -439,10 +488,10 @@ void reduce_dispatch_min_max(
439488 const std::vector<int >& axes) {
440489 if (rtype == Reduce::Max) {
441490 auto init = Limits<InT>::min;
442- reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
491+ reduction_op<InT, InT, MaxReduce, InT >(in, out, axes, init);
443492 } else {
444493 auto init = Limits<InT>::max;
445- reduction_op<InT, InT, MinReduce>(in, out, axes, init);
494+ reduction_op<InT, InT, MinReduce, InT >(in, out, axes, init);
446495 }
447496}
448497
0 commit comments