Skip to content

Commit 528524f

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Performance optimizations for block ops, spdiagm, and matrix multiply
- Use let blocks in cat/blockdiag to avoid boxing overhead in loops - Use copyto! and block copy instead of element-by-element loops - Add fast path for spdiagm with single main diagonal - Vectorize spdiagm with multiple diagonals - Cache product col_indices, compress_map, and structural hash in MatrixPlan - Fix MUMPS OMP_NUM_THREADS handling to parse environment variable
1 parent 547853a commit 528524f

3 files changed

Lines changed: 126 additions & 63 deletions

File tree

src/blocks.jl

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ function Base.cat(As::SparseMatrixMPI{T}...; dims) where T
125125
triplets = _gather_rows_from_sparse(A, rows_needed)
126126

127127
# Add triplets with offsets applied
128-
for (row_in_block, col_in_block, val) in triplets
129-
output_row = row_offsets[bi] + row_in_block
130-
local_output_row = output_row - my_out_row_start + 1
131-
global_col = col_offset + col_in_block
132-
push!(local_I, local_output_row)
133-
push!(local_J, global_col)
134-
push!(local_V, val)
128+
# Use let block to capture loop variables and avoid boxing overhead
129+
let row_off = row_offsets[bi], out_start = my_out_row_start, col_off = col_offset
130+
for (row_in_block, col_in_block, val) in triplets
131+
push!(local_I, row_off + row_in_block - out_start + 1)
132+
push!(local_J, col_off + col_in_block)
133+
push!(local_V, val)
134+
end
135135
end
136136
end
137137
end
@@ -264,12 +264,11 @@ function Base.cat(As::MatrixMPI{T}...; dims) where T
264264
gathered_rows = _gather_dense_rows(A, rows_needed)
265265

266266
# Place into local matrix (only if we have overlap)
267-
if has_overlap
268-
for (i, row_in_block) in enumerate(rows_needed)
269-
output_row = row_offsets[bi] + row_in_block
270-
local_row = output_row - my_out_row_start + 1
271-
local_matrix[local_row, col_start:col_end] = gathered_rows[i, :]
272-
end
267+
# Use block copy instead of element-by-element loop to avoid boxing overhead
268+
if has_overlap && !isempty(rows_needed)
269+
first_local_row = row_offsets[bi] + first(rows_needed) - my_out_row_start + 1
270+
last_local_row = row_offsets[bi] + last(rows_needed) - my_out_row_start + 1
271+
local_matrix[first_local_row:last_local_row, col_start:col_end] = gathered_rows
273272
end
274273
end
275274
end
@@ -393,13 +392,12 @@ function _vcat_vectors(vs::VectorMPI{T}...) where T
393392
# Copy elements from repartitioned vector to output
394393
first_in_vec = max(1, my_out_start - offset)
395394
last_in_vec = min(vec_len, my_out_end - offset)
395+
n_copy = last_in_vec - first_in_vec + 1
396396

397-
for idx_in_vec in first_in_vec:last_in_vec
398-
global_out_idx = offset + idx_in_vec
399-
local_out_idx = global_out_idx - my_out_start + 1
400-
local_v_idx = idx_in_vec - my_v_start + 1
401-
local_v[local_out_idx] = v_repart.v[local_v_idx]
402-
end
397+
# Use copyto! instead of element-by-element loop to avoid boxing overhead
398+
dst_start = offset + first_in_vec - my_out_start + 1
399+
src_start = first_in_vec - my_v_start + 1
400+
copyto!(local_v, dst_start, v_repart.v, src_start, n_copy)
403401
end
404402
end
405403

@@ -512,13 +510,13 @@ function blockdiag(As::SparseMatrixMPI{T}...) where T
512510
triplets = _gather_rows_from_sparse(A, rows_needed)
513511

514512
# Add triplets with offsets applied
515-
for (row_in_block, col_in_block, val) in triplets
516-
output_row = row_offsets[k] + row_in_block
517-
local_output_row = output_row - my_out_row_start + 1
518-
global_col = col_offset + col_in_block
519-
push!(local_I, local_output_row)
520-
push!(local_J, global_col)
521-
push!(local_V, val)
513+
# Use let block to capture loop variables and avoid boxing overhead
514+
let row_off = row_offsets[k], out_start = my_out_row_start, col_off = col_offset
515+
for (row_in_block, col_in_block, val) in triplets
516+
push!(local_I, row_off + row_in_block - out_start + 1)
517+
push!(local_J, col_off + col_in_block)
518+
push!(local_V, val)
519+
end
522520
end
523521
end
524522

src/mumps_factorization.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,10 @@ function _get_or_create_analysis_plan(A::SparseMatrixMPI{T}, symmetric::Bool) wh
290290
set_icntl!(mumps, 21, 0; displaylevel=0) # Centralized solution on host
291291
set_icntl!(mumps, 7, 5; displaylevel=0) # METIS ordering (better fill-in)
292292

293-
# Set OpenMP threads for MUMPS to match Julia's thread count if OMP_NUM_THREADS not set
294-
if !haskey(ENV, "OMP_NUM_THREADS")
295-
set_icntl!(mumps, 16, Threads.nthreads(); displaylevel=0)
296-
end
293+
# Enable OpenMP threading in MUMPS
294+
# ICNTL(16) = number of OpenMP threads (0 = use OMP_NUM_THREADS)
295+
omp_threads = parse(Int, get(ENV, "OMP_NUM_THREADS", "1"))
296+
set_icntl!(mumps, 16, omp_threads; displaylevel=0)
297297

298298
# Set matrix dimension
299299
mumps.n = MUMPS_INT(n)

src/sparse.jl

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -884,30 +884,44 @@ function Base.:*(A::SparseMatrixMPI{T}, B::SparseMatrixMPI{T}) where T
884884
# result is (ncols_B, local_nrows_A) = shape of C.AT
885885
result_AT = plan.AT A.A.parent
886886

887-
# Always compute col_indices and hash from the actual result structure.
888-
# NOTE: We cannot cache product_structural_hash because two multiplications with
889-
# the same INPUT structural hashes can produce results with DIFFERENT structures.
890-
# This happens when the input matrices have the same compressed local structure
891-
# but different global positions of nonzeros.
892-
result_col_indices = isempty(result_AT.rowval) ? Int[] : unique(sort(result_AT.rowval))
893-
894-
# Build compress_map: compress_map[global_col] = local_col
895-
if isempty(result_col_indices)
896-
compress_map = Int[]
887+
# Use cached col_indices and compress_map if available, otherwise compute and cache
888+
if plan.product_col_indices !== nothing
889+
result_col_indices = plan.product_col_indices
890+
compress_map = plan.product_compress_map
897891
else
898-
max_col = maximum(result_col_indices)
899-
compress_map = zeros(Int, max_col)
900-
for (local_idx, global_idx) in enumerate(result_col_indices)
901-
compress_map[global_idx] = local_idx
892+
result_col_indices = isempty(result_AT.rowval) ? Int[] : unique(sort(result_AT.rowval))
893+
894+
# Build compress_map: compress_map[global_col] = local_col
895+
if isempty(result_col_indices)
896+
compress_map = Int[]
897+
else
898+
max_col = maximum(result_col_indices)
899+
compress_map = zeros(Int, max_col)
900+
for (local_idx, global_idx) in enumerate(result_col_indices)
901+
compress_map[global_idx] = local_idx
902+
end
902903
end
904+
905+
# Cache for future use with same structural pattern
906+
plan.product_col_indices = result_col_indices
907+
plan.product_compress_map = compress_map
908+
plan.product_row_partition = A.row_partition
903909
end
904910

905911
compressed_result_AT = compress_AT_cached(result_AT, compress_map, length(result_col_indices))
906-
# Hash computed lazily by _ensure_hash if/when needed for plan caching
907-
# This avoids expensive hash computation for results that won't be reused in matrix-matrix multiply
912+
913+
# Use cached structural hash if available, otherwise compute and cache
914+
# This is important for chained operations like (P' * A) * P where the intermediate
915+
# result needs a hash for the next multiply's plan lookup
916+
result_hash = plan.product_structural_hash
917+
if result_hash === nothing
918+
# Compute hash for the result structure
919+
result_hash = compute_structural_hash(A.row_partition, result_col_indices, compressed_result_AT, MPI.COMM_WORLD)
920+
plan.product_structural_hash = result_hash
921+
end
908922

909923
# C = A * B has rows from A and columns from B
910-
return SparseMatrixMPI{T}(nothing, A.row_partition, B.col_partition, result_col_indices, transpose(compressed_result_AT),
924+
return SparseMatrixMPI{T}(result_hash, A.row_partition, B.col_partition, result_col_indices, transpose(compressed_result_AT),
911925
nothing)
912926
end
913927

@@ -2563,6 +2577,11 @@ A = spdiagm(0 => v1, 1 => v2) # Main diagonal and first superdiagonal
25632577
function spdiagm(kv::Pair{<:Integer,<:VectorMPI}...)
25642578
isempty(kv) && error("spdiagm requires at least one diagonal")
25652579

2580+
# Fast path for single main diagonal
2581+
if length(kv) == 1 && first(kv)[1] == 0
2582+
return spdiagm(first(kv)[2])
2583+
end
2584+
25662585
comm = MPI.COMM_WORLD
25672586
rank = MPI.Comm_rank(comm)
25682587
nranks = MPI.Comm_size(comm)
@@ -2592,7 +2611,7 @@ function spdiagm(kv::Pair{<:Integer,<:VectorMPI}...)
25922611
repartitioned[k] = repartition(v, target)
25932612
end
25942613

2595-
# Step 3: Build local triplets (local_row, global_col, value)
2614+
# Step 3: Build local triplets using vectorized operations
25962615
local_I = Int[]
25972616
local_J = Int[]
25982617
local_V = T[]
@@ -2601,25 +2620,45 @@ function spdiagm(kv::Pair{<:Integer,<:VectorMPI}...)
26012620
vec_len = length(v)
26022621
v_repart = repartitioned[k]
26032622
my_v_start = v_repart.partition[rank+1]
2623+
my_v_end = v_repart.partition[rank+2] - 1
2624+
local_v_len = my_v_end - my_v_start + 1
26042625

2605-
for local_row_idx in 1:local_nrows
2606-
global_row = my_row_start + local_row_idx - 1
2626+
# Compute which local rows have valid diagonal entries
2627+
# For k >= 0: global_row i -> col i+k, uses v[i]
2628+
# For k < 0: global_row i -> col i+k, uses v[i+k]
2629+
if k >= 0
2630+
# v[i] goes to (i, i+k), so row = vec_idx, col = vec_idx + k
2631+
# We have v[my_v_start:my_v_end], these go to rows my_v_start:my_v_end
2632+
first_row = max(my_row_start, my_v_start)
2633+
last_row = min(my_row_end, my_v_end)
2634+
else
2635+
# v[i] goes to (i-k, i), so row = vec_idx - k, col = vec_idx
2636+
# Row r uses v[r+k], col = r+k
2637+
first_row = max(my_row_start, my_v_start - k)
2638+
last_row = min(my_row_end, my_v_end - k)
2639+
end
26072640

2641+
if first_row <= last_row
2642+
nentries = last_row - first_row + 1
2643+
rows = first_row:last_row
2644+
cols = rows .+ k
2645+
2646+
# Filter to valid column range
2647+
valid = (1 .<= cols .<= n)
2648+
valid_rows = rows[valid]
2649+
valid_cols = cols[valid]
2650+
2651+
# Local row indices and vector indices
2652+
local_rows = valid_rows .- my_row_start .+ 1
26082653
if k >= 0
2609-
vec_idx = global_row
2610-
col = global_row + k
2654+
v_indices = valid_rows .- my_v_start .+ 1
26112655
else
2612-
vec_idx = global_row + k
2613-
col = vec_idx
2656+
v_indices = (valid_rows .+ k) .- my_v_start .+ 1
26142657
end
26152658

2616-
if 1 <= vec_idx <= vec_len && 1 <= col <= n
2617-
# Index into repartitioned vector
2618-
local_v_idx = vec_idx - my_v_start + 1
2619-
push!(local_I, local_row_idx)
2620-
push!(local_J, col)
2621-
push!(local_V, v_repart.v[local_v_idx])
2622-
end
2659+
append!(local_I, local_rows)
2660+
append!(local_J, valid_cols)
2661+
append!(local_V, v_repart.v[v_indices])
26232662
end
26242663
end
26252664

@@ -2722,8 +2761,34 @@ v = VectorMPI([1.0, 2.0, 3.0])
27222761
A = spdiagm(v) # 3×3 diagonal matrix
27232762
```
27242763
"""
2725-
function spdiagm(v::VectorMPI)
2726-
return spdiagm(0 => v)
2764+
function spdiagm(v::VectorMPI{T}) where T
2765+
# Ultra-fast path for main diagonal: build CSC structure directly
2766+
n = length(v)
2767+
comm = MPI.COMM_WORLD
2768+
rank = MPI.Comm_rank(comm)
2769+
2770+
my_start = v.partition[rank+1]
2771+
local_n = length(v.v)
2772+
2773+
# Build AT (transpose) as CSC directly - no sparse() call needed
2774+
# AT has size (n_cols, local_n_rows): n global columns, local_n local rows
2775+
# Column j of AT corresponds to local row j, which has one entry at global column my_start+j-1
2776+
#
2777+
# IMPORTANT: AT.rowval stores LOCAL/compressed column indices (1, 2, 3, ...)
2778+
# col_indices maps these local indices to global column indices
2779+
colptr = collect(1:(local_n+1)) # Each column has exactly 1 entry
2780+
rowval = collect(1:local_n) # LOCAL column indices (compressed)
2781+
nzval = copy(v.v)
2782+
2783+
AT_local = SparseMatrixCSC(n, local_n, colptr, rowval, nzval)
2784+
2785+
# col_indices maps local column index -> global column index
2786+
col_indices = collect(my_start:(my_start + local_n - 1))
2787+
row_partition = v.partition # Use same partition as input vector
2788+
col_partition = v.partition # Square matrix, same column partition
2789+
2790+
return SparseMatrixMPI{T}(nothing, row_partition, col_partition, col_indices,
2791+
transpose(AT_local), nothing)
27272792
end
27282793

27292794
"""

0 commit comments

Comments
 (0)