Skip to content

Commit 7abb847

Browse files
author
Jeremy E Kozdon
committed
Add back matshell
1 parent 0499a9c commit 7abb847

5 files changed

Lines changed: 123 additions & 46 deletions

File tree

src/PETSc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ include("viewer.jl")
2424
include("options.jl")
2525
include("vec.jl")
2626
include("mat.jl")
27-
# include("matshell.jl")
27+
include("matshell.jl")
2828
# include("ksp.jl")
2929
# include("ref.jl")
3030
# include("pc.jl")

src/matshell.jl

Lines changed: 84 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,99 @@
11
"""
2-
MatShell{T}(obj, m, n)
3-
4-
Create a `m×n` PETSc shell matrix object wrapping `obj`.
5-
6-
If `obj` is a `Function`, then the multiply action `obj(y,x)`; otherwise it calls `mul!(y, obj, x)`.
7-
This can be changed by defining `PETSc._mul!`.
8-
2+
MatShell(
3+
petsclib::PetscLib,
4+
obj::OType,
5+
comm::MPI.Comm,
6+
local_rows,
7+
local_cols,
8+
global_rows = LibPETSc.PETSC_DECIDE,
9+
global_cols = LibPETSc.PETSC_DECIDE,
10+
)
11+
12+
Create a `global_rows X global_cols` PETSc shell matrix object wrapping `obj`
13+
with local size `local_rows X local_cols`.
14+
15+
The `obj` will be registered as an `MATOP_MULT` function and if if `obj` is a
16+
`Function`, then the multiply action `obj(y,x)`; otherwise it calls `mul!(y,
17+
obj, x)`.
18+
19+
# External Links
20+
$(_doc_external("Mat/MatCreateShell"))
21+
$(_doc_external("Mat/MatShellSetOperation"))
22+
$(_doc_external("Mat/MATOP_MULT"))
923
"""
10-
mutable struct MatShell{T,A} <: AbstractMat{T}
24+
mutable struct MatShell{PetscLib, PetscScalar, OType} <:
25+
AbstractMat{PetscLib, PetscScalar}
1126
ptr::CMat
12-
obj::A
27+
obj::OType
1328
end
1429

30+
struct MatOp{PetscLib, PetscInt, Op} end
31+
32+
function (::MatOp{PetscLib, PetscInt, LibPETSc.MATOP_MULT})(
33+
M::CMat,
34+
cx::CVec,
35+
cy::CVec,
36+
)::PetscInt where {PetscLib, PetscInt}
37+
r_ctx = Ref{Ptr{Cvoid}}()
38+
LibPETSc.MatShellGetContext(PetscLib, M, r_ctx)
39+
ptr = r_ctx[]
40+
mat = unsafe_pointer_to_objref(ptr)
1541

16-
struct MatOp{T,Op} end
42+
PetscScalar = getlib(PetscLib).PetscScalar
43+
x = unsafe_localarray(VecPtr(PetscLib, cx); write = false)
44+
y = unsafe_localarray(VecPtr(PetscLib, cy); read = false)
1745

46+
_mul!(y, mat, x)
1847

19-
function _mul!(y,mat::MatShell{T,F},x) where {T, F<:Function}
48+
Base.finalize(y)
49+
Base.finalize(x)
50+
return PetscInt(0)
51+
end
52+
53+
function _mul!(
54+
y,
55+
mat::MatShell{PetscLib, PetscScalar, F},
56+
x,
57+
) where {PetscLib, PetscScalar, F <: Function}
2058
mat.obj(y, x)
2159
end
2260

23-
function _mul!(y,mat::MatShell{T},x) where {T}
61+
function _mul!(y, mat::MatShell, x) where {T}
2462
LinearAlgebra.mul!(y, mat.obj, x)
2563
end
2664

27-
MatShell{T}(obj, m, n) where {T} = MatShell{T}(obj, MPI.COMM_SELF, m, n, m, n)
28-
29-
30-
@for_libpetsc begin
31-
function MatShell{$PetscScalar}(obj::A, comm::MPI.Comm, m, n, M, N) where {A}
32-
mat = MatShell{$PetscScalar,A}(C_NULL, obj)
33-
# we use the MatShell object itsel
34-
ctx = pointer_from_objref(mat)
35-
@chk ccall((:MatCreateShell, $libpetsc), PetscErrorCode,
36-
(MPI.MPI_Comm,$PetscInt,$PetscInt,$PetscInt,$PetscInt,Ptr{Cvoid},Ptr{CMat}),
37-
comm, m, n, M, N, ctx, mat)
38-
39-
mulptr = @cfunction(MatOp{$PetscScalar, MATOP_MULT}(), $PetscInt, (CMat, CVec, CVec))
40-
@chk ccall((:MatShellSetOperation, $libpetsc), PetscErrorCode, (CMat, MatOperation, Ptr{Cvoid}), mat, MATOP_MULT, mulptr)
41-
return mat
42-
end
43-
44-
function (::MatOp{$PetscScalar, MATOP_MULT})(M::CMat,cx::CVec,cy::CVec)::$PetscInt
45-
r_ctx = Ref{Ptr{Cvoid}}()
46-
@chk ccall((:MatShellGetContext, $libpetsc), PetscErrorCode, (CMat, Ptr{Ptr{Cvoid}}), M, r_ctx)
47-
ptr = r_ctx[]
48-
mat = unsafe_pointer_to_objref(ptr)
49-
50-
x = unsafe_localarray($PetscScalar, cx; write=false)
51-
y = unsafe_localarray($PetscScalar, cy; read=false)
52-
53-
_mul!(y,mat,x)
54-
55-
Base.finalize(y)
56-
Base.finalize(x)
57-
return $PetscInt(0)
58-
end
59-
65+
# We have to use the macro here because of the @cfunction
66+
LibPETSc.@for_petsc function MatShell(
67+
petsclib::$PetscLib,
68+
obj::OType,
69+
comm::MPI.Comm,
70+
local_rows,
71+
local_cols,
72+
global_rows = LibPETSc.PETSC_DECIDE,
73+
global_cols = LibPETSc.PETSC_DECIDE,
74+
) where {OType}
75+
mat = MatShell{$PetscLib, $PetscScalar, OType}(C_NULL, obj)
76+
77+
# we use the MatShell object itself
78+
ctx = pointer_from_objref(mat)
79+
80+
LibPETSc.MatCreateShell(
81+
petsclib,
82+
comm,
83+
local_rows,
84+
local_cols,
85+
global_rows,
86+
global_cols,
87+
pointer_from_objref(mat),
88+
mat,
89+
)
90+
91+
mulptr = @cfunction(
92+
MatOp{$PetscLib, $PetscInt, LibPETSc.MATOP_MULT}(),
93+
$PetscInt,
94+
(CMat, CVec, CVec)
95+
)
96+
LibPETSc.MatShellSetOperation(petsclib, mat, LibPETSc.MATOP_MULT, mulptr)
97+
98+
return mat
6099
end

src/vec.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@ Base.eltype(
1818
) where {PetscLib, PetscScalar} = PetscScalar
1919
Base.size(v::AbstractVec) = (length(v),)
2020

21+
"""
22+
VecPtr(petsclib, v::Vector)
23+
24+
Container type for a PETSc Vec that is passed from a function.
25+
"""
26+
mutable struct VecPtr{PetscLib, PetscScalar} <:
27+
AbstractVec{PetscLib, PetscScalar}
28+
ptr::CVec
29+
end
30+
function VecPtr(petsclib::PetscLib, ptr::CVec) where {PetscLib <: PetscLibType}
31+
return VecPtr{PetscLib, petsclib.PetscScalar}(ptr)
32+
end
33+
VecPtr(::Type{PetscLib}, ptr::CVec) where {PetscLib <: PetscLibType} =
34+
VecPtr(getlib(PetscLib), ptr)
35+
2136
"""
2237
VecSeq(petsclib, v::Vector)
2338

test/matshell.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Test
2+
using PETSc
3+
using MPI
4+
5+
@testset "MatShell" begin
6+
for petsclib in PETSc.petsclibs
7+
PETSc.initialize(petsclib)
8+
PetscScalar = petsclib.PetscScalar
9+
10+
local_rows = 10
11+
local_cols = 5
12+
f!(x, y) = x .= [2y; 3y]
13+
x_jl = collect
14+
15+
matshell =
16+
PETSc.MatShell(petsclib, f!, MPI.COMM_SELF, local_rows, local_cols)
17+
x = PetscScalar.(collect(1:5))
18+
@test matshell * x == [2x; 3x]
19+
20+
PETSc.finalize(petsclib)
21+
end
22+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ include("init.jl")
22
include("options.jl")
33
include("vec.jl")
44
include("mat.jl")
5+
include("matshell.jl")

0 commit comments

Comments
 (0)