Skip to content

Commit 359f785

Browse files
Sébastien LoiselSébastien Loisel
authored andcommitted
Fix map_rows_gpu docstring and add to API docs
1 parent 0976924 commit 359f785

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ SparseMatrixMPI_local
4545

4646
```@docs
4747
map_rows
48+
map_rows_gpu
4849
```
4950

5051
## Linear System Solvers

src/LinearAlgebraMPI.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -926,9 +926,9 @@ end
926926
_map_rows_cpu_kernel(f, args...) = _map_rows_gpu_kernel(f, args...)
927927

928928
"""
929-
map_rows(f, A...)
929+
map_rows_gpu(f, A...)
930930
931-
Apply function `f` to corresponding rows of distributed vectors/matrices.
931+
Apply function `f` to corresponding rows of distributed vectors/matrices (GPU-native).
932932
933933
Each argument in `A...` must be either a `VectorMPI` or `MatrixMPI`. All inputs
934934
are repartitioned to match the partition of the first argument before applying `f`.
@@ -937,6 +937,9 @@ This implementation uses GPU-friendly broadcasting: matrices are converted to
937937
Vector{SVector} via transpose+reinterpret, then f is broadcast over all arguments.
938938
This avoids GPU->CPU->GPU round-trips when the underlying arrays are on GPU.
939939
940+
**Important**: The function `f` must be isbits-compatible (no captured non-isbits data)
941+
for GPU execution. Use [`map_rows`](@ref) for functions with arbitrary closures.
942+
940943
For each row index i, `f` is called with:
941944
- For `VectorMPI`: the scalar element at index i
942945
- For `MatrixMPI`: an SVector containing the i-th row
@@ -956,21 +959,23 @@ The result type depends on what `f` returns:
956959
# Element-wise product of two vectors
957960
u = VectorMPI([1.0, 2.0, 3.0])
958961
v = VectorMPI([4.0, 5.0, 6.0])
959-
w = map_rows((a, b) -> a * b, u, v) # VectorMPI([4.0, 10.0, 18.0])
962+
w = map_rows_gpu((a, b) -> a * b, u, v) # VectorMPI([4.0, 10.0, 18.0])
960963
961964
# Row norms of a matrix
962965
A = MatrixMPI(randn(5, 3))
963-
norms = map_rows(r -> norm(r), A) # VectorMPI of row norms
966+
norms = map_rows_gpu(r -> norm(r), A) # VectorMPI of row norms
964967
965968
# Return SVector to build a matrix
966969
A = MatrixMPI(randn(3, 2))
967-
result = map_rows(r -> SVector(sum(r), prod(r)), A) # 3×2 MatrixMPI
970+
result = map_rows_gpu(r -> SVector(sum(r), prod(r)), A) # 3×2 MatrixMPI
968971
969972
# Mixed inputs: matrix rows combined with vector elements
970973
A = MatrixMPI(randn(4, 3))
971974
w = VectorMPI([1.0, 2.0, 3.0, 4.0])
972-
result = map_rows((row, wi) -> sum(row) * wi, A, w) # VectorMPI
975+
result = map_rows_gpu((row, wi) -> sum(row) * wi, A, w) # VectorMPI
973976
```
977+
978+
See also: [`map_rows`](@ref) for CPU fallback version (handles arbitrary closures)
974979
"""
975980
function map_rows_gpu(f, A...)
976981
isempty(A) && error("map_rows_gpu requires at least one argument")

0 commit comments

Comments
 (0)