diff --git a/src/sort/merge_sort.jl b/src/sort/merge_sort.jl index 5fb7b20..b81cdbf 100644 --- a/src/sort/merge_sort.jl +++ b/src/sort/merge_sort.jl @@ -149,6 +149,20 @@ function merge_sort!( ) # Simple sanity checks @argcheck block_size > 0 + + # Hoist `by` transform: broadcast it once to produce a key array, then sort + # (key, value) pairs. Without hoisting, by(elem) fires inside every binary-search + # comparison in the O(n log²n) merge hot-path — once per element per merge step. + if by !== identity + keys = by.(v) + merge_sort_by_key!( + keys, v, backend; + lt, rev, order, block_size, + temp_values=temp, # temp was for v swap buffer; maps to temp_values here + ) + return v + end + if !isnothing(temp) @argcheck length(temp) == length(v) @argcheck eltype(temp) === eltype(v) diff --git a/src/sort/sort.jl b/src/sort/sort.jl index baba4dc..02acc5a 100644 --- a/src/sort/sort.jl +++ b/src/sort/sort.jl @@ -208,10 +208,15 @@ function _sortperm_impl!( temp::Union{Nothing, AbstractArray}=nothing, ) if use_gpu_algorithm(backend, prefer_threads) - merge_sortperm_lowmem!( + # merge_sortperm! copies keys alongside indices in shared memory so comparisons + # never touch global memory during the binary-search step. + # merge_sortperm_lowmem! avoids the key copy but its comparator does two global + # loads per comparison, making it O(n log²n) in global traffic at large n. + merge_sortperm!( ix, v, backend; lt, by, rev, order, - block_size, temp, + block_size, + temp_ix=temp, # old `temp` was the index buffer; maps directly to temp_ix ) else sample_sortperm!( diff --git a/test/generic/sort.jl b/test/generic/sort.jl index ee48c00..ecf11dd 100644 --- a/test/generic/sort.jl +++ b/test/generic/sort.jl @@ -48,6 +48,79 @@ if !IS_CPU_BACKEND || !prefer_threads block_size=64, temp=array_from_host(1:10_000, Int32)) @test issorted(Array(v)) end + +@testset "sort_by_transform" begin + # Tests for the by= hoisting optimisation: by(elem) is broadcast once before + # sorting rather than being called inside every merge comparison. + # Checks exact output match against Base.sort so we catch ordering regressions. + Random.seed!(42) + + # Exact match against Base.sort for common by= functions + for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Float32, Float64, Int32)) + n = 10_000 + v_h = T <: AbstractFloat ? randn(T, n) : rand(T(-100):T(100), n) + for (kw, base_kw) in ( + ((by=abs,), (by=abs,)), + ((by=abs, rev=true), (by=abs, rev=true)), + ((by=x->x^2,), (by=x->x^2,)), + ) + v = array_from_host(v_h) + tmp = copy(v) + AK.merge_sort!(tmp; kw...) + @test Array(tmp) == sort(v_h; base_kw...) + end + end + + # rev=true and lt=(>) are not hoisted (no by=) — verify they still pass + n = 10_000 + v_h = randn(Float32, n) + v = array_from_host(v_h); tmp = copy(v) + AK.merge_sort!(tmp; rev=true) + @test Array(tmp) == sort(v_h; rev=true) + + # Edge sizes under by= hoisting + for n in (1, 2, 513, 1025) + v_h = randn(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + AK.merge_sort!(tmp; by=abs) + @test Array(tmp) == sort(v_h; by=abs) + end + + # temp kwarg still forwarded correctly through hoisting path + n = 20_000 + v_h = randn(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + temp = array_from_host(zeros(Float32, n)) + AK.merge_sort!(tmp; by=abs, temp) + @test Array(tmp) == sort(v_h; by=abs) + + # sort! (public API) routes through the same hoisting path + n = 10_000 + v_h = randn(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + AK.sort!(tmp; by=abs) + @test Array(tmp) == sort(v_h; by=abs) + + # by= with a type-changing transform (Float32 → Bool key) + n = 10_000 + v_h = randn(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + AK.merge_sort!(tmp; by=x->x>0) + @test Array(tmp) == sort(v_h; by=x->x>0) + + # identity path unchanged: verify no regression from the early-return guard + n = 10_000 + v_h = rand(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + AK.merge_sort!(tmp) + @test Array(tmp) == sort(v_h) +end + else # CPU backend @testset "sample_sort" begin Random.seed!(0) @@ -453,3 +526,116 @@ end vh = Array(v) @test issorted(vh[ixh]) end + + +if !IS_CPU_BACKEND || !prefer_threads +@testset "sortperm_extended" begin + # Helper: ix is a valid permutation of 1:n that produces a sorted order + function is_valid_perm(vh, ixh; kwargs...) + n = length(vh) + length(ixh) == n && + sort(Int.(ixh)) == collect(1:n) && + issorted(vh[ixh]; kwargs...) + end + + # ── Element types ──────────────────────────────────────────────────────── + Random.seed!(123) + + for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Int16, UInt16, Int64, UInt64, Float64, UInt8)) + for _ in 1:50 + n = rand(1:50_000) + v = array_from_host(rand(T, n)) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + vh, ixh = Array(v), Array(ix) + @test is_valid_perm(vh, ixh) + end + end + + # ── Edge sizes ─────────────────────────────────────────────────────────── + for n in (1, 2, 3, 511, 512, 513, 1023, 1024, 1025, 2047, 2048, 2049) + v = array_from_host(rand(Float32, n)) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + vh, ixh = Array(v), Array(ix) + @test is_valid_perm(vh, ixh) + end + + # ── Data distributions ─────────────────────────────────────────────────── + n = 2^14 + Random.seed!(456) + base = rand(Float32, n) + + for arr in ( + sort(base), # already sorted + reverse(sort(base)), # reverse sorted + fill(1f0, n), # all same + Float32.(rand(1:4, n)), # 4 unique values + ) + v = array_from_host(arr) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + vh, ixh = Array(v), Array(ix) + @test is_valid_perm(vh, ixh) + end + + # ── Comparator options ─────────────────────────────────────────────────── + n = 10_000 + Random.seed!(789) + + for (kw, check_kw) in ( + ((rev=true,), (rev=true,)), + ((by=abs,), (by=abs,)), + ((by=abs, rev=true), (by=abs, rev=true)), + ((lt=(>),), (lt=(>),)), + ) + v = array_from_host(randn(Float32, n)) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v; kw...) + vh, ixh = Array(v), Array(ix) + @test is_valid_perm(vh, ixh; check_kw...) + end + + # ── temp kwarg: buffer reuse gives identical result ─────────────────────── + n = 20_000 + Random.seed!(321) + v1 = array_from_host(rand(Float32, n)) + v2 = copy(v1) + ix1 = array_from_host(zeros(Int, n)) + ix2 = array_from_host(zeros(Int, n)) + temp = array_from_host(zeros(Int, n)) + AK.sortperm!(ix1, v1; temp) + AK.sortperm!(ix2, v2; temp) + @test Array(ix1) == Array(ix2) + + # ── Exact match against Base.sortperm ──────────────────────────────────── + for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Int32, Float32, Float64)) + n = 10_000 + v_h = rand(T, n) + ref = sortperm(v_h) + v = array_from_host(v_h) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + ixh = Int.(Array(ix)) + @test v_h[ixh] == v_h[ref] + end + + # ── Stability: equal keys must preserve original relative order ─────────── + n = 10_000 + v_h = Int32.(mod.(1:n, 10)) # values 0..9 cycling, 1000 of each + v = array_from_host(v_h) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + ixh = Array(ix) + for k in 0:9 + group = ixh[v_h[ixh] .== k] + @test issorted(group) # within each equal-key group, indices must be ascending + end + + # ── sortperm does not mutate the input ─────────────────────────────────── + v = array_from_host(rand(Float32, 5_000)) + vbak = copy(v) + AK.sortperm(v) + @test Array(v) == Array(vbak) +end +end