@@ -884,30 +884,44 @@ function Base.:*(A::SparseMatrixMPI{T}, B::SparseMatrixMPI{T}) where T
884884 # result is (ncols_B, local_nrows_A) = shape of C.AT
885885 result_AT = plan. AT ⊛ A. A. parent
886886
887- # Always compute col_indices and hash from the actual result structure.
888- # NOTE: We cannot cache product_structural_hash because two multiplications with
889- # the same INPUT structural hashes can produce results with DIFFERENT structures.
890- # This happens when the input matrices have the same compressed local structure
891- # but different global positions of nonzeros.
892- result_col_indices = isempty (result_AT. rowval) ? Int[] : unique (sort (result_AT. rowval))
893-
894- # Build compress_map: compress_map[global_col] = local_col
895- if isempty (result_col_indices)
896- compress_map = Int[]
887+ # Use cached col_indices and compress_map if available, otherwise compute and cache
888+ if plan. product_col_indices != = nothing
889+ result_col_indices = plan. product_col_indices
890+ compress_map = plan. product_compress_map
897891 else
898- max_col = maximum (result_col_indices)
899- compress_map = zeros (Int, max_col)
900- for (local_idx, global_idx) in enumerate (result_col_indices)
901- compress_map[global_idx] = local_idx
892+ result_col_indices = isempty (result_AT. rowval) ? Int[] : unique (sort (result_AT. rowval))
893+
894+ # Build compress_map: compress_map[global_col] = local_col
895+ if isempty (result_col_indices)
896+ compress_map = Int[]
897+ else
898+ max_col = maximum (result_col_indices)
899+ compress_map = zeros (Int, max_col)
900+ for (local_idx, global_idx) in enumerate (result_col_indices)
901+ compress_map[global_idx] = local_idx
902+ end
902903 end
904+
905+ # Cache for future use with same structural pattern
906+ plan. product_col_indices = result_col_indices
907+ plan. product_compress_map = compress_map
908+ plan. product_row_partition = A. row_partition
903909 end
904910
905911 compressed_result_AT = compress_AT_cached (result_AT, compress_map, length (result_col_indices))
906- # Hash computed lazily by _ensure_hash if/when needed for plan caching
907- # This avoids expensive hash computation for results that won't be reused in matrix-matrix multiply
912+
913+ # Use cached structural hash if available, otherwise compute and cache
914+ # This is important for chained operations like (P' * A) * P where the intermediate
915+ # result needs a hash for the next multiply's plan lookup
916+ result_hash = plan. product_structural_hash
917+ if result_hash === nothing
918+ # Compute hash for the result structure
919+ result_hash = compute_structural_hash (A. row_partition, result_col_indices, compressed_result_AT, MPI. COMM_WORLD)
920+ plan. product_structural_hash = result_hash
921+ end
908922
909923 # C = A * B has rows from A and columns from B
910- return SparseMatrixMPI {T} (nothing , A. row_partition, B. col_partition, result_col_indices, transpose (compressed_result_AT),
924+ return SparseMatrixMPI {T} (result_hash , A. row_partition, B. col_partition, result_col_indices, transpose (compressed_result_AT),
911925 nothing )
912926end
913927
@@ -2563,6 +2577,11 @@ A = spdiagm(0 => v1, 1 => v2) # Main diagonal and first superdiagonal
25632577function spdiagm (kv:: Pair{<:Integer,<:VectorMPI} ...)
25642578 isempty (kv) && error (" spdiagm requires at least one diagonal" )
25652579
2580+ # Fast path for single main diagonal
2581+ if length (kv) == 1 && first (kv)[1 ] == 0
2582+ return spdiagm (first (kv)[2 ])
2583+ end
2584+
25662585 comm = MPI. COMM_WORLD
25672586 rank = MPI. Comm_rank (comm)
25682587 nranks = MPI. Comm_size (comm)
@@ -2592,7 +2611,7 @@ function spdiagm(kv::Pair{<:Integer,<:VectorMPI}...)
25922611 repartitioned[k] = repartition (v, target)
25932612 end
25942613
2595- # Step 3: Build local triplets (local_row, global_col, value)
2614+ # Step 3: Build local triplets using vectorized operations
25962615 local_I = Int[]
25972616 local_J = Int[]
25982617 local_V = T[]
@@ -2601,25 +2620,45 @@ function spdiagm(kv::Pair{<:Integer,<:VectorMPI}...)
26012620 vec_len = length (v)
26022621 v_repart = repartitioned[k]
26032622 my_v_start = v_repart. partition[rank+ 1 ]
2623+ my_v_end = v_repart. partition[rank+ 2 ] - 1
2624+ local_v_len = my_v_end - my_v_start + 1
26042625
2605- for local_row_idx in 1 : local_nrows
2606- global_row = my_row_start + local_row_idx - 1
2626+ # Compute which local rows have valid diagonal entries
2627+ # For k >= 0: global_row i -> col i+k, uses v[i]
2628+ # For k < 0: global_row i -> col i+k, uses v[i+k]
2629+ if k >= 0
2630+ # v[i] goes to (i, i+k), so row = vec_idx, col = vec_idx + k
2631+ # We have v[my_v_start:my_v_end], these go to rows my_v_start:my_v_end
2632+ first_row = max (my_row_start, my_v_start)
2633+ last_row = min (my_row_end, my_v_end)
2634+ else
2635+ # v[i] goes to (i-k, i), so row = vec_idx - k, col = vec_idx
2636+ # Row r uses v[r+k], col = r+k
2637+ first_row = max (my_row_start, my_v_start - k)
2638+ last_row = min (my_row_end, my_v_end - k)
2639+ end
26072640
2641+ if first_row <= last_row
2642+ nentries = last_row - first_row + 1
2643+ rows = first_row: last_row
2644+ cols = rows .+ k
2645+
2646+ # Filter to valid column range
2647+ valid = (1 .<= cols .<= n)
2648+ valid_rows = rows[valid]
2649+ valid_cols = cols[valid]
2650+
2651+ # Local row indices and vector indices
2652+ local_rows = valid_rows .- my_row_start .+ 1
26082653 if k >= 0
2609- vec_idx = global_row
2610- col = global_row + k
2654+ v_indices = valid_rows .- my_v_start .+ 1
26112655 else
2612- vec_idx = global_row + k
2613- col = vec_idx
2656+ v_indices = (valid_rows .+ k) .- my_v_start .+ 1
26142657 end
26152658
2616- if 1 <= vec_idx <= vec_len && 1 <= col <= n
2617- # Index into repartitioned vector
2618- local_v_idx = vec_idx - my_v_start + 1
2619- push! (local_I, local_row_idx)
2620- push! (local_J, col)
2621- push! (local_V, v_repart. v[local_v_idx])
2622- end
2659+ append! (local_I, local_rows)
2660+ append! (local_J, valid_cols)
2661+ append! (local_V, v_repart. v[v_indices])
26232662 end
26242663 end
26252664
@@ -2722,8 +2761,34 @@ v = VectorMPI([1.0, 2.0, 3.0])
27222761A = spdiagm(v) # 3×3 diagonal matrix
27232762```
27242763"""
2725- function spdiagm (v:: VectorMPI )
2726- return spdiagm (0 => v)
2764+ function spdiagm (v:: VectorMPI{T} ) where T
2765+ # Ultra-fast path for main diagonal: build CSC structure directly
2766+ n = length (v)
2767+ comm = MPI. COMM_WORLD
2768+ rank = MPI. Comm_rank (comm)
2769+
2770+ my_start = v. partition[rank+ 1 ]
2771+ local_n = length (v. v)
2772+
2773+ # Build AT (transpose) as CSC directly - no sparse() call needed
2774+ # AT has size (n_cols, local_n_rows): n global columns, local_n local rows
2775+ # Column j of AT corresponds to local row j, which has one entry at global column my_start+j-1
2776+ #
2777+ # IMPORTANT: AT.rowval stores LOCAL/compressed column indices (1, 2, 3, ...)
2778+ # col_indices maps these local indices to global column indices
2779+ colptr = collect (1 : (local_n+ 1 )) # Each column has exactly 1 entry
2780+ rowval = collect (1 : local_n) # LOCAL column indices (compressed)
2781+ nzval = copy (v. v)
2782+
2783+ AT_local = SparseMatrixCSC (n, local_n, colptr, rowval, nzval)
2784+
2785+ # col_indices maps local column index -> global column index
2786+ col_indices = collect (my_start: (my_start + local_n - 1 ))
2787+ row_partition = v. partition # Use same partition as input vector
2788+ col_partition = v. partition # Square matrix, same column partition
2789+
2790+ return SparseMatrixMPI {T} (nothing , row_partition, col_partition, col_indices,
2791+ transpose (AT_local), nothing )
27272792end
27282793
27292794"""
0 commit comments