Skip to content

Commit 6b20302

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Add Base.zeros for distributed types with GPU support
- Add zeros(VectorMPI{T,AV}, n) for creating zero distributed vectors - Add zeros(MatrixMPI{T,AM}, m, n) for creating zero distributed matrices - Add zeros(SparseMatrixMPI{T,Ti,AV}, m, n) for creating zero sparse matrices - Export VectorMPI_CPU, MatrixMPI_CPU, SparseMatrixMPI_CPU type aliases - Add _zeros_like helper with Metal extension support for GPU arrays - Update CLAUDE.md with zeros usage examples Bump version to 0.1.5
1 parent 3ae6d50 commit 6b20302

4 files changed

Lines changed: 178 additions & 1 deletion

File tree

CLAUDE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,27 @@ GPU acceleration is supported via Metal.jl (macOS) as a package extension.
4040
- `SparseMatrixMPI{T,Ti,AV}` where `AV` is `Vector{T}` (CPU) or `MtlVector{T}` (GPU) for the `nzval` array
4141
- Type aliases: `VectorMPI_CPU{T}`, `MatrixMPI_CPU{T}`, `SparseMatrixMPI_CPU{T,Ti}` for CPU-backed types
4242

43+
### Creating Zero Arrays
44+
45+
Use `Base.zeros` with the full parametric type or type alias:
46+
47+
```julia
48+
# CPU zero arrays
49+
v = zeros(VectorMPI{Float64,Vector{Float64}}, 100)
50+
v = zeros(VectorMPI_CPU{Float64}, 100) # Equivalent using type alias
51+
52+
A = zeros(MatrixMPI{Float64,Matrix{Float64}}, 50, 30)
53+
A = zeros(MatrixMPI_CPU{Float64}, 50, 30)
54+
55+
S = zeros(SparseMatrixMPI{Float64,Int,Vector{Float64}}, 100, 100)
56+
S = zeros(SparseMatrixMPI_CPU{Float64,Int}, 100, 100)
57+
58+
# GPU zero arrays (requires Metal.jl loaded)
59+
using Metal
60+
v_gpu = zeros(VectorMPI{Float32,MtlVector{Float32}}, 100)
61+
A_gpu = zeros(MatrixMPI{Float32,MtlMatrix{Float32}}, 50, 30)
62+
```
63+
4364
### CPU Staging
4465

4566
MPI communication always uses CPU buffers (no Metal-aware MPI exists). GPU data is staged through CPU:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearAlgebraMPI"
22
uuid = "5bdd2be4-ae34-42ef-8b36-f4c85d48f377"
3-
version = "0.1.4"
3+
version = "0.1.5"
44
authors = ["S. Loisel"]
55

66
[deps]

ext/LinearAlgebraMPIMetalExt.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,26 @@ function LinearAlgebraMPI._convert_vector_to_backend(v::LinearAlgebraMPI.VectorM
149149
return LinearAlgebraMPI.mtl(v)
150150
end
151151

152+
# ============================================================================
153+
# Base.zeros Support
154+
# ============================================================================
155+
156+
"""
157+
_zeros_like(::Type{MtlVector{T}}, dims...) where T
158+
159+
Create a zero MtlVector of the specified dimensions.
160+
Used by Base.zeros(VectorMPI{T,MtlVector{T}}, n).
161+
"""
162+
LinearAlgebraMPI._zeros_like(::Type{MtlVector{T}}, dims...) where T = Metal.zeros(T, dims...)
163+
164+
"""
165+
_zeros_like(::Type{MtlMatrix{T}}, dims...) where T
166+
167+
Create a zero MtlMatrix of the specified dimensions.
168+
Used by Base.zeros(MatrixMPI{T,MtlMatrix{T}}, m, n).
169+
"""
170+
LinearAlgebraMPI._zeros_like(::Type{MtlMatrix{T}}, dims...) where T = Metal.zeros(T, dims...)
171+
152172
# ============================================================================
153173
# MatrixPlan Index Array Support
154174
# ============================================================================

src/LinearAlgebraMPI.jl

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import LinearAlgebra
1010
import LinearAlgebra: tr, diag, triu, tril, Transpose, Adjoint, norm, opnorm, mul!, ldlt, BLAS, issymmetric, UniformScaling, dot, Symmetric
1111

1212
export SparseMatrixMPI, MatrixMPI, VectorMPI, clear_plan_cache!, uniform_partition, repartition
13+
export VectorMPI_CPU, MatrixMPI_CPU, SparseMatrixMPI_CPU # Type aliases for CPU-backed types
1314
export SparseMatrixCSR # Type alias for Transpose{SparseMatrixCSC} (CSR storage format)
1415
export map_rows # Row-wise map over distributed vectors/matrices
1516
export VectorMPI_local, MatrixMPI_local, SparseMatrixMPI_local # Local constructors
@@ -959,6 +960,141 @@ function map_rows(f, A...)
959960
end
960961

961962

963+
# ============================================================================
964+
# Base.zeros for Distributed Types
965+
# ============================================================================
966+
967+
# Helper to create a zero array with the correct backend type
968+
# Base case: CPU arrays
969+
_zeros_like(::Type{Vector{T}}, dims...) where T = zeros(T, dims...)
970+
_zeros_like(::Type{Matrix{T}}, dims...) where T = zeros(T, dims...)
971+
972+
# For GPU arrays, extensions will define additional methods
973+
974+
"""
975+
Base.zeros(::Type{VectorMPI{T,AV}}, n::Integer; comm=MPI.COMM_WORLD) where {T,AV}
976+
977+
Create a distributed zero vector of length `n` with element type `T` and storage type `AV`.
978+
979+
The vector is uniformly partitioned across MPI ranks.
980+
981+
# Examples
982+
```julia
983+
# CPU zero vector
984+
v = zeros(VectorMPI{Float64,Vector{Float64}}, 100)
985+
986+
# Using type alias
987+
v = zeros(VectorMPI_CPU{Float64}, 100)
988+
989+
# GPU zero vector (requires Metal.jl loaded)
990+
using Metal
991+
v = zeros(VectorMPI{Float32,MtlVector{Float32}}, 100)
992+
```
993+
"""
994+
function Base.zeros(::Type{VectorMPI{T,AV}}, n::Integer;
995+
comm::MPI.Comm=MPI.COMM_WORLD) where {T,AV<:AbstractVector{T}}
996+
nranks = MPI.Comm_size(comm)
997+
rank = MPI.Comm_rank(comm)
998+
999+
partition = uniform_partition(n, nranks)
1000+
local_size = partition[rank + 2] - partition[rank + 1]
1001+
1002+
local_v = _zeros_like(AV, local_size)
1003+
hash = compute_partition_hash(partition)
1004+
1005+
return VectorMPI{T,AV}(hash, partition, local_v)
1006+
end
1007+
1008+
"""
1009+
Base.zeros(::Type{MatrixMPI{T,AM}}, m::Integer, n::Integer; comm=MPI.COMM_WORLD) where {T,AM}
1010+
1011+
Create a distributed zero matrix of size `m × n` with element type `T` and storage type `AM`.
1012+
1013+
The matrix is row-partitioned across MPI ranks.
1014+
1015+
# Examples
1016+
```julia
1017+
# CPU zero matrix
1018+
A = zeros(MatrixMPI{Float64,Matrix{Float64}}, 100, 50)
1019+
1020+
# Using type alias
1021+
A = zeros(MatrixMPI_CPU{Float64}, 100, 50)
1022+
1023+
# GPU zero matrix (requires Metal.jl loaded)
1024+
using Metal
1025+
A = zeros(MatrixMPI{Float32,MtlMatrix{Float32}}, 100, 50)
1026+
```
1027+
"""
1028+
function Base.zeros(::Type{MatrixMPI{T,AM}}, m::Integer, n::Integer;
1029+
comm::MPI.Comm=MPI.COMM_WORLD) where {T,AM<:AbstractMatrix{T}}
1030+
nranks = MPI.Comm_size(comm)
1031+
rank = MPI.Comm_rank(comm)
1032+
1033+
row_partition = uniform_partition(m, nranks)
1034+
col_partition = uniform_partition(n, nranks) # Used for transpose operations
1035+
local_nrows = row_partition[rank + 2] - row_partition[rank + 1]
1036+
1037+
local_A = _zeros_like(AM, local_nrows, n)
1038+
# Structural hash computed lazily
1039+
1040+
return MatrixMPI{T,AM}(nothing, row_partition, col_partition, local_A)
1041+
end
1042+
1043+
"""
1044+
Base.zeros(::Type{SparseMatrixMPI{T,Ti,AV}}, m::Integer, n::Integer; comm=MPI.COMM_WORLD) where {T,Ti,AV}
1045+
1046+
Create a distributed zero sparse matrix of size `m × n`.
1047+
1048+
A zero sparse matrix has no nonzero entries, so the resulting matrix has:
1049+
- Empty `rowptr` (all ones)
1050+
- Empty `colval` and `nzval`
1051+
1052+
# Examples
1053+
```julia
1054+
# CPU zero sparse matrix
1055+
A = zeros(SparseMatrixMPI{Float64,Int,Vector{Float64}}, 100, 100)
1056+
1057+
# Using type alias
1058+
A = zeros(SparseMatrixMPI_CPU{Float64,Int}, 100, 100)
1059+
```
1060+
"""
1061+
function Base.zeros(::Type{SparseMatrixMPI{T,Ti,AV}}, m::Integer, n::Integer;
1062+
comm::MPI.Comm=MPI.COMM_WORLD) where {T,Ti<:Integer,AV<:AbstractVector{T}}
1063+
nranks = MPI.Comm_size(comm)
1064+
rank = MPI.Comm_rank(comm)
1065+
1066+
row_partition = uniform_partition(m, nranks)
1067+
col_partition = uniform_partition(n, nranks)
1068+
local_nrows = row_partition[rank + 2] - row_partition[rank + 1]
1069+
1070+
# Empty sparse structure
1071+
rowptr = ones(Ti, local_nrows + 1) # All rows have 0 entries
1072+
colval = Ti[]
1073+
nzval = _zeros_like(AV, 0) # Empty but correct type
1074+
col_indices = Int[] # No columns referenced
1075+
1076+
# For CPU, rowptr_target/colval_target are the same as rowptr/colval
1077+
# For GPU, they would be GPU copies (but empty arrays don't matter)
1078+
rowptr_target = rowptr
1079+
colval_target = colval
1080+
1081+
return SparseMatrixMPI{T,Ti,AV}(
1082+
nothing, # Hash computed lazily
1083+
row_partition,
1084+
col_partition,
1085+
col_indices,
1086+
rowptr,
1087+
colval,
1088+
nzval,
1089+
local_nrows,
1090+
0, # ncols_compressed = 0 (no columns referenced)
1091+
nothing, # cached_transpose
1092+
true, # cached_symmetric (zero matrix is symmetric)
1093+
rowptr_target,
1094+
colval_target
1095+
)
1096+
end
1097+
9621098
# ============================================================================
9631099
# Precompilation Workload
9641100
# ============================================================================

0 commit comments

Comments
 (0)