|
38 | 38 |
|
39 | 39 | @kernel function _mapreduce_gpu_kernel!( |
40 | 40 | f, op, initop, |
41 | | - dims_red, strides, offsets, |
42 | | - arrays |
| 41 | + dims_red, strides, offsets, ops, arrays |
43 | 42 | ) |
44 | 43 |
|
45 | 44 | I_out = @index(Global, Cartesian) |
|
50 | 49 | # Initialize accumulator from current output value (or apply initop) |
51 | 50 | out = arrays[1] |
52 | 51 | out_I_parent = Is_parent[1] |
53 | | - @inbounds acc = _gpu_init_acc(initop, out[ParentIndex(out_I_parent)]) |
| 52 | + @inbounds acc = _gpu_init_acc(initop, ops[1](out[out_I_parent])) |
54 | 53 |
|
55 | 54 | inputs = Base.tail(arrays) |
56 | 55 | inputs_I_parent = Base.tail(Is_parent) |
57 | 56 | inputs_strides = Base.tail(strides) |
| 57 | + inputs_ops = Base.tail(ops) |
58 | 58 |
|
59 | 59 | for I_red in CartesianIndices(dims_red) |
60 | 60 | # Compute parent index for current reduction index |
61 | 61 | Is_red_parent = cartesian2parent.(inputs_strides, Ref(I_red)) |
62 | 62 | # Get values from each input array, apply map function, and accumulate |
63 | | - vals = getindex.(inputs, ParentIndex.(inputs_I_parent .+ Is_red_parent)) |
| 63 | + vals = inputs_ops.(getindex.(inputs, inputs_I_parent .+ Is_red_parent)) |
64 | 64 | acc = _gpu_accum(op, acc, f(vals...)) |
65 | 65 | end |
66 | 66 | # Write back result to output array |
67 | | - @inbounds out[ParentIndex(out_I_parent)] = acc |
| 67 | + @inbounds out[out_I_parent] = ops[1](acc) |
68 | 68 | end |
69 | 69 |
|
70 | 70 | # GPU-compatible _mapreduce: avoids scalar indexing (first(A), out[ParentIndex(1)]) |
@@ -120,7 +120,8 @@ function Strided._mapreduce_block!( |
120 | 120 |
|
121 | 121 | backend = KernelAbstractions.get_backend(parent(out)) |
122 | 122 | kernel! = _mapreduce_gpu_kernel!(backend) |
123 | | - kernel!(f, op, initop, dims_red, strides, offsets, arrays; ndrange = dims_out) |
| 123 | + ops = getproperty.(arrays, :op) |
| 124 | + kernel!(f, op, initop, dims_red, strides, offsets, ops, parent.(arrays); ndrange = dims_out) |
124 | 125 |
|
125 | 126 | return nothing |
126 | 127 | end |
|
0 commit comments