Skip to content

Commit 0aa111a

Browse files
committed
deal with Metal.jl
1 parent 88e8e22 commit 0aa111a

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

ext/StridedGPUArraysExt.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ end
3838

3939
@kernel function _mapreduce_gpu_kernel!(
4040
f, op, initop,
41-
dims_red, strides, offsets,
42-
arrays
41+
dims_red, strides, offsets, ops, arrays
4342
)
4443

4544
I_out = @index(Global, Cartesian)
@@ -50,21 +49,22 @@ end
5049
# Initialize accumulator from current output value (or apply initop)
5150
out = arrays[1]
5251
out_I_parent = Is_parent[1]
53-
@inbounds acc = _gpu_init_acc(initop, out[ParentIndex(out_I_parent)])
52+
@inbounds acc = _gpu_init_acc(initop, ops[1](out[out_I_parent]))
5453

5554
inputs = Base.tail(arrays)
5655
inputs_I_parent = Base.tail(Is_parent)
5756
inputs_strides = Base.tail(strides)
57+
inputs_ops = Base.tail(ops)
5858

5959
for I_red in CartesianIndices(dims_red)
6060
# Compute parent index for current reduction index
6161
Is_red_parent = cartesian2parent.(inputs_strides, Ref(I_red))
6262
# Get values from each input array, apply map function, and accumulate
63-
vals = getindex.(inputs, ParentIndex.(inputs_I_parent .+ Is_red_parent))
63+
vals = inputs_ops.(getindex.(inputs, inputs_I_parent .+ Is_red_parent))
6464
acc = _gpu_accum(op, acc, f(vals...))
6565
end
6666
# Write back result to output array
67-
@inbounds out[ParentIndex(out_I_parent)] = acc
67+
@inbounds out[out_I_parent] = ops[1](acc)
6868
end
6969

7070
# GPU-compatible _mapreduce: avoids scalar indexing (first(A), out[ParentIndex(1)])
@@ -120,7 +120,8 @@ function Strided._mapreduce_block!(
120120

121121
backend = KernelAbstractions.get_backend(parent(out))
122122
kernel! = _mapreduce_gpu_kernel!(backend)
123-
kernel!(f, op, initop, dims_red, strides, offsets, arrays; ndrange = dims_out)
123+
ops = getproperty.(arrays, :op)
124+
kernel!(f, op, initop, dims_red, strides, offsets, ops, parent.(arrays); ndrange = dims_out)
124125

125126
return nothing
126127
end

0 commit comments

Comments
 (0)