Skip to content

Commit 4754d87

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
GPU backend support via Metal extension
- Add Metal extension with GPU kernels for map_rows_gpu - Fix CPU→GPU transfer in execute_plan! (use copyto! instead of broadcast) - Fix SparseMatrixMPI column extraction to preserve GPU backend - Add GPU-aware VectorPlan with CPU staging for MPI communication - Parameterize all MPI types with array type (AV/AM) for CPU/GPU dispatch - Update tests to run on both CPU and GPU backends
1 parent 4daf1a0 commit 4754d87

22 files changed

Lines changed: 5447 additions & 362 deletions

CLAUDE.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,55 @@ Sparse matrices remain on CPU (Julia's `SparseMatrixCSC` doesn't support GPU arr
8686
- `ext/LinearAlgebraMPIMetalExt.jl` - Metal extension with `mtl()` and `cpu()` functions
8787
- Loaded automatically when `using Metal` before `using LinearAlgebraMPI`
8888

89+
### Writing Unified CPU/GPU Functions
90+
91+
**IMPORTANT:** Never write separate CPU and GPU code paths with `if AV <: Vector` or `if A.nzval isa Vector` branches. Use unified helper functions instead.
92+
93+
**No mixed CPU/GPU operations:** Operations between CPU and GPU arrays are forbidden. Both operands must be on the same backend. Functions should error on mixed backends:
94+
```julia
95+
if (A_is_cpu != B_is_cpu)
96+
error("Mixed CPU/GPU operations not supported")
97+
end
98+
```
99+
100+
**Helper functions for unified code:**
101+
102+
1. `_values_to_backend(cpu_values::Vector, template)` - Convert CPU values to template's backend:
103+
- CPU template: returns `cpu_values` directly (no copy)
104+
- GPU template: `copyto!(similar(template, T, length(cpu_values)), cpu_values)`
105+
106+
2. `_to_target_backend(v::Vector, ::Type{AV})` - Convert CPU index vector to target type:
107+
- `Type{Vector{T}}`: returns `v` directly (no copy)
108+
- `Type{MtlVector{T}}`: returns GPU copy
109+
110+
**Pattern for result construction (unified):**
111+
```julia
112+
# Values: use helper (no-op for CPU, copy for GPU)
113+
nzval = _values_to_backend(nzval_cpu, A.nzval)
114+
115+
# Structure arrays: cache once, reuse forever
116+
# For CPU: caches reference to original (no allocation)
117+
# For GPU: caches GPU copy (allocated once)
118+
if plan.cached_rowptr_target === nothing
119+
plan.cached_rowptr_target = _to_target_backend(plan.rowptr, AV)
120+
end
121+
if plan.cached_colval_target === nothing
122+
plan.cached_colval_target = _to_target_backend(plan.colval, AV)
123+
end
124+
```
125+
126+
**Pattern for values that change each call (e.g., gathered B values in A*B):**
127+
```julia
128+
# First call: create cache (reference for CPU, GPU buffer for GPU)
129+
if plan.cached_values === nothing
130+
plan.cached_values = _values_to_backend(plan.cpu_values, A.nzval)
131+
end
132+
# Sync: no-op for CPU (cache === source), copy for GPU
133+
if !(plan.cached_values === plan.cpu_values)
134+
copyto!(plan.cached_values, plan.cpu_values)
135+
end
136+
```
137+
89138
## Architecture
90139

91140
LinearAlgebraMPI implements distributed sparse and dense matrix operations using MPI for parallel computing across multiple ranks. Supports both `Float64` and `ComplexF64` element types.

ext/LinearAlgebraMPIMetalExt.jl

Lines changed: 250 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ const MtlVectorMPI{T} = LinearAlgebraMPI.VectorMPI{T,MtlVector{T}}
1818
1919
Convert a CPU VectorMPI to a Metal GPU VectorMPI.
2020
"""
21-
function LinearAlgebraMPI.mtl(v::LinearAlgebraMPI.VectorMPI{T}) where T
21+
function LinearAlgebraMPI.mtl(v::LinearAlgebraMPI.VectorMPI{T,Vector{T}}) where T
2222
return adapt(MtlArray, v)
2323
end
2424

25+
# No-op for already-GPU vectors
26+
function LinearAlgebraMPI.mtl(v::LinearAlgebraMPI.VectorMPI{T,<:MtlVector}) where T
27+
return v
28+
end
29+
2530
"""
2631
cpu(v::LinearAlgebraMPI.VectorMPI{T,<:MtlVector})
2732
@@ -49,7 +54,9 @@ function LinearAlgebraMPI.mtl(A::LinearAlgebraMPI.SparseMatrixMPI{T,Ti,Vector{T}
4954
# Convert structure arrays to GPU (used by unified SpMV kernel)
5055
rowptr_target = MtlVector(A.rowptr)
5156
colval_target = MtlVector(A.colval)
52-
return LinearAlgebraMPI.SparseMatrixMPI{T,Ti,MtlVector{T}}(
57+
# Use typeof() to get the concrete GPU type (e.g., MtlVector{T, PrivateStorage})
58+
AV = typeof(nzval_gpu)
59+
return LinearAlgebraMPI.SparseMatrixMPI{T,Ti,AV}(
5360
A.structural_hash,
5461
A.row_partition,
5562
A.col_partition,
@@ -66,6 +73,11 @@ function LinearAlgebraMPI.mtl(A::LinearAlgebraMPI.SparseMatrixMPI{T,Ti,Vector{T}
6673
)
6774
end
6875

76+
# No-op for already-GPU sparse matrices
77+
function LinearAlgebraMPI.mtl(A::LinearAlgebraMPI.SparseMatrixMPI{T,Ti,<:MtlVector}) where {T,Ti}
78+
return A
79+
end
80+
6981
"""
7082
cpu(A::LinearAlgebraMPI.SparseMatrixMPI{T,Ti,<:MtlVector})
7183
@@ -101,14 +113,21 @@ Convert a CPU MatrixMPI to a Metal GPU MatrixMPI.
101113
"""
102114
function LinearAlgebraMPI.mtl(A::LinearAlgebraMPI.MatrixMPI{T,Matrix{T}}) where T
103115
A_gpu = MtlMatrix(A.A)
104-
return LinearAlgebraMPI.MatrixMPI{T,MtlMatrix{T}}(
116+
# Use typeof() to get the concrete GPU type (e.g., MtlMatrix{T, PrivateStorage})
117+
AM = typeof(A_gpu)
118+
return LinearAlgebraMPI.MatrixMPI{T,AM}(
105119
A.structural_hash,
106120
A.row_partition,
107121
A.col_partition,
108122
A_gpu
109123
)
110124
end
111125

126+
# No-op for already-GPU matrices
127+
function LinearAlgebraMPI.mtl(A::LinearAlgebraMPI.MatrixMPI{T,<:MtlMatrix}) where T
128+
return A
129+
end
130+
112131
"""
113132
cpu(A::LinearAlgebraMPI.MatrixMPI{T,<:MtlMatrix})
114133
@@ -154,39 +173,259 @@ end
154173
# ============================================================================
155174

156175
"""
157-
_zeros_like(::Type{MtlVector{T}}, dims...) where T
176+
_zeros_like(::Type{<:MtlVector{T}}, dims...) where T
158177
159178
Create a zero MtlVector of the specified dimensions.
160179
Used by Base.zeros(VectorMPI{T,MtlVector{T}}, n).
180+
Accepts concrete types like MtlVector{T, PrivateStorage}.
161181
"""
162-
LinearAlgebraMPI._zeros_like(::Type{MtlVector{T}}, dims...) where T = Metal.zeros(T, dims...)
182+
LinearAlgebraMPI._zeros_like(::Type{<:MtlVector{T}}, dims...) where T = Metal.zeros(T, dims...)
163183

164184
"""
165-
_zeros_like(::Type{MtlMatrix{T}}, dims...) where T
185+
_zeros_like(::Type{<:MtlMatrix{T}}, dims...) where T
166186
167187
Create a zero MtlMatrix of the specified dimensions.
168188
Used by Base.zeros(MatrixMPI{T,MtlMatrix{T}}, m, n).
189+
Accepts concrete types like MtlMatrix{T, PrivateStorage}.
169190
"""
170-
LinearAlgebraMPI._zeros_like(::Type{MtlMatrix{T}}, dims...) where T = Metal.zeros(T, dims...)
191+
LinearAlgebraMPI._zeros_like(::Type{<:MtlMatrix{T}}, dims...) where T = Metal.zeros(T, dims...)
171192

172193
# ============================================================================
173194
# MatrixPlan Index Array Support
174195
# ============================================================================
175196

176197
"""
177-
_index_array_type(::Type{MtlVector{T}}, ::Type{Ti}) where {T,Ti}
198+
_index_array_type(::Type{<:MtlVector{T}}, ::Type{Ti}) where {T,Ti}
178199
179200
Map MtlVector{T} value array type to MtlVector{Ti} index array type.
180201
Used by MatrixPlan to store symbolic index arrays on GPU.
202+
Accepts concrete types like MtlVector{T, PrivateStorage}.
181203
"""
182-
LinearAlgebraMPI._index_array_type(::Type{MtlVector{T}}, ::Type{Ti}) where {T,Ti} = MtlVector{Ti}
204+
LinearAlgebraMPI._index_array_type(::Type{<:MtlVector{T}}, ::Type{Ti}) where {T,Ti} = MtlVector{Ti}
183205

184206
"""
185-
_to_target_backend(v::Vector{Ti}, ::Type{MtlVector{T}}) where {Ti,T}
207+
_to_target_backend(v::Vector{Ti}, ::Type{<:MtlVector}) where Ti
186208
187209
Convert a CPU index vector to Metal GPU.
188210
Used by SparseMatrixMPI constructors to create GPU structure arrays.
211+
Accepts concrete types like MtlVector{T, PrivateStorage}.
212+
"""
213+
LinearAlgebraMPI._to_target_backend(v::Vector{Ti}, ::Type{<:MtlVector}) where Ti = MtlVector(v)
214+
215+
# ============================================================================
216+
# GPU map_rows_gpu implementation via Metal kernels
217+
# ============================================================================
218+
219+
using StaticArrays
220+
189221
"""
190-
LinearAlgebraMPI._to_target_backend(v::Vector{Ti}, ::Type{<:MtlVector}) where {Ti} = MtlVector(v)
222+
_map_rows_gpu_kernel(f, arg1::MtlMatrix, rest::MtlMatrix...)
223+
224+
GPU-accelerated row-wise map for Metal arrays.
225+
Each thread processes one row, applying `f` to the corresponding rows of all input matrices.
226+
Returns a Metal matrix with the same number of rows.
227+
"""
228+
function LinearAlgebraMPI._map_rows_gpu_kernel(f, arg1::MtlMatrix{T}, rest::MtlMatrix...) where T
229+
n = size(arg1, 1)
230+
231+
# For very small problems, fall back to CPU (kernel launch overhead dominates)
232+
if n < 256
233+
# CPU fallback for small arrays
234+
arg1_cpu = Array(arg1)
235+
rest_cpu = map(Array, rest)
236+
result_cpu = LinearAlgebraMPI._map_rows_cpu_kernel(f, arg1_cpu, rest_cpu...)
237+
return MtlMatrix(result_cpu)
238+
end
239+
240+
# Get output size by evaluating f on first row
241+
first_rows = (SVector{size(arg1,2),T}(ntuple(j -> arg1[1,j], size(arg1,2))),)
242+
for m in rest
243+
first_rows = (first_rows..., SVector{size(m,2),T}(ntuple(j -> m[1,j], size(m,2))))
244+
end
245+
sample_out = f(first_rows...)
246+
247+
if sample_out isa SVector
248+
out_cols = length(sample_out)
249+
elseif sample_out isa SMatrix
250+
out_cols = length(sample_out) # Flatten matrix output
251+
else
252+
out_cols = 1 # Scalar output
253+
end
254+
255+
# Allocate output
256+
output = Metal.zeros(T, n, out_cols)
257+
258+
# Create kernel
259+
_map_rows_kernel_dispatch(f, output, arg1, rest...)
260+
261+
return output
262+
end
263+
264+
"""
265+
Dispatch to appropriate kernel based on number of arguments.
266+
"""
267+
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}) where T
268+
n = size(arg1, 1)
269+
ncols1 = size(arg1, 2)
270+
out_cols = size(output, 2)
271+
272+
kernel = @metal launch=false _map_rows_kernel_1arg(f, output, arg1, Val(ncols1), Val(out_cols))
273+
threads = min(n, 256)
274+
groups = cld(n, threads)
275+
kernel(f, output, arg1, Val(ncols1), Val(out_cols); threads=threads, groups=groups)
276+
Metal.synchronize()
277+
end
278+
279+
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}, arg2::MtlMatrix{T}) where T
280+
n = size(arg1, 1)
281+
ncols1 = size(arg1, 2)
282+
ncols2 = size(arg2, 2)
283+
out_cols = size(output, 2)
284+
285+
kernel = @metal launch=false _map_rows_kernel_2args(f, output, arg1, arg2, Val(ncols1), Val(ncols2), Val(out_cols))
286+
threads = min(n, 256)
287+
groups = cld(n, threads)
288+
kernel(f, output, arg1, arg2, Val(ncols1), Val(ncols2), Val(out_cols); threads=threads, groups=groups)
289+
Metal.synchronize()
290+
end
291+
292+
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}, arg2::MtlMatrix{T}, arg3::MtlMatrix{T}) where T
293+
n = size(arg1, 1)
294+
ncols1 = size(arg1, 2)
295+
ncols2 = size(arg2, 2)
296+
ncols3 = size(arg3, 2)
297+
out_cols = size(output, 2)
298+
299+
kernel = @metal launch=false _map_rows_kernel_3args(f, output, arg1, arg2, arg3, Val(ncols1), Val(ncols2), Val(ncols3), Val(out_cols))
300+
threads = min(n, 256)
301+
groups = cld(n, threads)
302+
kernel(f, output, arg1, arg2, arg3, Val(ncols1), Val(ncols2), Val(ncols3), Val(out_cols); threads=threads, groups=groups)
303+
Metal.synchronize()
304+
end
305+
306+
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}, arg2::MtlMatrix{T}, arg3::MtlMatrix{T}, arg4::MtlMatrix{T}) where T
307+
n = size(arg1, 1)
308+
ncols1 = size(arg1, 2)
309+
ncols2 = size(arg2, 2)
310+
ncols3 = size(arg3, 2)
311+
ncols4 = size(arg4, 2)
312+
out_cols = size(output, 2)
313+
314+
kernel = @metal launch=false _map_rows_kernel_4args(f, output, arg1, arg2, arg3, arg4, Val(ncols1), Val(ncols2), Val(ncols3), Val(ncols4), Val(out_cols))
315+
threads = min(n, 256)
316+
groups = cld(n, threads)
317+
kernel(f, output, arg1, arg2, arg3, arg4, Val(ncols1), Val(ncols2), Val(ncols3), Val(ncols4), Val(out_cols); threads=threads, groups=groups)
318+
Metal.synchronize()
319+
end
320+
321+
function _map_rows_kernel_dispatch(f, output::MtlMatrix{T}, arg1::MtlMatrix{T}, arg2::MtlMatrix{T}, arg3::MtlMatrix{T}, arg4::MtlMatrix{T}, arg5::MtlMatrix{T}) where T
322+
n = size(arg1, 1)
323+
ncols1 = size(arg1, 2)
324+
ncols2 = size(arg2, 2)
325+
ncols3 = size(arg3, 2)
326+
ncols4 = size(arg4, 2)
327+
ncols5 = size(arg5, 2)
328+
out_cols = size(output, 2)
329+
330+
kernel = @metal launch=false _map_rows_kernel_5args(f, output, arg1, arg2, arg3, arg4, arg5, Val(ncols1), Val(ncols2), Val(ncols3), Val(ncols4), Val(ncols5), Val(out_cols))
331+
threads = min(n, 256)
332+
groups = cld(n, threads)
333+
kernel(f, output, arg1, arg2, arg3, arg4, arg5, Val(ncols1), Val(ncols2), Val(ncols3), Val(ncols4), Val(ncols5), Val(out_cols); threads=threads, groups=groups)
334+
Metal.synchronize()
335+
end
336+
337+
# ============================================================================
338+
# Metal kernels
339+
# ============================================================================
340+
341+
function _map_rows_kernel_1arg(f, output, arg1, ::Val{NC1}, ::Val{OCols}) where {NC1, OCols}
342+
i = thread_position_in_grid_1d()
343+
n = size(arg1, 1)
344+
if i <= n
345+
T = eltype(arg1)
346+
row1 = SVector{NC1,T}(ntuple(j -> @inbounds(arg1[i,j]), Val(NC1)))
347+
result = f(row1)
348+
_write_result!(output, i, result, Val(OCols))
349+
end
350+
return nothing
351+
end
352+
353+
function _map_rows_kernel_2args(f, output, arg1, arg2, ::Val{NC1}, ::Val{NC2}, ::Val{OCols}) where {NC1, NC2, OCols}
354+
i = thread_position_in_grid_1d()
355+
n = size(arg1, 1)
356+
if i <= n
357+
T = eltype(arg1)
358+
row1 = SVector{NC1,T}(ntuple(j -> @inbounds(arg1[i,j]), Val(NC1)))
359+
row2 = SVector{NC2,T}(ntuple(j -> @inbounds(arg2[i,j]), Val(NC2)))
360+
result = f(row1, row2)
361+
_write_result!(output, i, result, Val(OCols))
362+
end
363+
return nothing
364+
end
365+
366+
function _map_rows_kernel_3args(f, output, arg1, arg2, arg3, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{OCols}) where {NC1, NC2, NC3, OCols}
367+
i = thread_position_in_grid_1d()
368+
n = size(arg1, 1)
369+
if i <= n
370+
T = eltype(arg1)
371+
row1 = SVector{NC1,T}(ntuple(j -> @inbounds(arg1[i,j]), Val(NC1)))
372+
row2 = SVector{NC2,T}(ntuple(j -> @inbounds(arg2[i,j]), Val(NC2)))
373+
row3 = SVector{NC3,T}(ntuple(j -> @inbounds(arg3[i,j]), Val(NC3)))
374+
result = f(row1, row2, row3)
375+
_write_result!(output, i, result, Val(OCols))
376+
end
377+
return nothing
378+
end
379+
380+
function _map_rows_kernel_4args(f, output, arg1, arg2, arg3, arg4, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{NC4}, ::Val{OCols}) where {NC1, NC2, NC3, NC4, OCols}
381+
i = thread_position_in_grid_1d()
382+
n = size(arg1, 1)
383+
if i <= n
384+
T = eltype(arg1)
385+
row1 = SVector{NC1,T}(ntuple(j -> @inbounds(arg1[i,j]), Val(NC1)))
386+
row2 = SVector{NC2,T}(ntuple(j -> @inbounds(arg2[i,j]), Val(NC2)))
387+
row3 = SVector{NC3,T}(ntuple(j -> @inbounds(arg3[i,j]), Val(NC3)))
388+
row4 = SVector{NC4,T}(ntuple(j -> @inbounds(arg4[i,j]), Val(NC4)))
389+
result = f(row1, row2, row3, row4)
390+
_write_result!(output, i, result, Val(OCols))
391+
end
392+
return nothing
393+
end
394+
395+
function _map_rows_kernel_5args(f, output, arg1, arg2, arg3, arg4, arg5, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{NC4}, ::Val{NC5}, ::Val{OCols}) where {NC1, NC2, NC3, NC4, NC5, OCols}
396+
i = thread_position_in_grid_1d()
397+
n = size(arg1, 1)
398+
if i <= n
399+
T = eltype(arg1)
400+
row1 = SVector{NC1,T}(ntuple(j -> @inbounds(arg1[i,j]), Val(NC1)))
401+
row2 = SVector{NC2,T}(ntuple(j -> @inbounds(arg2[i,j]), Val(NC2)))
402+
row3 = SVector{NC3,T}(ntuple(j -> @inbounds(arg3[i,j]), Val(NC3)))
403+
row4 = SVector{NC4,T}(ntuple(j -> @inbounds(arg4[i,j]), Val(NC4)))
404+
row5 = SVector{NC5,T}(ntuple(j -> @inbounds(arg5[i,j]), Val(NC5)))
405+
result = f(row1, row2, row3, row4, row5)
406+
_write_result!(output, i, result, Val(OCols))
407+
end
408+
return nothing
409+
end
410+
411+
# Helper to write result (scalar, SVector, or SMatrix) to output row
412+
@inline function _write_result!(output, i, result::Number, ::Val{1})
413+
@inbounds output[i, 1] = result
414+
return nothing
415+
end
416+
417+
@inline function _write_result!(output, i, result::SVector{N,T}, ::Val{N}) where {N,T}
418+
for j in 1:N
419+
@inbounds output[i, j] = result[j]
420+
end
421+
return nothing
422+
end
423+
424+
@inline function _write_result!(output, i, result::SMatrix{M,N,T}, ::Val{MN}) where {M,N,T,MN}
425+
for j in 1:MN
426+
@inbounds output[i, j] = result[j]
427+
end
428+
return nothing
429+
end
191430

192431
end # module

0 commit comments

Comments
 (0)