Skip to content

Commit b6e1790

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Fix GPU broadcast for VectorMPI .* MatrixMPI; remove CPU fallback
- Add _prepare_broadcast_arg for MatrixMPI to extract underlying .A matrix, enabling GPU broadcasts without non-bitstype wrapper - Remove n < 256 CPU fallback in Metal extension for unified GPU path - Fix scalar indexing in _map_rows_gpu_kernel (copy row to CPU first) Bump version to 0.1.9
1 parent 82c4a9f commit b6e1790

3 files changed

Lines changed: 11 additions & 13 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearAlgebraMPI"
22
uuid = "5bdd2be4-ae34-42ef-8b36-f4c85d48f377"
3-
version = "0.1.8"
3+
version = "0.1.9"
44
authors = ["S. Loisel"]
55

66
[deps]

ext/LinearAlgebraMPIMetalExt.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -228,19 +228,12 @@ Returns a Metal matrix with the same number of rows.
228228
function LinearAlgebraMPI._map_rows_gpu_kernel(f, arg1::MtlMatrix{T}, rest::MtlMatrix...) where T
229229
n = size(arg1, 1)
230230

231-
# For very small problems, fall back to CPU (kernel launch overhead dominates)
232-
if n < 256
233-
# CPU fallback for small arrays
234-
arg1_cpu = Array(arg1)
235-
rest_cpu = map(Array, rest)
236-
result_cpu = LinearAlgebraMPI._map_rows_cpu_kernel(f, arg1_cpu, rest_cpu...)
237-
return MtlMatrix(result_cpu)
238-
end
239-
240-
# Get output size by evaluating f on first row
241-
first_rows = (SVector{size(arg1,2),T}(ntuple(j -> arg1[1,j], size(arg1,2))),)
231+
# Get output size by evaluating f on first row (copy to CPU to avoid scalar indexing)
232+
arg1_row1 = Array(view(arg1, 1:1, :))[1, :]
233+
first_rows = (SVector{size(arg1,2),T}(arg1_row1...),)
242234
for m in rest
243-
first_rows = (first_rows..., SVector{size(m,2),T}(ntuple(j -> m[1,j], size(m,2))))
235+
m_row1 = Array(view(m, 1:1, :))[1, :]
236+
first_rows = (first_rows..., SVector{size(m,2),T}(m_row1...))
244237
end
245238
sample_out = f(first_rows...)
246239

src/dense.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,3 +1841,8 @@ end
18411841
Base.broadcasted(::typeof(*), α::Number, A::MatrixMPI) = α * A
18421842
Base.broadcasted(::typeof(*), A::MatrixMPI, α::Number) = A * α
18431843
Base.broadcasted(::typeof(/), A::MatrixMPI, α::Number) = A / α
1844+
1845+
# Handle MatrixMPI in VectorMPI broadcasts by extracting the underlying local matrix data
1846+
# This enables broadcasting VectorMPI .* MatrixMPI on GPU without passing the wrapper
1847+
# (wrapper contains non-bitstype fields that can't be passed to GPU kernels)
1848+
_prepare_broadcast_arg(m::MatrixMPI, ref_partition, comm) = m.A

0 commit comments

Comments
 (0)