@@ -18,10 +18,15 @@ const MtlVectorMPI{T} = LinearAlgebraMPI.VectorMPI{T,MtlVector{T}}
1818
1919Convert a CPU VectorMPI to a Metal GPU VectorMPI.
2020"""
21- function LinearAlgebraMPI. mtl (v:: LinearAlgebraMPI.VectorMPI{T} ) where T
21+ function LinearAlgebraMPI. mtl (v:: LinearAlgebraMPI.VectorMPI{T,Vector{T} } ) where T
2222 return adapt (MtlArray, v)
2323end
2424
25+ # No-op for already-GPU vectors
26+ function LinearAlgebraMPI. mtl (v:: LinearAlgebraMPI.VectorMPI{T,<:MtlVector} ) where T
27+ return v
28+ end
29+
2530"""
2631 cpu(v::LinearAlgebraMPI.VectorMPI{T,<:MtlVector})
2732
@@ -49,7 +54,9 @@ function LinearAlgebraMPI.mtl(A::LinearAlgebraMPI.SparseMatrixMPI{T,Ti,Vector{T}
4954 # Convert structure arrays to GPU (used by unified SpMV kernel)
5055 rowptr_target = MtlVector (A. rowptr)
5156 colval_target = MtlVector (A. colval)
52- return LinearAlgebraMPI. SparseMatrixMPI {T,Ti,MtlVector{T}} (
57+ # Use typeof() to get the concrete GPU type (e.g., MtlVector{T, PrivateStorage})
58+ AV = typeof (nzval_gpu)
59+ return LinearAlgebraMPI. SparseMatrixMPI {T,Ti,AV} (
5360 A. structural_hash,
5461 A. row_partition,
5562 A. col_partition,
@@ -66,6 +73,11 @@ function LinearAlgebraMPI.mtl(A::LinearAlgebraMPI.SparseMatrixMPI{T,Ti,Vector{T}
6673 )
6774end
6875
76+ # No-op for already-GPU sparse matrices
77+ function LinearAlgebraMPI. mtl (A:: LinearAlgebraMPI.SparseMatrixMPI{T,Ti,<:MtlVector} ) where {T,Ti}
78+ return A
79+ end
80+
6981"""
7082 cpu(A::LinearAlgebraMPI.SparseMatrixMPI{T,Ti,<:MtlVector})
7183
@@ -101,14 +113,21 @@ Convert a CPU MatrixMPI to a Metal GPU MatrixMPI.
101113"""
102114function LinearAlgebraMPI. mtl (A:: LinearAlgebraMPI.MatrixMPI{T,Matrix{T}} ) where T
103115 A_gpu = MtlMatrix (A. A)
104- return LinearAlgebraMPI. MatrixMPI {T,MtlMatrix{T}} (
116+ # Use typeof() to get the concrete GPU type (e.g., MtlMatrix{T, PrivateStorage})
117+ AM = typeof (A_gpu)
118+ return LinearAlgebraMPI. MatrixMPI {T,AM} (
105119 A. structural_hash,
106120 A. row_partition,
107121 A. col_partition,
108122 A_gpu
109123 )
110124end
111125
126+ # No-op for already-GPU matrices
127+ function LinearAlgebraMPI. mtl (A:: LinearAlgebraMPI.MatrixMPI{T,<:MtlMatrix} ) where T
128+ return A
129+ end
130+
112131"""
113132 cpu(A::LinearAlgebraMPI.MatrixMPI{T,<:MtlMatrix})
114133
@@ -154,39 +173,259 @@ end
154173# ============================================================================
155174
156175"""
157- _zeros_like(::Type{MtlVector{T}}, dims...) where T
176+ _zeros_like(::Type{<: MtlVector{T}}, dims...) where T
158177
159178Create a zero MtlVector of the specified dimensions.
160179Used by Base.zeros(VectorMPI{T,MtlVector{T}}, n).
180+ Accepts concrete types like MtlVector{T, PrivateStorage}.
161181"""
162- LinearAlgebraMPI. _zeros_like (:: Type{MtlVector{T}} , dims... ) where T = Metal. zeros (T, dims... )
182+ LinearAlgebraMPI. _zeros_like (:: Type{<: MtlVector{T}} , dims... ) where T = Metal. zeros (T, dims... )
163183
164184"""
165- _zeros_like(::Type{MtlMatrix{T}}, dims...) where T
185+ _zeros_like(::Type{<: MtlMatrix{T}}, dims...) where T
166186
167187Create a zero MtlMatrix of the specified dimensions.
168188Used by Base.zeros(MatrixMPI{T,MtlMatrix{T}}, m, n).
189+ Accepts concrete types like MtlMatrix{T, PrivateStorage}.
169190"""
170- LinearAlgebraMPI. _zeros_like (:: Type{MtlMatrix{T}} , dims... ) where T = Metal. zeros (T, dims... )
191+ LinearAlgebraMPI. _zeros_like (:: Type{<: MtlMatrix{T}} , dims... ) where T = Metal. zeros (T, dims... )
171192
172193# ============================================================================
173194# MatrixPlan Index Array Support
174195# ============================================================================
175196
176197"""
177- _index_array_type(::Type{MtlVector{T}}, ::Type{Ti}) where {T,Ti}
198+ _index_array_type(::Type{<: MtlVector{T}}, ::Type{Ti}) where {T,Ti}
178199
179200Map MtlVector{T} value array type to MtlVector{Ti} index array type.
180201Used by MatrixPlan to store symbolic index arrays on GPU.
202+ Accepts concrete types like MtlVector{T, PrivateStorage}.
181203"""
182- LinearAlgebraMPI. _index_array_type (:: Type{MtlVector{T}} , :: Type{Ti} ) where {T,Ti} = MtlVector{Ti}
204+ LinearAlgebraMPI. _index_array_type (:: Type{<: MtlVector{T}} , :: Type{Ti} ) where {T,Ti} = MtlVector{Ti}
183205
184206"""
185- _to_target_backend(v::Vector{Ti}, ::Type{MtlVector{T}} ) where {Ti,T}
207+ _to_target_backend(v::Vector{Ti}, ::Type{<: MtlVector} ) where Ti
186208
187209Convert a CPU index vector to Metal GPU.
188210Used by SparseMatrixMPI constructors to create GPU structure arrays.
211+ Accepts concrete types like MtlVector{T, PrivateStorage}.
212+ """
213+ LinearAlgebraMPI. _to_target_backend (v:: Vector{Ti} , :: Type{<:MtlVector} ) where Ti = MtlVector (v)
214+
215+ # ============================================================================
216+ # GPU map_rows_gpu implementation via Metal kernels
217+ # ============================================================================
218+
219+ using StaticArrays
220+
189221"""
190- LinearAlgebraMPI. _to_target_backend (v:: Vector{Ti} , :: Type{<:MtlVector} ) where {Ti} = MtlVector (v)
222+ _map_rows_gpu_kernel(f, arg1::MtlMatrix, rest::MtlMatrix...)
223+
224+ GPU-accelerated row-wise map for Metal arrays.
225+ Each thread processes one row, applying `f` to the corresponding rows of all input matrices.
226+ Returns a Metal matrix with the same number of rows.
227+ """
228+ function LinearAlgebraMPI. _map_rows_gpu_kernel (f, arg1:: MtlMatrix{T} , rest:: MtlMatrix... ) where T
229+ n = size (arg1, 1 )
230+
231+ # For very small problems, fall back to CPU (kernel launch overhead dominates)
232+ if n < 256
233+ # CPU fallback for small arrays
234+ arg1_cpu = Array (arg1)
235+ rest_cpu = map (Array, rest)
236+ result_cpu = LinearAlgebraMPI. _map_rows_cpu_kernel (f, arg1_cpu, rest_cpu... )
237+ return MtlMatrix (result_cpu)
238+ end
239+
240+ # Get output size by evaluating f on first row
241+ first_rows = (SVector {size(arg1,2),T} (ntuple (j -> arg1[1 ,j], size (arg1,2 ))),)
242+ for m in rest
243+ first_rows = (first_rows... , SVector {size(m,2),T} (ntuple (j -> m[1 ,j], size (m,2 ))))
244+ end
245+ sample_out = f (first_rows... )
246+
247+ if sample_out isa SVector
248+ out_cols = length (sample_out)
249+ elseif sample_out isa SMatrix
250+ out_cols = length (sample_out) # Flatten matrix output
251+ else
252+ out_cols = 1 # Scalar output
253+ end
254+
255+ # Allocate output
256+ output = Metal. zeros (T, n, out_cols)
257+
258+ # Create kernel
259+ _map_rows_kernel_dispatch (f, output, arg1, rest... )
260+
261+ return output
262+ end
263+
264+ """
265+ Dispatch to appropriate kernel based on number of arguments.
266+ """
267+ function _map_rows_kernel_dispatch (f, output:: MtlMatrix{T} , arg1:: MtlMatrix{T} ) where T
268+ n = size (arg1, 1 )
269+ ncols1 = size (arg1, 2 )
270+ out_cols = size (output, 2 )
271+
272+ kernel = @metal launch= false _map_rows_kernel_1arg (f, output, arg1, Val (ncols1), Val (out_cols))
273+ threads = min (n, 256 )
274+ groups = cld (n, threads)
275+ kernel (f, output, arg1, Val (ncols1), Val (out_cols); threads= threads, groups= groups)
276+ Metal. synchronize ()
277+ end
278+
279+ function _map_rows_kernel_dispatch (f, output:: MtlMatrix{T} , arg1:: MtlMatrix{T} , arg2:: MtlMatrix{T} ) where T
280+ n = size (arg1, 1 )
281+ ncols1 = size (arg1, 2 )
282+ ncols2 = size (arg2, 2 )
283+ out_cols = size (output, 2 )
284+
285+ kernel = @metal launch= false _map_rows_kernel_2args (f, output, arg1, arg2, Val (ncols1), Val (ncols2), Val (out_cols))
286+ threads = min (n, 256 )
287+ groups = cld (n, threads)
288+ kernel (f, output, arg1, arg2, Val (ncols1), Val (ncols2), Val (out_cols); threads= threads, groups= groups)
289+ Metal. synchronize ()
290+ end
291+
292+ function _map_rows_kernel_dispatch (f, output:: MtlMatrix{T} , arg1:: MtlMatrix{T} , arg2:: MtlMatrix{T} , arg3:: MtlMatrix{T} ) where T
293+ n = size (arg1, 1 )
294+ ncols1 = size (arg1, 2 )
295+ ncols2 = size (arg2, 2 )
296+ ncols3 = size (arg3, 2 )
297+ out_cols = size (output, 2 )
298+
299+ kernel = @metal launch= false _map_rows_kernel_3args (f, output, arg1, arg2, arg3, Val (ncols1), Val (ncols2), Val (ncols3), Val (out_cols))
300+ threads = min (n, 256 )
301+ groups = cld (n, threads)
302+ kernel (f, output, arg1, arg2, arg3, Val (ncols1), Val (ncols2), Val (ncols3), Val (out_cols); threads= threads, groups= groups)
303+ Metal. synchronize ()
304+ end
305+
306+ function _map_rows_kernel_dispatch (f, output:: MtlMatrix{T} , arg1:: MtlMatrix{T} , arg2:: MtlMatrix{T} , arg3:: MtlMatrix{T} , arg4:: MtlMatrix{T} ) where T
307+ n = size (arg1, 1 )
308+ ncols1 = size (arg1, 2 )
309+ ncols2 = size (arg2, 2 )
310+ ncols3 = size (arg3, 2 )
311+ ncols4 = size (arg4, 2 )
312+ out_cols = size (output, 2 )
313+
314+ kernel = @metal launch= false _map_rows_kernel_4args (f, output, arg1, arg2, arg3, arg4, Val (ncols1), Val (ncols2), Val (ncols3), Val (ncols4), Val (out_cols))
315+ threads = min (n, 256 )
316+ groups = cld (n, threads)
317+ kernel (f, output, arg1, arg2, arg3, arg4, Val (ncols1), Val (ncols2), Val (ncols3), Val (ncols4), Val (out_cols); threads= threads, groups= groups)
318+ Metal. synchronize ()
319+ end
320+
321+ function _map_rows_kernel_dispatch (f, output:: MtlMatrix{T} , arg1:: MtlMatrix{T} , arg2:: MtlMatrix{T} , arg3:: MtlMatrix{T} , arg4:: MtlMatrix{T} , arg5:: MtlMatrix{T} ) where T
322+ n = size (arg1, 1 )
323+ ncols1 = size (arg1, 2 )
324+ ncols2 = size (arg2, 2 )
325+ ncols3 = size (arg3, 2 )
326+ ncols4 = size (arg4, 2 )
327+ ncols5 = size (arg5, 2 )
328+ out_cols = size (output, 2 )
329+
330+ kernel = @metal launch= false _map_rows_kernel_5args (f, output, arg1, arg2, arg3, arg4, arg5, Val (ncols1), Val (ncols2), Val (ncols3), Val (ncols4), Val (ncols5), Val (out_cols))
331+ threads = min (n, 256 )
332+ groups = cld (n, threads)
333+ kernel (f, output, arg1, arg2, arg3, arg4, arg5, Val (ncols1), Val (ncols2), Val (ncols3), Val (ncols4), Val (ncols5), Val (out_cols); threads= threads, groups= groups)
334+ Metal. synchronize ()
335+ end
336+
337+ # ============================================================================
338+ # Metal kernels
339+ # ============================================================================
340+
341+ function _map_rows_kernel_1arg (f, output, arg1, :: Val{NC1} , :: Val{OCols} ) where {NC1, OCols}
342+ i = thread_position_in_grid_1d ()
343+ n = size (arg1, 1 )
344+ if i <= n
345+ T = eltype (arg1)
346+ row1 = SVector {NC1,T} (ntuple (j -> @inbounds (arg1[i,j]), Val (NC1)))
347+ result = f (row1)
348+ _write_result! (output, i, result, Val (OCols))
349+ end
350+ return nothing
351+ end
352+
353+ function _map_rows_kernel_2args (f, output, arg1, arg2, :: Val{NC1} , :: Val{NC2} , :: Val{OCols} ) where {NC1, NC2, OCols}
354+ i = thread_position_in_grid_1d ()
355+ n = size (arg1, 1 )
356+ if i <= n
357+ T = eltype (arg1)
358+ row1 = SVector {NC1,T} (ntuple (j -> @inbounds (arg1[i,j]), Val (NC1)))
359+ row2 = SVector {NC2,T} (ntuple (j -> @inbounds (arg2[i,j]), Val (NC2)))
360+ result = f (row1, row2)
361+ _write_result! (output, i, result, Val (OCols))
362+ end
363+ return nothing
364+ end
365+
366+ function _map_rows_kernel_3args (f, output, arg1, arg2, arg3, :: Val{NC1} , :: Val{NC2} , :: Val{NC3} , :: Val{OCols} ) where {NC1, NC2, NC3, OCols}
367+ i = thread_position_in_grid_1d ()
368+ n = size (arg1, 1 )
369+ if i <= n
370+ T = eltype (arg1)
371+ row1 = SVector {NC1,T} (ntuple (j -> @inbounds (arg1[i,j]), Val (NC1)))
372+ row2 = SVector {NC2,T} (ntuple (j -> @inbounds (arg2[i,j]), Val (NC2)))
373+ row3 = SVector {NC3,T} (ntuple (j -> @inbounds (arg3[i,j]), Val (NC3)))
374+ result = f (row1, row2, row3)
375+ _write_result! (output, i, result, Val (OCols))
376+ end
377+ return nothing
378+ end
379+
380+ function _map_rows_kernel_4args (f, output, arg1, arg2, arg3, arg4, :: Val{NC1} , :: Val{NC2} , :: Val{NC3} , :: Val{NC4} , :: Val{OCols} ) where {NC1, NC2, NC3, NC4, OCols}
381+ i = thread_position_in_grid_1d ()
382+ n = size (arg1, 1 )
383+ if i <= n
384+ T = eltype (arg1)
385+ row1 = SVector {NC1,T} (ntuple (j -> @inbounds (arg1[i,j]), Val (NC1)))
386+ row2 = SVector {NC2,T} (ntuple (j -> @inbounds (arg2[i,j]), Val (NC2)))
387+ row3 = SVector {NC3,T} (ntuple (j -> @inbounds (arg3[i,j]), Val (NC3)))
388+ row4 = SVector {NC4,T} (ntuple (j -> @inbounds (arg4[i,j]), Val (NC4)))
389+ result = f (row1, row2, row3, row4)
390+ _write_result! (output, i, result, Val (OCols))
391+ end
392+ return nothing
393+ end
394+
395+ function _map_rows_kernel_5args (f, output, arg1, arg2, arg3, arg4, arg5, :: Val{NC1} , :: Val{NC2} , :: Val{NC3} , :: Val{NC4} , :: Val{NC5} , :: Val{OCols} ) where {NC1, NC2, NC3, NC4, NC5, OCols}
396+ i = thread_position_in_grid_1d ()
397+ n = size (arg1, 1 )
398+ if i <= n
399+ T = eltype (arg1)
400+ row1 = SVector {NC1,T} (ntuple (j -> @inbounds (arg1[i,j]), Val (NC1)))
401+ row2 = SVector {NC2,T} (ntuple (j -> @inbounds (arg2[i,j]), Val (NC2)))
402+ row3 = SVector {NC3,T} (ntuple (j -> @inbounds (arg3[i,j]), Val (NC3)))
403+ row4 = SVector {NC4,T} (ntuple (j -> @inbounds (arg4[i,j]), Val (NC4)))
404+ row5 = SVector {NC5,T} (ntuple (j -> @inbounds (arg5[i,j]), Val (NC5)))
405+ result = f (row1, row2, row3, row4, row5)
406+ _write_result! (output, i, result, Val (OCols))
407+ end
408+ return nothing
409+ end
410+
411+ # Helper to write result (scalar, SVector, or SMatrix) to output row
412+ @inline function _write_result! (output, i, result:: Number , :: Val{1} )
413+ @inbounds output[i, 1 ] = result
414+ return nothing
415+ end
416+
417+ @inline function _write_result! (output, i, result:: SVector{N,T} , :: Val{N} ) where {N,T}
418+ for j in 1 : N
419+ @inbounds output[i, j] = result[j]
420+ end
421+ return nothing
422+ end
423+
424+ @inline function _write_result! (output, i, result:: SMatrix{M,N,T} , :: Val{MN} ) where {M,N,T,MN}
425+ for j in 1 : MN
426+ @inbounds output[i, j] = result[j]
427+ end
428+ return nothing
429+ end
191430
192431end # module
0 commit comments