From 64a1465679d6716410012ffabc9d1c68cb98a81c Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Thu, 25 Sep 2025 14:19:30 -0500 Subject: [PATCH 1/3] Findall --- Project.toml | 2 ++ src/accumulate.jl | 8 ++++++++ src/indexing.jl | 36 ++++++++++++++++++++++++++++++++++++ src/oneAPI.jl | 4 +++- 4 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 src/accumulate.jl create mode 100644 src/indexing.jl diff --git a/Project.toml b/Project.toml index e657f638..a39ed566 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "2.3.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" @@ -31,6 +32,7 @@ oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36" [compat] AbstractFFTs = "1.5.0" +AcceleratedKernels = "0.4.3" Adapt = "4" CEnum = "0.4, 0.5" ExprTools = "0.1" diff --git a/src/accumulate.jl b/src/accumulate.jl new file mode 100644 index 00000000..2f7bf416 --- /dev/null +++ b/src/accumulate.jl @@ -0,0 +1,8 @@ +Base.accumulate!(op, B::oneArray, A::oneArray; init=zero(eltype(A)), kwargs...) = + AK.accumulate!(op, B, A, oneAPIBackend(); init, kwargs...) + +Base.accumulate(op, A::oneArray; init=zero(eltype(A)), kwargs...) = + AK.accumulate(op, A, oneAPIBackend(); init, kwargs...) + +Base.cumsum(src::oneArray; kwargs...) = AK.cumsum(src, oneAPIBackend(); kwargs...) +Base.cumprod(src::oneArray; kwargs...) = AK.cumprod(src, oneAPIBackend(); kwargs...) \ No newline at end of file diff --git a/src/indexing.jl b/src/indexing.jl new file mode 100644 index 00000000..0e52ecd4 --- /dev/null +++ b/src/indexing.jl @@ -0,0 +1,36 @@ +Base.to_index(::oneArray, I::AbstractArray{Bool}) = findall(I) + +if VERSION >= v"1.11.0-DEV.1157" + Base.to_indices(x::oneArray, I::Tuple{AbstractArray{Bool}}) = + (Base.to_index(x, I[1]),) +end + +function _ker!(ys, bools, indices) + i = get_global_id() + + @inbounds if i ≤ length(bools) && bools[i] + ii = CartesianIndices(bools)[i] + b = indices[i] # new position + ys[b] = ii + end + return +end + +function Base.findall(bools::oneArray{Bool}) + I = keytype(bools) + + indices = cumsum(reshape(bools, prod(size(bools)))) + oneL0.synchronize() + + n = isempty(indices) ? 0 : @allowscalar indices[end] + + ys = oneArray{I}(undef, n) + + if n > 0 + @oneapi items=length(bools) _ker!(ys, bools, indices) + end + oneL0.synchronize() + unsafe_free!(indices) + + return ys +end \ No newline at end of file diff --git a/src/oneAPI.jl b/src/oneAPI.jl index bbb4b3f3..b7f8b527 100644 --- a/src/oneAPI.jl +++ b/src/oneAPI.jl @@ -58,7 +58,7 @@ export SYCL include("../lib/mkl/oneMKL.jl") export oneMKL end - +import AcceleratedKernels as AK # integrations and specialized functionality include("broadcast.jl") include("mapreduce.jl") @@ -68,6 +68,8 @@ include("utils.jl") include("oneAPIKernels.jl") import .oneAPIKernels: oneAPIBackend +include("accumulate.jl") +include("indexing.jl") export oneAPIBackend function __init__() From f3f847e33bebfeec6b88e6acbd7bc44c15427590 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Thu, 25 Sep 2025 14:25:30 -0500 Subject: [PATCH 2/3] Add test --- test/indexing.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 test/indexing.jl diff --git a/test/indexing.jl b/test/indexing.jl new file mode 100644 index 00000000..4f1bfcc7 --- /dev/null +++ b/test/indexing.jl @@ -0,0 +1,20 @@ +using Test +using oneAPI + +@testset "findall" begin + bools1d = oneArray([true, false, true, false, true]) + @test Array(findall(bools1d)) == findall(Bool[true, false, true, false, true]) + + bools2d = oneArray(Bool[true false; false true; true false]) + @test Array(findall(bools2d)) == findall(Bool[true false; false true; true false]) + + all_false = oneArray(fill(false, 4)) + @test Array(findall(all_false)) == Int[] + + all_true = oneArray(fill(true, 3, 2)) + @test Array(findall(all_true)) == findall(fill(true, 3, 2)) + + data = oneArray(collect(1:6)) + mask = oneArray(Bool[true, false, true, false, false, true]) + @test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])] +end From 7ebe854e96ae16ed5bd2600ea6df36af66e1bd29 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Thu, 25 Sep 2025 14:26:06 -0500 Subject: [PATCH 3/3] Format --- src/accumulate.jl | 6 +++--- src/indexing.jl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/accumulate.jl b/src/accumulate.jl index 2f7bf416..206b4ea1 100644 --- a/src/accumulate.jl +++ b/src/accumulate.jl @@ -1,8 +1,8 @@ -Base.accumulate!(op, B::oneArray, A::oneArray; init=zero(eltype(A)), kwargs...) = +Base.accumulate!(op, B::oneArray, A::oneArray; init = zero(eltype(A)), kwargs...) = AK.accumulate!(op, B, A, oneAPIBackend(); init, kwargs...) -Base.accumulate(op, A::oneArray; init=zero(eltype(A)), kwargs...) = +Base.accumulate(op, A::oneArray; init = zero(eltype(A)), kwargs...) = AK.accumulate(op, A, oneAPIBackend(); init, kwargs...) Base.cumsum(src::oneArray; kwargs...) = AK.cumsum(src, oneAPIBackend(); kwargs...) -Base.cumprod(src::oneArray; kwargs...) = AK.cumprod(src, oneAPIBackend(); kwargs...) \ No newline at end of file +Base.cumprod(src::oneArray; kwargs...) = AK.cumprod(src, oneAPIBackend(); kwargs...) diff --git a/src/indexing.jl b/src/indexing.jl index 0e52ecd4..661deaaf 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -27,10 +27,10 @@ function Base.findall(bools::oneArray{Bool}) ys = oneArray{I}(undef, n) if n > 0 - @oneapi items=length(bools) _ker!(ys, bools, indices) + @oneapi items = length(bools) _ker!(ys, bools, indices) end oneL0.synchronize() unsafe_free!(indices) return ys -end \ No newline at end of file +end