Skip to content

Commit 99ad09b

Browse files
authored
Allow passing a function to findmin and findmax (#673)
1 parent a6c8d08 commit 99ad09b

2 files changed

Lines changed: 25 additions & 5 deletions

File tree

src/host/indexing.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ end
226226
Base.findfirst(A::AnyGPUArray{Bool}) = findfirst(identity, A)
227227
Base.findlast(A::AnyGPUArray{Bool}) = findlast(identity, A)
228228

229-
function findminmax(binop, A::AnyGPUArray; init, dims)
229+
function findminmax(binop, f, A::AnyGPUArray; init, dims)
230230
indices = EachIndex(A)
231231
dummy_index = firstindex(A)
232232

@@ -237,21 +237,25 @@ function findminmax(binop, A::AnyGPUArray; init, dims)
237237
isequal(x, y) && return (x, min(i, j))
238238
return t1
239239
end
240+
241+
fA = f.(A)
240242

241243
if dims == Colon()
242-
res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index))
244+
res = mapreduce(tuple, reduction, fA, indices; init = (init, dummy_index))
243245

244246
# out of consistency with Base.findarray, return a CartesianIndex
245247
# when the input is a multidimensional array
246248
return (res[1], ndims(A) == 1 ? res[2] : CartesianIndices(A)[res[2]])
247249
else
248-
res = mapreduce(tuple, reduction, A, indices;
250+
res = mapreduce(tuple, reduction, fA, indices;
249251
init = (init, dummy_index), dims=dims)
250252
vals = map(x->x[1], res)
251253
inds = map(x->ndims(A) == 1 ? x[2] : CartesianIndices(A)[x[2]], res)
252254
return (vals, inds)
253255
end
254256
end
255257

256-
Base.findmax(a::AnyGPUArray; dims=:) = findminmax(Base.isless, a; init=typemin(eltype(a)), dims)
257-
Base.findmin(a::AnyGPUArray; dims=:) = findminmax(Base.isgreater, a; init=typemax(eltype(a)), dims)
258+
Base.findmax(a::AnyGPUArray; dims=:) = findminmax(Base.isless, identity, a; init=typemin(eltype(a)), dims)
259+
Base.findmin(a::AnyGPUArray; dims=:) = findminmax(Base.isgreater, identity, a; init=typemax(eltype(a)), dims)
260+
Base.findmax(f::Function, a::AnyGPUArray; dims=:) = findminmax(Base.isless, f, a; init=typemin(eltype(a)), dims)
261+
Base.findmin(f::Function, a::AnyGPUArray; dims=:) = findminmax(Base.isgreater, f, a; init=typemax(eltype(a)), dims)

test/testsuite/indexing.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ end
205205
@test isequal(findmax(x), findmax(AT(x)))
206206
@test isequal(findmax(x; dims=1), Array.(findmax(AT(x); dims=1)))
207207
end
208+
let x = randn(Float32, 100)
209+
@test findmax(abs, x) == findmax(abs, AT(x))
210+
@test findmax(abs, x; dims=1) == Array.(findmax(abs, AT(x); dims=1))
211+
212+
x[32] = x[33] = x[55] = x[66] = NaN32
213+
@test isequal(findmax(abs, x), findmax(abs, AT(x)))
214+
@test isequal(findmax(abs, x; dims=1), Array.(findmax(abs, AT(x); dims=1)))
215+
end
208216
let x = rand(Float32, 10, 10)
209217
@test findmax(x) == findmax(AT(x))
210218
@test findmax(x; dims=1) == Array.(findmax(AT(x); dims=1))
@@ -235,6 +243,14 @@ end
235243
@test isequal(findmin(x), findmin(AT(x)))
236244
@test isequal(findmin(x; dims=1), Array.(findmin(AT(x); dims=1)))
237245
end
246+
let x = randn(Float32, 100)
247+
@test findmin(abs, x) == findmin(abs, AT(x))
248+
@test findmin(abs, x; dims=1) == Array.(findmin(abs, AT(x); dims=1))
249+
250+
x[32] = x[33] = x[55] = x[66] = NaN32
251+
@test isequal(findmin(abs, x), findmin(abs, AT(x)))
252+
@test isequal(findmin(abs, x; dims=1), Array.(findmin(abs, AT(x); dims=1)))
253+
end
238254
let x = rand(Float32, 10, 10)
239255
@test findmin(x) == findmin(AT(x))
240256
@test findmin(x; dims=1) == Array.(findmin(AT(x); dims=1))

0 commit comments

Comments
 (0)