11# SparseMatrixMPI type and sparse matrix operations
22
3+ # ============================================================================
4+ # Merge-sort style helpers for sorted column index arrays
5+ # ============================================================================
6+
7+ """
8+ merge_sorted_unique!(result, a, b)
9+
10+ Merge two sorted arrays into a sorted array of unique elements.
11+ Returns the result array (which may be shorter than allocated).
12+ """
13+ function merge_sorted_unique! (result:: Vector{Int} , a:: Vector{Int} , b:: Vector{Int} )
14+ i, j, k = 1 , 1 , 0
15+ @inbounds while i <= length (a) && j <= length (b)
16+ if a[i] < b[j]
17+ k += 1
18+ result[k] = a[i]
19+ i += 1
20+ elseif a[i] > b[j]
21+ k += 1
22+ result[k] = b[j]
23+ j += 1
24+ else # equal
25+ k += 1
26+ result[k] = a[i]
27+ i += 1
28+ j += 1
29+ end
30+ end
31+ @inbounds while i <= length (a)
32+ k += 1
33+ result[k] = a[i]
34+ i += 1
35+ end
36+ @inbounds while j <= length (b)
37+ k += 1
38+ result[k] = b[j]
39+ j += 1
40+ end
41+ return resize! (result, k)
42+ end
43+
44+ """
45+ merge_sorted_unique(a, b)
46+
47+ Merge two sorted arrays into a new sorted array of unique elements.
48+ O(n+m) time, no sorting or Dict needed.
49+ """
50+ function merge_sorted_unique (a:: Vector{Int} , b:: Vector{Int} )
51+ result = Vector {Int} (undef, length (a) + length (b))
52+ return merge_sorted_unique! (result, a, b)
53+ end
54+
55+ """
56+ build_subset_mapping!(mapping, subset, superset)
57+
58+ Build a mapping from subset indices to superset positions.
59+ Both arrays must be sorted, and subset must be a subset of superset.
60+ mapping[i] = position of subset[i] in superset.
61+ O(|subset| + |superset|) time with linear scan.
62+ """
63+ function build_subset_mapping! (mapping:: Vector{Int} , subset:: Vector{Int} , superset:: Vector{Int} )
64+ j = 1 # position in superset
65+ @inbounds for i in 1 : length (subset)
66+ while superset[j] < subset[i]
67+ j += 1
68+ end
69+ # Now superset[j] == subset[i]
70+ mapping[i] = j
71+ end
72+ return mapping
73+ end
74+
75+ """
76+ build_subset_mapping(subset, superset)
77+
78+ Build a mapping from subset indices to superset positions.
79+ Returns a new vector where mapping[i] = position of subset[i] in superset.
80+ """
81+ function build_subset_mapping (subset:: Vector{Int} , superset:: Vector{Int} )
82+ mapping = Vector {Int} (undef, length (subset))
83+ return build_subset_mapping! (mapping, subset, superset)
84+ end
85+
86+ # ============================================================================
87+
388"""
489 compute_structural_hash(row_partition, col_indices, AT, comm) -> Blake3Hash
590
@@ -47,8 +132,8 @@ function compress_AT(AT::SparseMatrixCSC{T,Int}, col_indices::Vector{Int}) where
47132 if isempty (col_indices)
48133 return SparseMatrixCSC (0 , AT. n, AT. colptr, Int[], T[])
49134 end
50- global_to_local = Dict (g => l for (l, g) in enumerate (col_indices))
51- compressed_rowval = [global_to_local[r] for r in AT. rowval]
135+ # col_indices is sorted, use binary search instead of Dict
136+ compressed_rowval = [searchsortedfirst (col_indices, r) for r in AT. rowval]
52137 return SparseMatrixCSC (length (col_indices), AT. n, AT. colptr, compressed_rowval, AT. nzval)
53138end
54139
@@ -130,28 +215,25 @@ function _rebuild_AT_with_insertions(AT::SparseMatrixCSC{T,Int}, col_indices::Ve
130215 # Build expanded col_indices (merge existing and new, maintain sorted order)
131216 expanded_col_indices = sort (unique (vcat (col_indices, collect (new_global_cols))))
132217
133- # Build global->local mapping for expanded col_indices
134- global_to_local = Dict (g => l for (l, g) in enumerate (expanded_col_indices))
135-
136218 # Collect all entries: (AT_col, AT_row, val) = (local_row, local_col_in_expanded, val)
137219 # Using a Dict to handle duplicates (later values win)
138220 entries = Dict {Tuple{Int,Int},T} ()
139221
140222 # Add existing entries from AT (reindex to expanded col_indices)
141- old_global_to_local = Dict (g => l for (l, g) in enumerate (col_indices))
223+ # expanded_col_indices is sorted, use binary search
142224 for at_col in 1 : n_local_rows
143225 for k in AT. colptr[at_col]: (AT. colptr[at_col+ 1 ]- 1 )
144226 old_local_col = AT. rowval[k]
145227 global_col = col_indices[old_local_col]
146- new_local_col = global_to_local[ global_col]
228+ new_local_col = searchsortedfirst (expanded_col_indices, global_col)
147229 entries[(at_col, new_local_col)] = AT. nzval[k]
148230 end
149231 end
150232
151233 # Add new insertions (may overwrite existing)
152234 for (global_i, global_j, val) in insertions
153235 local_row = global_i - row_offset + 1 # AT column
154- local_col = global_to_local[ global_j] # AT row
236+ local_col = searchsortedfirst (expanded_col_indices, global_j) # AT row
155237 entries[(local_row, local_col)] = val
156238 end
157239
@@ -821,10 +903,11 @@ function Base.:*(A::SparseMatrixMPI{T}, B::SparseMatrixMPI{T}) where T
821903 end
822904
823905 compressed_result_AT = compress_AT_cached (result_AT, compress_map, length (result_col_indices))
824- result_hash = compute_structural_hash (A. row_partition, result_col_indices, compressed_result_AT, comm)
906+ # Hash computed lazily by _ensure_hash if/when needed for plan caching
907+ # This avoids expensive hash computation for results that won't be reused in matrix-matrix multiply
825908
826909 # C = A * B has rows from A and columns from B
827- return SparseMatrixMPI {T} (result_hash , A. row_partition, B. col_partition, result_col_indices, transpose (compressed_result_AT),
910+ return SparseMatrixMPI {T} (nothing , A. row_partition, B. col_partition, result_col_indices, transpose (compressed_result_AT),
828911 nothing )
829912end
830913
@@ -834,44 +917,24 @@ end
834917Add two distributed sparse matrices. The result has A's row partition.
835918"""
836919function Base.:+ (A:: SparseMatrixMPI{T} , B:: SparseMatrixMPI{T} ) where T
837- comm = MPI. COMM_WORLD
838-
839- # Repartition B to match A's row partition
840920 B_repart = repartition (B, A. row_partition)
841921
842- # Both A and B_repart now have:
843- # - Same row_partition (A.row_partition)
844- # - Local CSC storage with compressed column indices
845- # - col_indices mapping compressed → global
846-
847- # Compute union of column indices
848- union_indices = sort (unique (vcat (A. col_indices, B_repart. col_indices)))
849- union_size = length (union_indices)
850-
851- # Build mappings: compressed → union for both A and B_repart
852- global_to_union = Dict (g => l for (l, g) in enumerate (union_indices))
853- A_col_to_union = [global_to_union[g] for g in A. col_indices]
854- B_col_to_union = [global_to_union[g] for g in B_repart. col_indices]
922+ if A. col_indices == B_repart. col_indices
923+ return SparseMatrixMPI {T} (nothing , A. row_partition, A. col_partition,
924+ A. col_indices, transpose (A. A. parent + B_repart. A. parent), nothing )
925+ end
855926
856- # Reindex both to union space
857- A_union = reindex_to_union_cached (A. A. parent, A_col_to_union, union_size)
858- B_union = reindex_to_union_cached (B_repart. A. parent, B_col_to_union, union_size)
927+ # Merge sorted col_indices and build mappings via linear scan (no Dict)
928+ union_cols = merge_sorted_unique (A. col_indices, B_repart. col_indices)
929+ A_map = build_subset_mapping (A. col_indices, union_cols)
930+ B_map = build_subset_mapping (B_repart. col_indices, union_cols)
859931
860932 # Add in union space
861- result_union = A_union + B_union
862-
863- # Convert result back to compressed indices
864- result_col_indices_local = isempty (result_union. rowval) ? Int[] : unique (sort (result_union. rowval))
865- result_col_indices = isempty (result_col_indices_local) ? Int[] : union_indices[result_col_indices_local]
933+ result = reindex_to_union_cached (A. A. parent, A_map, length (union_cols)) +
934+ reindex_to_union_cached (B_repart. A. parent, B_map, length (union_cols))
866935
867- # Compress result
868- compressed_result = compress_AT (result_union, result_col_indices_local)
869-
870- # Compute hash from actual result
871- result_hash = compute_structural_hash (A. row_partition, result_col_indices, compressed_result, comm)
872-
873- return SparseMatrixMPI {T} (result_hash, A. row_partition, A. col_partition,
874- result_col_indices, transpose (compressed_result), nothing )
936+ return SparseMatrixMPI {T} (nothing , A. row_partition, A. col_partition,
937+ union_cols, transpose (result), nothing )
875938end
876939
877940"""
@@ -880,44 +943,24 @@ end
880943Subtract two distributed sparse matrices. The result has A's row partition.
881944"""
882945function Base.:- (A:: SparseMatrixMPI{T} , B:: SparseMatrixMPI{T} ) where T
883- comm = MPI. COMM_WORLD
884-
885- # Repartition B to match A's row partition
886946 B_repart = repartition (B, A. row_partition)
887947
888- # Both A and B_repart now have:
889- # - Same row_partition (A.row_partition)
890- # - Local CSC storage with compressed column indices
891- # - col_indices mapping compressed → global
892-
893- # Compute union of column indices
894- union_indices = sort (unique (vcat (A. col_indices, B_repart. col_indices)))
895- union_size = length (union_indices)
896-
897- # Build mappings: compressed → union for both A and B_repart
898- global_to_union = Dict (g => l for (l, g) in enumerate (union_indices))
899- A_col_to_union = [global_to_union[g] for g in A. col_indices]
900- B_col_to_union = [global_to_union[g] for g in B_repart. col_indices]
948+ if A. col_indices == B_repart. col_indices
949+ return SparseMatrixMPI {T} (nothing , A. row_partition, A. col_partition,
950+ A. col_indices, transpose (A. A. parent - B_repart. A. parent), nothing )
951+ end
901952
902- # Reindex both to union space
903- A_union = reindex_to_union_cached (A. A. parent, A_col_to_union, union_size)
904- B_union = reindex_to_union_cached (B_repart. A. parent, B_col_to_union, union_size)
953+ # Merge sorted col_indices and build mappings via linear scan (no Dict)
954+ union_cols = merge_sorted_unique (A. col_indices, B_repart. col_indices)
955+ A_map = build_subset_mapping (A. col_indices, union_cols)
956+ B_map = build_subset_mapping (B_repart. col_indices, union_cols)
905957
906958 # Subtract in union space
907- result_union = A_union - B_union
908-
909- # Convert result back to compressed indices
910- result_col_indices_local = isempty (result_union. rowval) ? Int[] : unique (sort (result_union. rowval))
911- result_col_indices = isempty (result_col_indices_local) ? Int[] : union_indices[result_col_indices_local]
912-
913- # Compress result
914- compressed_result = compress_AT (result_union, result_col_indices_local)
959+ result = reindex_to_union_cached (A. A. parent, A_map, length (union_cols)) -
960+ reindex_to_union_cached (B_repart. A. parent, B_map, length (union_cols))
915961
916- # Compute hash from actual result
917- result_hash = compute_structural_hash (A. row_partition, result_col_indices, compressed_result, comm)
918-
919- return SparseMatrixMPI {T} (result_hash, A. row_partition, A. col_partition,
920- result_col_indices, transpose (compressed_result), nothing )
962+ return SparseMatrixMPI {T} (nothing , A. row_partition, A. col_partition,
963+ union_cols, transpose (result), nothing )
921964end
922965
923966"""
@@ -1122,7 +1165,11 @@ function TransposePlan(A::SparseMatrixMPI{T}) where T
11221165 local_src_indices = Int[]
11231166 local_dst_indices = Int[]
11241167
1125- recv_rank_to_idx = Dict (r => i for (i, r) in enumerate (recv_rank_ids))
1168+ # Map rank -> index in recv_rank_ids using Vector (nranks is small)
1169+ recv_rank_to_idx = zeros (Int, nranks)
1170+ for (i, r) in enumerate (recv_rank_ids)
1171+ recv_rank_to_idx[r+ 1 ] = i
1172+ end
11261173
11271174 for (ent_idx, (_, _, src_rank, src_idx)) in enumerate (entries)
11281175 dst_idx = entry_to_nzval_idx[ent_idx]
@@ -1131,7 +1178,7 @@ function TransposePlan(A::SparseMatrixMPI{T}) where T
11311178 push! (local_dst_indices, dst_idx)
11321179 else
11331180 # Use indexed assignment: src_idx is the position in recv_buf from src_rank
1134- recv_perm[recv_rank_to_idx[src_rank]][src_idx] = dst_idx
1181+ recv_perm[recv_rank_to_idx[src_rank+ 1 ]][src_idx] = dst_idx
11351182 end
11361183 end
11371184
@@ -2182,9 +2229,8 @@ function triu(A::SparseMatrixMPI{T}, k::Integer=0) where T
21822229 else
21832230 local_used = unique (sort (new_rowval))
21842231 new_col_indices = col_indices[local_used] # convert to global
2185- # Compress: map old local indices to new compressed indices
2186- local_to_compressed = Dict (old => new for (new, old) in enumerate (local_used))
2187- compressed_rowval = [local_to_compressed[r] for r in new_rowval]
2232+ # Compress: local_used is sorted, use binary search instead of Dict
2233+ compressed_rowval = [searchsortedfirst (local_used, r) for r in new_rowval]
21882234 compressed_AT = SparseMatrixCSC (length (new_col_indices), size (A. A. parent, 2 ),
21892235 new_colptr, compressed_rowval, new_nzval)
21902236 end
@@ -2256,9 +2302,8 @@ function tril(A::SparseMatrixMPI{T}, k::Integer=0) where T
22562302 else
22572303 local_used = unique (sort (new_rowval))
22582304 new_col_indices = col_indices[local_used] # convert to global
2259- # Compress: map old local indices to new compressed indices
2260- local_to_compressed = Dict (old => new for (new, old) in enumerate (local_used))
2261- compressed_rowval = [local_to_compressed[r] for r in new_rowval]
2305+ # Compress: local_used is sorted, use binary search instead of Dict
2306+ compressed_rowval = [searchsortedfirst (local_used, r) for r in new_rowval]
22622307 compressed_AT = SparseMatrixCSC (length (new_col_indices), size (A. A. parent, 2 ),
22632308 new_colptr, compressed_rowval, new_nzval)
22642309 end
0 commit comments