Skip to content

Commit 15b38ca

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Optimize single-rank performance to match native Julia
- Replace Dict with merge-sort style operations for sorted col_indices - Add merge_sorted_unique() and build_subset_mapping() helpers for O(n+m) merging - Use searchsortedfirst() instead of Dict for index lookups - Simplify A+B/A-B to ~20 lines each (was ~60 lines) - Fix norm(v,1) and sum(v) to use BLAS-optimized local operations - Make structural hash computation lazy for A*B, A+B, A-B results - Add benchmark script tools/benchmark_single_rank.jl Performance improvements (1 MPI rank vs native Julia): - Regular sparse A+B: 4.6x → 1.02x - Hypersparse A+B: 2.9x → 1.45x - norm(v,1): 28x → 1.02x - sum(v): 6.4x → 1.02x - sparse A*B: 3.1x → 1.38x
1 parent eae3ed9 commit 15b38ca

5 files changed

Lines changed: 507 additions & 95 deletions

File tree

src/indexing.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,9 +1135,8 @@ function Base.getindex(A::SparseMatrixMPI{T}, row_rng::UnitRange{Int}, col_rng::
11351135
# rowval_list contains positions into new_col_indices
11361136
if !isempty(rowval_list)
11371137
unique_positions = sort(unique(rowval_list))
1138-
# Map positions to compressed 1-based indices
1139-
pos_to_compressed = Dict(p => i for (i, p) in enumerate(unique_positions))
1140-
compressed_rowval = [pos_to_compressed[r] for r in rowval_list]
1138+
# Map positions to compressed indices: unique_positions is sorted, use binary search
1139+
compressed_rowval = [searchsortedfirst(unique_positions, r) for r in rowval_list]
11411140
# final_col_indices maps compressed index to global column in result
11421141
# new_col_indices contains the shifted global column indices
11431142
final_col_indices = new_col_indices[unique_positions]
@@ -3594,10 +3593,11 @@ function Base.getindex(A::SparseMatrixMPI{T}, i::Int, col_idx::VectorMPI{Int}) w
35943593
col_indices = A.col_indices
35953594

35963595
row_data = zeros(T, length(col_indices_result))
3597-
col_idx_map = Dict(j => c for (c, j) in enumerate(col_indices_result))
3596+
# col_indices_result is sorted, use binary search instead of Dict
35983597
for (local_j, global_j) in enumerate(col_indices)
3599-
if haskey(col_idx_map, global_j)
3600-
row_data[col_idx_map[global_j]] = local_A[local_j, local_row]
3598+
idx = searchsortedfirst(col_indices_result, global_j)
3599+
if idx <= length(col_indices_result) && col_indices_result[idx] == global_j
3600+
row_data[idx] = local_A[local_j, local_row]
36013601
end
36023602
end
36033603
else

src/sparse.jl

Lines changed: 127 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,90 @@
11
# SparseMatrixMPI type and sparse matrix operations
22

3+
# ============================================================================
4+
# Merge-sort style helpers for sorted column index arrays
5+
# ============================================================================
6+
7+
"""
8+
merge_sorted_unique!(result, a, b)
9+
10+
Merge two sorted arrays into a sorted array of unique elements.
11+
Returns the result array (which may be shorter than allocated).
12+
"""
13+
function merge_sorted_unique!(result::Vector{Int}, a::Vector{Int}, b::Vector{Int})
14+
i, j, k = 1, 1, 0
15+
@inbounds while i <= length(a) && j <= length(b)
16+
if a[i] < b[j]
17+
k += 1
18+
result[k] = a[i]
19+
i += 1
20+
elseif a[i] > b[j]
21+
k += 1
22+
result[k] = b[j]
23+
j += 1
24+
else # equal
25+
k += 1
26+
result[k] = a[i]
27+
i += 1
28+
j += 1
29+
end
30+
end
31+
@inbounds while i <= length(a)
32+
k += 1
33+
result[k] = a[i]
34+
i += 1
35+
end
36+
@inbounds while j <= length(b)
37+
k += 1
38+
result[k] = b[j]
39+
j += 1
40+
end
41+
return resize!(result, k)
42+
end
43+
44+
"""
45+
merge_sorted_unique(a, b)
46+
47+
Merge two sorted arrays into a new sorted array of unique elements.
48+
O(n+m) time, no sorting or Dict needed.
49+
"""
50+
function merge_sorted_unique(a::Vector{Int}, b::Vector{Int})
51+
result = Vector{Int}(undef, length(a) + length(b))
52+
return merge_sorted_unique!(result, a, b)
53+
end
54+
55+
"""
56+
build_subset_mapping!(mapping, subset, superset)
57+
58+
Build a mapping from subset indices to superset positions.
59+
Both arrays must be sorted, and subset must be a subset of superset.
60+
mapping[i] = position of subset[i] in superset.
61+
O(|subset| + |superset|) time with linear scan.
62+
"""
63+
function build_subset_mapping!(mapping::Vector{Int}, subset::Vector{Int}, superset::Vector{Int})
64+
j = 1 # position in superset
65+
@inbounds for i in 1:length(subset)
66+
while superset[j] < subset[i]
67+
j += 1
68+
end
69+
# Now superset[j] == subset[i]
70+
mapping[i] = j
71+
end
72+
return mapping
73+
end
74+
75+
"""
76+
build_subset_mapping(subset, superset)
77+
78+
Build a mapping from subset indices to superset positions.
79+
Returns a new vector where mapping[i] = position of subset[i] in superset.
80+
"""
81+
function build_subset_mapping(subset::Vector{Int}, superset::Vector{Int})
82+
mapping = Vector{Int}(undef, length(subset))
83+
return build_subset_mapping!(mapping, subset, superset)
84+
end
85+
86+
# ============================================================================
87+
388
"""
489
compute_structural_hash(row_partition, col_indices, AT, comm) -> Blake3Hash
590
@@ -47,8 +132,8 @@ function compress_AT(AT::SparseMatrixCSC{T,Int}, col_indices::Vector{Int}) where
47132
if isempty(col_indices)
48133
return SparseMatrixCSC(0, AT.n, AT.colptr, Int[], T[])
49134
end
50-
global_to_local = Dict(g => l for (l, g) in enumerate(col_indices))
51-
compressed_rowval = [global_to_local[r] for r in AT.rowval]
135+
# col_indices is sorted, use binary search instead of Dict
136+
compressed_rowval = [searchsortedfirst(col_indices, r) for r in AT.rowval]
52137
return SparseMatrixCSC(length(col_indices), AT.n, AT.colptr, compressed_rowval, AT.nzval)
53138
end
54139

@@ -130,28 +215,25 @@ function _rebuild_AT_with_insertions(AT::SparseMatrixCSC{T,Int}, col_indices::Ve
130215
# Build expanded col_indices (merge existing and new, maintain sorted order)
131216
expanded_col_indices = sort(unique(vcat(col_indices, collect(new_global_cols))))
132217

133-
# Build global->local mapping for expanded col_indices
134-
global_to_local = Dict(g => l for (l, g) in enumerate(expanded_col_indices))
135-
136218
# Collect all entries: (AT_col, AT_row, val) = (local_row, local_col_in_expanded, val)
137219
# Using a Dict to handle duplicates (later values win)
138220
entries = Dict{Tuple{Int,Int},T}()
139221

140222
# Add existing entries from AT (reindex to expanded col_indices)
141-
old_global_to_local = Dict(g => l for (l, g) in enumerate(col_indices))
223+
# expanded_col_indices is sorted, use binary search
142224
for at_col in 1:n_local_rows
143225
for k in AT.colptr[at_col]:(AT.colptr[at_col+1]-1)
144226
old_local_col = AT.rowval[k]
145227
global_col = col_indices[old_local_col]
146-
new_local_col = global_to_local[global_col]
228+
new_local_col = searchsortedfirst(expanded_col_indices, global_col)
147229
entries[(at_col, new_local_col)] = AT.nzval[k]
148230
end
149231
end
150232

151233
# Add new insertions (may overwrite existing)
152234
for (global_i, global_j, val) in insertions
153235
local_row = global_i - row_offset + 1 # AT column
154-
local_col = global_to_local[global_j] # AT row
236+
local_col = searchsortedfirst(expanded_col_indices, global_j) # AT row
155237
entries[(local_row, local_col)] = val
156238
end
157239

@@ -821,10 +903,11 @@ function Base.:*(A::SparseMatrixMPI{T}, B::SparseMatrixMPI{T}) where T
821903
end
822904

823905
compressed_result_AT = compress_AT_cached(result_AT, compress_map, length(result_col_indices))
824-
result_hash = compute_structural_hash(A.row_partition, result_col_indices, compressed_result_AT, comm)
906+
# Hash computed lazily by _ensure_hash if/when needed for plan caching
907+
# This avoids expensive hash computation for results that won't be reused in matrix-matrix multiply
825908

826909
# C = A * B has rows from A and columns from B
827-
return SparseMatrixMPI{T}(result_hash, A.row_partition, B.col_partition, result_col_indices, transpose(compressed_result_AT),
910+
return SparseMatrixMPI{T}(nothing, A.row_partition, B.col_partition, result_col_indices, transpose(compressed_result_AT),
828911
nothing)
829912
end
830913

@@ -834,44 +917,24 @@ end
834917
Add two distributed sparse matrices. The result has A's row partition.
835918
"""
836919
function Base.:+(A::SparseMatrixMPI{T}, B::SparseMatrixMPI{T}) where T
837-
comm = MPI.COMM_WORLD
838-
839-
# Repartition B to match A's row partition
840920
B_repart = repartition(B, A.row_partition)
841921

842-
# Both A and B_repart now have:
843-
# - Same row_partition (A.row_partition)
844-
# - Local CSC storage with compressed column indices
845-
# - col_indices mapping compressed → global
846-
847-
# Compute union of column indices
848-
union_indices = sort(unique(vcat(A.col_indices, B_repart.col_indices)))
849-
union_size = length(union_indices)
850-
851-
# Build mappings: compressed → union for both A and B_repart
852-
global_to_union = Dict(g => l for (l, g) in enumerate(union_indices))
853-
A_col_to_union = [global_to_union[g] for g in A.col_indices]
854-
B_col_to_union = [global_to_union[g] for g in B_repart.col_indices]
922+
if A.col_indices == B_repart.col_indices
923+
return SparseMatrixMPI{T}(nothing, A.row_partition, A.col_partition,
924+
A.col_indices, transpose(A.A.parent + B_repart.A.parent), nothing)
925+
end
855926

856-
# Reindex both to union space
857-
A_union = reindex_to_union_cached(A.A.parent, A_col_to_union, union_size)
858-
B_union = reindex_to_union_cached(B_repart.A.parent, B_col_to_union, union_size)
927+
# Merge sorted col_indices and build mappings via linear scan (no Dict)
928+
union_cols = merge_sorted_unique(A.col_indices, B_repart.col_indices)
929+
A_map = build_subset_mapping(A.col_indices, union_cols)
930+
B_map = build_subset_mapping(B_repart.col_indices, union_cols)
859931

860932
# Add in union space
861-
result_union = A_union + B_union
862-
863-
# Convert result back to compressed indices
864-
result_col_indices_local = isempty(result_union.rowval) ? Int[] : unique(sort(result_union.rowval))
865-
result_col_indices = isempty(result_col_indices_local) ? Int[] : union_indices[result_col_indices_local]
933+
result = reindex_to_union_cached(A.A.parent, A_map, length(union_cols)) +
934+
reindex_to_union_cached(B_repart.A.parent, B_map, length(union_cols))
866935

867-
# Compress result
868-
compressed_result = compress_AT(result_union, result_col_indices_local)
869-
870-
# Compute hash from actual result
871-
result_hash = compute_structural_hash(A.row_partition, result_col_indices, compressed_result, comm)
872-
873-
return SparseMatrixMPI{T}(result_hash, A.row_partition, A.col_partition,
874-
result_col_indices, transpose(compressed_result), nothing)
936+
return SparseMatrixMPI{T}(nothing, A.row_partition, A.col_partition,
937+
union_cols, transpose(result), nothing)
875938
end
876939

877940
"""
@@ -880,44 +943,24 @@ end
880943
Subtract two distributed sparse matrices. The result has A's row partition.
881944
"""
882945
function Base.:-(A::SparseMatrixMPI{T}, B::SparseMatrixMPI{T}) where T
883-
comm = MPI.COMM_WORLD
884-
885-
# Repartition B to match A's row partition
886946
B_repart = repartition(B, A.row_partition)
887947

888-
# Both A and B_repart now have:
889-
# - Same row_partition (A.row_partition)
890-
# - Local CSC storage with compressed column indices
891-
# - col_indices mapping compressed → global
892-
893-
# Compute union of column indices
894-
union_indices = sort(unique(vcat(A.col_indices, B_repart.col_indices)))
895-
union_size = length(union_indices)
896-
897-
# Build mappings: compressed → union for both A and B_repart
898-
global_to_union = Dict(g => l for (l, g) in enumerate(union_indices))
899-
A_col_to_union = [global_to_union[g] for g in A.col_indices]
900-
B_col_to_union = [global_to_union[g] for g in B_repart.col_indices]
948+
if A.col_indices == B_repart.col_indices
949+
return SparseMatrixMPI{T}(nothing, A.row_partition, A.col_partition,
950+
A.col_indices, transpose(A.A.parent - B_repart.A.parent), nothing)
951+
end
901952

902-
# Reindex both to union space
903-
A_union = reindex_to_union_cached(A.A.parent, A_col_to_union, union_size)
904-
B_union = reindex_to_union_cached(B_repart.A.parent, B_col_to_union, union_size)
953+
# Merge sorted col_indices and build mappings via linear scan (no Dict)
954+
union_cols = merge_sorted_unique(A.col_indices, B_repart.col_indices)
955+
A_map = build_subset_mapping(A.col_indices, union_cols)
956+
B_map = build_subset_mapping(B_repart.col_indices, union_cols)
905957

906958
# Subtract in union space
907-
result_union = A_union - B_union
908-
909-
# Convert result back to compressed indices
910-
result_col_indices_local = isempty(result_union.rowval) ? Int[] : unique(sort(result_union.rowval))
911-
result_col_indices = isempty(result_col_indices_local) ? Int[] : union_indices[result_col_indices_local]
912-
913-
# Compress result
914-
compressed_result = compress_AT(result_union, result_col_indices_local)
959+
result = reindex_to_union_cached(A.A.parent, A_map, length(union_cols)) -
960+
reindex_to_union_cached(B_repart.A.parent, B_map, length(union_cols))
915961

916-
# Compute hash from actual result
917-
result_hash = compute_structural_hash(A.row_partition, result_col_indices, compressed_result, comm)
918-
919-
return SparseMatrixMPI{T}(result_hash, A.row_partition, A.col_partition,
920-
result_col_indices, transpose(compressed_result), nothing)
962+
return SparseMatrixMPI{T}(nothing, A.row_partition, A.col_partition,
963+
union_cols, transpose(result), nothing)
921964
end
922965

923966
"""
@@ -1122,7 +1165,11 @@ function TransposePlan(A::SparseMatrixMPI{T}) where T
11221165
local_src_indices = Int[]
11231166
local_dst_indices = Int[]
11241167

1125-
recv_rank_to_idx = Dict(r => i for (i, r) in enumerate(recv_rank_ids))
1168+
# Map rank -> index in recv_rank_ids using Vector (nranks is small)
1169+
recv_rank_to_idx = zeros(Int, nranks)
1170+
for (i, r) in enumerate(recv_rank_ids)
1171+
recv_rank_to_idx[r+1] = i
1172+
end
11261173

11271174
for (ent_idx, (_, _, src_rank, src_idx)) in enumerate(entries)
11281175
dst_idx = entry_to_nzval_idx[ent_idx]
@@ -1131,7 +1178,7 @@ function TransposePlan(A::SparseMatrixMPI{T}) where T
11311178
push!(local_dst_indices, dst_idx)
11321179
else
11331180
# Use indexed assignment: src_idx is the position in recv_buf from src_rank
1134-
recv_perm[recv_rank_to_idx[src_rank]][src_idx] = dst_idx
1181+
recv_perm[recv_rank_to_idx[src_rank+1]][src_idx] = dst_idx
11351182
end
11361183
end
11371184

@@ -2182,9 +2229,8 @@ function triu(A::SparseMatrixMPI{T}, k::Integer=0) where T
21822229
else
21832230
local_used = unique(sort(new_rowval))
21842231
new_col_indices = col_indices[local_used] # convert to global
2185-
# Compress: map old local indices to new compressed indices
2186-
local_to_compressed = Dict(old => new for (new, old) in enumerate(local_used))
2187-
compressed_rowval = [local_to_compressed[r] for r in new_rowval]
2232+
# Compress: local_used is sorted, use binary search instead of Dict
2233+
compressed_rowval = [searchsortedfirst(local_used, r) for r in new_rowval]
21882234
compressed_AT = SparseMatrixCSC(length(new_col_indices), size(A.A.parent, 2),
21892235
new_colptr, compressed_rowval, new_nzval)
21902236
end
@@ -2256,9 +2302,8 @@ function tril(A::SparseMatrixMPI{T}, k::Integer=0) where T
22562302
else
22572303
local_used = unique(sort(new_rowval))
22582304
new_col_indices = col_indices[local_used] # convert to global
2259-
# Compress: map old local indices to new compressed indices
2260-
local_to_compressed = Dict(old => new for (new, old) in enumerate(local_used))
2261-
compressed_rowval = [local_to_compressed[r] for r in new_rowval]
2305+
# Compress: local_used is sorted, use binary search instead of Dict
2306+
compressed_rowval = [searchsortedfirst(local_used, r) for r in new_rowval]
22622307
compressed_AT = SparseMatrixCSC(length(new_col_indices), size(A.A.parent, 2),
22632308
new_colptr, compressed_rowval, new_nzval)
22642309
end

src/vectors.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -560,18 +560,21 @@ function LinearAlgebra.norm(v::VectorMPI{T}, p::Real=2) where T
560560
comm = MPI.COMM_WORLD
561561

562562
if p == 2
563-
local_sum = sum(abs2, v.v; init=zero(real(T)))
563+
# Use BLAS-optimized local norm, then reduce
564+
local_nrm = isempty(v.v) ? zero(real(T)) : norm(v.v)
565+
local_sum = local_nrm * local_nrm
564566
global_sum = MPI.Allreduce(local_sum, MPI.SUM, comm)
565567
return sqrt(global_sum)
566568
elseif p == 1
567-
local_sum = sum(abs, v.v; init=zero(real(T)))
569+
# Use BLAS-optimized local norm(v, 1) = asum
570+
local_sum = isempty(v.v) ? zero(real(T)) : norm(v.v, 1)
568571
return MPI.Allreduce(local_sum, MPI.SUM, comm)
569572
elseif p == Inf
570-
local_max = isempty(v.v) ? zero(real(T)) : maximum(abs, v.v)
573+
local_max = isempty(v.v) ? zero(real(T)) : norm(v.v, Inf)
571574
return MPI.Allreduce(local_max, MPI.MAX, comm)
572575
else
573-
# General p-norm
574-
local_sum = sum(x -> abs(x)^p, v.v; init=zero(real(T)))
576+
# General p-norm - no BLAS optimization available
577+
local_sum = isempty(v.v) ? zero(real(T)) : sum(x -> abs(x)^p, v.v)
575578
global_sum = MPI.Allreduce(local_sum, MPI.SUM, comm)
576579
return global_sum^(1 / p)
577580
end
@@ -636,7 +639,8 @@ Compute the sum of all elements in the distributed vector.
636639
"""
637640
function Base.sum(v::VectorMPI{T}) where T
638641
comm = MPI.COMM_WORLD
639-
local_sum = sum(v.v; init=zero(T))
642+
# Use native sum without init for better performance; handle empty with ternary
643+
local_sum = isempty(v.v) ? zero(T) : sum(v.v)
640644
return MPI.Allreduce(local_sum, MPI.SUM, comm)
641645
end
642646

@@ -647,7 +651,8 @@ Compute the product of all elements in the distributed vector.
647651
"""
648652
function Base.prod(v::VectorMPI{T}) where T
649653
comm = MPI.COMM_WORLD
650-
local_prod = prod(v.v; init=one(T))
654+
# Use native prod without init for better performance; handle empty with ternary
655+
local_prod = isempty(v.v) ? one(T) : prod(v.v)
651656
return MPI.Allreduce(local_prod, MPI.PROD, comm)
652657
end
653658

tools/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Blake3Hash = "8f478455-a32d-4928-b0e4-72b19a7d5574"
34
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
45
LinearAlgebraMPI = "5bdd2be4-ae34-42ef-8b36-f4c85d48f377"
56
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
7+
MUMPS = "55d2b088-9f4e-11e9-26c0-150b02ea6a46"
8+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
69
SafePETSc = "50acdc01-ce88-4ca7-bd87-6916c254362e"
710
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
811

0 commit comments

Comments
 (0)