226226Base. findfirst (A:: AnyGPUArray{Bool} ) = findfirst (identity, A)
227227Base. findlast (A:: AnyGPUArray{Bool} ) = findlast (identity, A)
228228
229- function findminmax (binop, A:: AnyGPUArray ; init, dims)
229+ function findminmax (binop, f, A:: AnyGPUArray ; init, dims)
230230 indices = EachIndex (A)
231231 dummy_index = firstindex (A)
232232
@@ -237,21 +237,25 @@ function findminmax(binop, A::AnyGPUArray; init, dims)
237237 isequal (x, y) && return (x, min (i, j))
238238 return t1
239239 end
240+
241+ fA = f .(A)
240242
241243 if dims == Colon ()
242- res = mapreduce (tuple, reduction, A , indices; init = (init, dummy_index))
244+ res = mapreduce (tuple, reduction, fA , indices; init = (init, dummy_index))
243245
244246 # out of consistency with Base.findarray, return a CartesianIndex
245247 # when the input is a multidimensional array
246248 return (res[1 ], ndims (A) == 1 ? res[2 ] : CartesianIndices (A)[res[2 ]])
247249 else
248- res = mapreduce (tuple, reduction, A , indices;
250+ res = mapreduce (tuple, reduction, fA , indices;
249251 init = (init, dummy_index), dims= dims)
250252 vals = map (x-> x[1 ], res)
251253 inds = map (x-> ndims (A) == 1 ? x[2 ] : CartesianIndices (A)[x[2 ]], res)
252254 return (vals, inds)
253255 end
254256end
255257
256- Base. findmax (a:: AnyGPUArray ; dims= :) = findminmax (Base. isless, a; init= typemin (eltype (a)), dims)
257- Base. findmin (a:: AnyGPUArray ; dims= :) = findminmax (Base. isgreater, a; init= typemax (eltype (a)), dims)
258+ Base. findmax (a:: AnyGPUArray ; dims= :) = findminmax (Base. isless, identity, a; init= typemin (eltype (a)), dims)
259+ Base. findmin (a:: AnyGPUArray ; dims= :) = findminmax (Base. isgreater, identity, a; init= typemax (eltype (a)), dims)
260+ Base. findmax (f:: Function , a:: AnyGPUArray ; dims= :) = findminmax (Base. isless, f, a; init= typemin (eltype (a)), dims)
261+ Base. findmin (f:: Function , a:: AnyGPUArray ; dims= :) = findminmax (Base. isgreater, f, a; init= typemax (eltype (a)), dims)
0 commit comments