Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/sort/merge_sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,20 @@ function merge_sort!(
)
# Simple sanity checks
@argcheck block_size > 0

# Hoist `by` transform: broadcast it once to produce a key array, then sort
# (key, value) pairs. Without hoisting, by(elem) fires inside every binary-search
# comparison in the O(n log²n) merge hot-path — once per element per merge step.
if by !== identity
keys = by.(v)
merge_sort_by_key!(
keys, v, backend;
lt, rev, order, block_size,
temp_values=temp, # temp was for v swap buffer; maps to temp_values here
)
return v
end

if !isnothing(temp)
@argcheck length(temp) == length(v)
@argcheck eltype(temp) === eltype(v)
Expand Down
9 changes: 7 additions & 2 deletions src/sort/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,15 @@ function _sortperm_impl!(
temp::Union{Nothing, AbstractArray}=nothing,
)
if use_gpu_algorithm(backend, prefer_threads)
merge_sortperm_lowmem!(
# merge_sortperm! copies keys alongside indices in shared memory so comparisons
# never touch global memory during the binary-search step.
# merge_sortperm_lowmem! avoids the key copy but its comparator does two global
# loads per comparison, making it O(n log²n) in global traffic at large n.
merge_sortperm!(
ix, v, backend;
lt, by, rev, order,
block_size, temp,
block_size,
temp_ix=temp, # old `temp` was the index buffer; maps directly to temp_ix
)
else
sample_sortperm!(
Expand Down
186 changes: 186 additions & 0 deletions test/generic/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,79 @@ if !IS_CPU_BACKEND || !prefer_threads
block_size=64, temp=array_from_host(1:10_000, Int32))
@test issorted(Array(v))
end

@testset "sort_by_transform" begin
# Tests for the by= hoisting optimisation: by(elem) is broadcast once before
# sorting rather than being called inside every merge comparison.
# Checks exact output match against Base.sort so we catch ordering regressions.
Random.seed!(42)

# Exact match against Base.sort for common by= functions
for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Float32, Float64, Int32))
n = 10_000
v_h = T <: AbstractFloat ? randn(T, n) : rand(T(-100):T(100), n)
for (kw, base_kw) in (
((by=abs,), (by=abs,)),
((by=abs, rev=true), (by=abs, rev=true)),
((by=x->x^2,), (by=x->x^2,)),
)
v = array_from_host(v_h)
tmp = copy(v)
AK.merge_sort!(tmp; kw...)
@test Array(tmp) == sort(v_h; base_kw...)
end
end

# rev=true and lt=(>) are not hoisted (no by=) — verify they still pass
n = 10_000
v_h = randn(Float32, n)
v = array_from_host(v_h); tmp = copy(v)
AK.merge_sort!(tmp; rev=true)
@test Array(tmp) == sort(v_h; rev=true)

# Edge sizes under by= hoisting
for n in (1, 2, 513, 1025)
v_h = randn(Float32, n)
v = array_from_host(v_h)
tmp = copy(v)
AK.merge_sort!(tmp; by=abs)
@test Array(tmp) == sort(v_h; by=abs)
end

# temp kwarg still forwarded correctly through hoisting path
n = 20_000
v_h = randn(Float32, n)
v = array_from_host(v_h)
tmp = copy(v)
temp = array_from_host(zeros(Float32, n))
AK.merge_sort!(tmp; by=abs, temp)
@test Array(tmp) == sort(v_h; by=abs)

# sort! (public API) routes through the same hoisting path
n = 10_000
v_h = randn(Float32, n)
v = array_from_host(v_h)
tmp = copy(v)
AK.sort!(tmp; by=abs)
@test Array(tmp) == sort(v_h; by=abs)

# by= with a type-changing transform (Float32 → Bool key)
n = 10_000
v_h = randn(Float32, n)
v = array_from_host(v_h)
tmp = copy(v)
AK.merge_sort!(tmp; by=x->x>0)
@test Array(tmp) == sort(v_h; by=x->x>0)

# identity path unchanged: verify no regression from the early-return guard
n = 10_000
v_h = rand(Float32, n)
v = array_from_host(v_h)
tmp = copy(v)
AK.merge_sort!(tmp)
@test Array(tmp) == sort(v_h)
end

else # CPU backend
@testset "sample_sort" begin
Random.seed!(0)
Expand Down Expand Up @@ -453,3 +526,116 @@ end
vh = Array(v)
@test issorted(vh[ixh])
end


if !IS_CPU_BACKEND || !prefer_threads
@testset "sortperm_extended" begin
# Helper: ix is a valid permutation of 1:n that produces a sorted order
function is_valid_perm(vh, ixh; kwargs...)
n = length(vh)
length(ixh) == n &&
sort(Int.(ixh)) == collect(1:n) &&
issorted(vh[ixh]; kwargs...)
end

# ── Element types ────────────────────────────────────────────────────────
Random.seed!(123)

for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Int16, UInt16, Int64, UInt64, Float64, UInt8))
for _ in 1:50
n = rand(1:50_000)
v = array_from_host(rand(T, n))
ix = array_from_host(zeros(Int, n))
AK.sortperm!(ix, v)
vh, ixh = Array(v), Array(ix)
@test is_valid_perm(vh, ixh)
end
end

# ── Edge sizes ───────────────────────────────────────────────────────────
for n in (1, 2, 3, 511, 512, 513, 1023, 1024, 1025, 2047, 2048, 2049)
v = array_from_host(rand(Float32, n))
ix = array_from_host(zeros(Int, n))
AK.sortperm!(ix, v)
vh, ixh = Array(v), Array(ix)
@test is_valid_perm(vh, ixh)
end

# ── Data distributions ───────────────────────────────────────────────────
n = 2^14
Random.seed!(456)
base = rand(Float32, n)

for arr in (
sort(base), # already sorted
reverse(sort(base)), # reverse sorted
fill(1f0, n), # all same
Float32.(rand(1:4, n)), # 4 unique values
)
v = array_from_host(arr)
ix = array_from_host(zeros(Int, n))
AK.sortperm!(ix, v)
vh, ixh = Array(v), Array(ix)
@test is_valid_perm(vh, ixh)
end

# ── Comparator options ───────────────────────────────────────────────────
n = 10_000
Random.seed!(789)

for (kw, check_kw) in (
((rev=true,), (rev=true,)),
((by=abs,), (by=abs,)),
((by=abs, rev=true), (by=abs, rev=true)),
((lt=(>),), (lt=(>),)),
)
v = array_from_host(randn(Float32, n))
ix = array_from_host(zeros(Int, n))
AK.sortperm!(ix, v; kw...)
vh, ixh = Array(v), Array(ix)
@test is_valid_perm(vh, ixh; check_kw...)
end

# ── temp kwarg: buffer reuse gives identical result ───────────────────────
n = 20_000
Random.seed!(321)
v1 = array_from_host(rand(Float32, n))
v2 = copy(v1)
ix1 = array_from_host(zeros(Int, n))
ix2 = array_from_host(zeros(Int, n))
temp = array_from_host(zeros(Int, n))
AK.sortperm!(ix1, v1; temp)
AK.sortperm!(ix2, v2; temp)
@test Array(ix1) == Array(ix2)

# ── Exact match against Base.sortperm ────────────────────────────────────
for T in filter(T -> T !== Float64 || KernelAbstractions.supports_float64(BACKEND), (Int32, Float32, Float64))
n = 10_000
v_h = rand(T, n)
ref = sortperm(v_h)
v = array_from_host(v_h)
ix = array_from_host(zeros(Int, n))
AK.sortperm!(ix, v)
ixh = Int.(Array(ix))
@test v_h[ixh] == v_h[ref]
end

# ── Stability: equal keys must preserve original relative order ───────────
n = 10_000
v_h = Int32.(mod.(1:n, 10)) # values 0..9 cycling, 1000 of each
v = array_from_host(v_h)
ix = array_from_host(zeros(Int, n))
AK.sortperm!(ix, v)
ixh = Array(ix)
for k in 0:9
group = ixh[v_h[ixh] .== k]
@test issorted(group) # within each equal-key group, indices must be ascending
end

# ── sortperm does not mutate the input ───────────────────────────────────
v = array_from_host(rand(Float32, 5_000))
vbak = copy(v)
AK.sortperm(v)
@test Array(v) == Array(vbak)
end
end
Loading