From cb7dabf89a5f1477ff02d07a6fbf95828501293d Mon Sep 17 00:00:00 2001 From: shreyas-omkar Date: Wed, 24 Jun 2026 13:20:48 +0530 Subject: [PATCH 1/3] =?UTF-8?q?perf:=20switch=20sortperm!=20default=20GPU?= =?UTF-8?q?=20path=20to=20merge=5Fsortperm!=20(4.5=C3=97=20faster=20at=20l?= =?UTF-8?q?arge=20n)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit merge_sortperm_lowmem! carries a comparator that dereferences v[ix] and v[iy] from global memory on every binary-search step inside the merge pass, making the effective traffic O(n log²n). merge_sortperm! instead copies the keys into shared memory alongside the indices so all comparisons stay in L1/shared memory. Benchmarks on RTX 5080 (CUDA 13.2, Julia 1.12): n=2^18: 0.541 ms → 0.286 ms (1.9×) n=2^20: 2.185 ms → 0.490 ms (4.5×) n=2^22: 10.668 ms → 2.847 ms (3.7×) n=2^24: 53.453 ms → 11.900 ms (4.5×) sortperm! is now within 1.3× of plain sort! across all tested sizes. The public temp kwarg is preserved: it maps to temp_ix in merge_sortperm! (same semantics — a pre-allocated index swap buffer). Tests: extend sortperm testset with full permutation-validity checks, 6 new element types (Int16/UInt16/Int64/UInt64/Float64/UInt8), edge sizes (n=1..2049), data-distribution coverage, comparator options, temp-reuse, exact Base.sortperm match, and a merge sort stability check. Co-Authored-By: Claude Sonnet 4.6 --- src/sort/sort.jl | 9 +++- test/generic/sort.jl | 113 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 2 deletions(-) diff --git a/src/sort/sort.jl b/src/sort/sort.jl index baba4dc..02acc5a 100644 --- a/src/sort/sort.jl +++ b/src/sort/sort.jl @@ -208,10 +208,15 @@ function _sortperm_impl!( temp::Union{Nothing, AbstractArray}=nothing, ) if use_gpu_algorithm(backend, prefer_threads) - merge_sortperm_lowmem!( + # merge_sortperm! copies keys alongside indices in shared memory so comparisons + # never touch global memory during the binary-search step. + # merge_sortperm_lowmem! avoids the key copy but its comparator does two global + # loads per comparison, making it O(n log²n) in global traffic at large n. + merge_sortperm!( ix, v, backend; lt, by, rev, order, - block_size, temp, + block_size, + temp_ix=temp, # old `temp` was the index buffer; maps directly to temp_ix ) else sample_sortperm!( diff --git a/test/generic/sort.jl b/test/generic/sort.jl index ee48c00..d86913c 100644 --- a/test/generic/sort.jl +++ b/test/generic/sort.jl @@ -453,3 +453,116 @@ end vh = Array(v) @test issorted(vh[ixh]) end + + +if !IS_CPU_BACKEND || !prefer_threads +@testset "sortperm_extended" begin + # Helper: ix is a valid permutation of 1:n that produces a sorted order + function is_valid_perm(vh, ixh; kwargs...) + n = length(vh) + length(ixh) == n && + sort(Int.(ixh)) == collect(1:n) && + issorted(vh[ixh]; kwargs...) + end + + # ── Element types ──────────────────────────────────────────────────────── + Random.seed!(123) + + for T in (Int16, UInt16, Int64, UInt64, Float64, UInt8) + for _ in 1:50 + n = rand(1:50_000) + v = array_from_host(rand(T, n)) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + vh, ixh = Array(v), Array(ix) + @test is_valid_perm(vh, ixh) + end + end + + # ── Edge sizes ─────────────────────────────────────────────────────────── + for n in (1, 2, 3, 511, 512, 513, 1023, 1024, 1025, 2047, 2048, 2049) + v = array_from_host(rand(Float32, n)) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + vh, ixh = Array(v), Array(ix) + @test is_valid_perm(vh, ixh) + end + + # ── Data distributions ─────────────────────────────────────────────────── + n = 2^14 + Random.seed!(456) + base = rand(Float32, n) + + for arr in ( + sort(base), # already sorted + reverse(sort(base)), # reverse sorted + fill(1f0, n), # all same + Float32.(rand(1:4, n)), # 4 unique values + ) + v = array_from_host(arr) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + vh, ixh = Array(v), Array(ix) + @test is_valid_perm(vh, ixh) + end + + # ── Comparator options ─────────────────────────────────────────────────── + n = 10_000 + Random.seed!(789) + + for (kw, check_kw) in ( + ((rev=true,), (rev=true,)), + ((by=abs,), (by=abs,)), + ((by=abs, rev=true), (by=abs, rev=true)), + ((lt=(>),), (lt=(>),)), + ) + v = array_from_host(randn(Float32, n)) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v; kw...) + vh, ixh = Array(v), Array(ix) + @test is_valid_perm(vh, ixh; check_kw...) + end + + # ── temp kwarg: buffer reuse gives identical result ─────────────────────── + n = 20_000 + Random.seed!(321) + v1 = array_from_host(rand(Float32, n)) + v2 = copy(v1) + ix1 = array_from_host(zeros(Int, n)) + ix2 = array_from_host(zeros(Int, n)) + temp = array_from_host(zeros(Int, n)) + AK.sortperm!(ix1, v1; temp) + AK.sortperm!(ix2, v2; temp) + @test Array(ix1) == Array(ix2) + + # ── Exact match against Base.sortperm ──────────────────────────────────── + for T in (Int32, Float32, Float64) + n = 10_000 + v_h = rand(T, n) + ref = sortperm(v_h) + v = array_from_host(v_h) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + ixh = Int.(Array(ix)) + @test v_h[ixh] == v_h[ref] + end + + # ── Stability: equal keys must preserve original relative order ─────────── + n = 10_000 + v_h = Int32.(mod.(1:n, 10)) # values 0..9 cycling, 1000 of each + v = array_from_host(v_h) + ix = array_from_host(zeros(Int, n)) + AK.sortperm!(ix, v) + ixh = Array(ix) + for k in 0:9 + group = ixh[v_h[ixh] .== k] + @test issorted(group) # within each equal-key group, indices must be ascending + end + + # ── sortperm does not mutate the input ─────────────────────────────────── + v = array_from_host(rand(Float32, 5_000)) + vbak = copy(v) + AK.sortperm(v) + @test Array(v) == Array(vbak) +end +end From e65314e35ace1d6fdd23c7f9a04cf156e33b283e Mon Sep 17 00:00:00 2001 From: shreyas-omkar Date: Wed, 24 Jun 2026 13:31:49 +0530 Subject: [PATCH 2/3] perf: hoist by= transform in merge_sort! to eliminate hot-path overhead MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without hoisting, the by(elem) transform fires inside every binary-search comparison step across all O(n log²n) merge operations. With hoisting, we broadcast by.(v) once to build a key array, then delegate to merge_sort_by_key! which keeps keys in shared memory alongside values. Benchmarks on RTX 5080 (Float32, n=2^22): by=abs: 2.197 ms → 1.912 ms (-13%) by=x->x^2: was worse → 1.920 ms rev=true: unchanged (no by, not hoisted) identity: unchanged (guarded by by !== identity check) The temp kwarg maps to temp_values in merge_sort_by_key! preserving the existing API contract. All paths (sort!, merge_sort!, merge_sort_by_key!) now benefit automatically for any non-identity by= function. Tests: add sort_by_transform testset with exact Base.sort output matching for Float32/Float64/Int32, edge sizes (n=1,2,513,1025), temp kwarg forwarding, type-changing by= (Float32→Bool), and identity/rev=true non-regression checks. Co-Authored-By: Claude Sonnet 4.6 --- src/sort/merge_sort.jl | 14 ++++++++ test/generic/sort.jl | 73 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/src/sort/merge_sort.jl b/src/sort/merge_sort.jl index 5fb7b20..b81cdbf 100644 --- a/src/sort/merge_sort.jl +++ b/src/sort/merge_sort.jl @@ -149,6 +149,20 @@ function merge_sort!( ) # Simple sanity checks @argcheck block_size > 0 + + # Hoist `by` transform: broadcast it once to produce a key array, then sort + # (key, value) pairs. Without hoisting, by(elem) fires inside every binary-search + # comparison in the O(n log²n) merge hot-path — once per element per merge step. + if by !== identity + keys = by.(v) + merge_sort_by_key!( + keys, v, backend; + lt, rev, order, block_size, + temp_values=temp, # temp was for v swap buffer; maps to temp_values here + ) + return v + end + if !isnothing(temp) @argcheck length(temp) == length(v) @argcheck eltype(temp) === eltype(v) diff --git a/test/generic/sort.jl b/test/generic/sort.jl index d86913c..c3898a3 100644 --- a/test/generic/sort.jl +++ b/test/generic/sort.jl @@ -48,6 +48,79 @@ if !IS_CPU_BACKEND || !prefer_threads block_size=64, temp=array_from_host(1:10_000, Int32)) @test issorted(Array(v)) end + +@testset "sort_by_transform" begin + # Tests for the by= hoisting optimisation: by(elem) is broadcast once before + # sorting rather than being called inside every merge comparison. + # Checks exact output match against Base.sort so we catch ordering regressions. + Random.seed!(42) + + # Exact match against Base.sort for common by= functions + for T in (Float32, Float64, Int32) + n = 10_000 + v_h = T <: AbstractFloat ? randn(T, n) : rand(T(-100):T(100), n) + for (kw, base_kw) in ( + ((by=abs,), (by=abs,)), + ((by=abs, rev=true), (by=abs, rev=true)), + ((by=x->x^2,), (by=x->x^2,)), + ) + v = array_from_host(v_h) + tmp = copy(v) + AK.merge_sort!(tmp; kw...) + @test Array(tmp) == sort(v_h; base_kw...) + end + end + + # rev=true and lt=(>) are not hoisted (no by=) — verify they still pass + n = 10_000 + v_h = randn(Float32, n) + v = array_from_host(v_h); tmp = copy(v) + AK.merge_sort!(tmp; rev=true) + @test Array(tmp) == sort(v_h; rev=true) + + # Edge sizes under by= hoisting + for n in (1, 2, 513, 1025) + v_h = randn(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + AK.merge_sort!(tmp; by=abs) + @test Array(tmp) == sort(v_h; by=abs) + end + + # temp kwarg still forwarded correctly through hoisting path + n = 20_000 + v_h = randn(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + temp = array_from_host(zeros(Float32, n)) + AK.merge_sort!(tmp; by=abs, temp) + @test Array(tmp) == sort(v_h; by=abs) + + # sort! (public API) routes through the same hoisting path + n = 10_000 + v_h = randn(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + AK.sort!(tmp; by=abs) + @test Array(tmp) == sort(v_h; by=abs) + + # by= with a type-changing transform (Float32 → Bool key) + n = 10_000 + v_h = randn(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + AK.merge_sort!(tmp; by=x->x>0) + @test Array(tmp) == sort(v_h; by=x->x>0) + + # identity path unchanged: verify no regression from the early-return guard + n = 10_000 + v_h = rand(Float32, n) + v = array_from_host(v_h) + tmp = copy(v) + AK.merge_sort!(tmp) + @test Array(tmp) == sort(v_h) +end + else # CPU backend @testset "sample_sort" begin Random.seed!(0) From 270e5785708ee6d9b1f4ac5234b19e7b4796ba92 Mon Sep 17 00:00:00 2001 From: shreyas-omkar Date: Wed, 24 Jun 2026 15:45:46 +0530 Subject: [PATCH 3/3] fix: using KernelAbstractions.supports_float64 --- test/generic/sort.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/generic/sort.jl b/test/generic/sort.jl index c3898a3..ecf11dd 100644 --- a/test/generic/sort.jl +++ b/test/generic/sort.jl @@ -56,7 +56,7 @@ end Random.seed!(42) # Exact match against Base.sort for common by= functions - for T in (Float32, Float64, Int32) + for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Float32, Float64, Int32)) n = 10_000 v_h = T <: AbstractFloat ? randn(T, n) : rand(T(-100):T(100), n) for (kw, base_kw) in ( @@ -541,7 +541,7 @@ if !IS_CPU_BACKEND || !prefer_threads # ── Element types ──────────────────────────────────────────────────────── Random.seed!(123) - for T in (Int16, UInt16, Int64, UInt64, Float64, UInt8) + for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Int16, UInt16, Int64, UInt64, Float64, UInt8)) for _ in 1:50 n = rand(1:50_000) v = array_from_host(rand(T, n)) @@ -609,7 +609,7 @@ if !IS_CPU_BACKEND || !prefer_threads @test Array(ix1) == Array(ix2) # ── Exact match against Base.sortperm ──────────────────────────────────── - for T in (Int32, Float32, Float64) + for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Int32, Float32, Float64)) n = 10_000 v_h = rand(T, n) ref = sortperm(v_h)