926926_map_rows_cpu_kernel (f, args... ) = _map_rows_gpu_kernel (f, args... )
927927
928928"""
929- map_rows (f, A...)
929+ map_rows_gpu (f, A...)
930930
931- Apply function `f` to corresponding rows of distributed vectors/matrices.
931+ Apply function `f` to corresponding rows of distributed vectors/matrices (GPU-native) .
932932
933933Each argument in `A...` must be either a `VectorMPI` or `MatrixMPI`. All inputs
934934are repartitioned to match the partition of the first argument before applying `f`.
@@ -937,6 +937,9 @@ This implementation uses GPU-friendly broadcasting: matrices are converted to
937937Vector{SVector} via transpose+reinterpret, then f is broadcast over all arguments.
938938This avoids GPU->CPU->GPU round-trips when the underlying arrays are on GPU.
939939
940+ **Important**: The function `f` must be isbits-compatible (no captured non-isbits data)
941+ for GPU execution. Use [`map_rows`](@ref) for functions with arbitrary closures.
942+
940943For each row index i, `f` is called with:
941944- For `VectorMPI`: the scalar element at index i
942945- For `MatrixMPI`: an SVector containing the i-th row
@@ -956,21 +959,23 @@ The result type depends on what `f` returns:
956959# Element-wise product of two vectors
957960u = VectorMPI([1.0, 2.0, 3.0])
958961v = VectorMPI([4.0, 5.0, 6.0])
959- w = map_rows ((a, b) -> a * b, u, v) # VectorMPI([4.0, 10.0, 18.0])
962+ w = map_rows_gpu ((a, b) -> a * b, u, v) # VectorMPI([4.0, 10.0, 18.0])
960963
961964# Row norms of a matrix
962965A = MatrixMPI(randn(5, 3))
963- norms = map_rows (r -> norm(r), A) # VectorMPI of row norms
966+ norms = map_rows_gpu (r -> norm(r), A) # VectorMPI of row norms
964967
965968# Return SVector to build a matrix
966969A = MatrixMPI(randn(3, 2))
967- result = map_rows (r -> SVector(sum(r), prod(r)), A) # 3×2 MatrixMPI
970+ result = map_rows_gpu (r -> SVector(sum(r), prod(r)), A) # 3×2 MatrixMPI
968971
969972# Mixed inputs: matrix rows combined with vector elements
970973A = MatrixMPI(randn(4, 3))
971974w = VectorMPI([1.0, 2.0, 3.0, 4.0])
972- result = map_rows ((row, wi) -> sum(row) * wi, A, w) # VectorMPI
975+ result = map_rows_gpu ((row, wi) -> sum(row) * wi, A, w) # VectorMPI
973976```
977+
978+ See also: [`map_rows`](@ref) for CPU fallback version (handles arbitrary closures)
974979"""
975980function map_rows_gpu (f, A... )
976981 isempty (A) && error (" map_rows_gpu requires at least one argument" )
0 commit comments