Skip to content

Commit 06be7db

Browse files
timholyamontoison
authored andcommitted
Improve inferrability
There is no way to make the `S` keyword argument of `LinearOperator{T}` inferrable. This moves `S` to the second position, so anyone who wishes to write inferrable code can use `LinearOperator{T, Vector{T}}(args...)`. More inferrability & performance work
1 parent 04479cf commit 06be7db

8 files changed

Lines changed: 71 additions & 50 deletions

File tree

ext/LinearOperatorsLDLFactorizationsExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function LinearOperators.opLDL(M::AbstractMatrix; check::Bool = false)
1313
tprod! = @closure (res, u, α, β) -> LinearOperators.tmulFact!(res, LDL, u, α, β) # M.' = conj(M)
1414
ctprod! = @closure (res, w, α, β) -> LinearOperators.mulFact!(res, LDL, w, α, β)
1515
S = eltype(LDL)
16-
return LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
16+
return LinearOperator{S, Vector{S}}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
1717
#TODO: use iterative refinement.
1818
end
1919

@@ -31,7 +31,7 @@ function LinearOperators.opLDL(
3131
tprod! = @closure (res, u) -> ldiv!(res, LDL, u) # M.' = conj(M)
3232
ctprod! = @closure (res, w) -> ldiv!(res, LDL, w)
3333
S = eltype(LDL)
34-
return LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
34+
return LinearOperator{S, Vector{S}}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
3535
end
3636

3737
end # module

src/abstract.jl

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ to combine or otherwise alter them. They can be combined with
4343
other operators, with matrices and with scalars. Operators may
4444
be transposed and conjugate-transposed using the usual Julia syntax.
4545
"""
46-
mutable struct LinearOperator{T, I <: Integer, F, Ft, Fct, S} <: AbstractLinearOperator{T}
46+
mutable struct LinearOperator{T, S, I <: Integer, F, Ft, Fct} <: AbstractLinearOperator{T}
4747
nrow::I
4848
ncol::I
4949
symmetric::Bool
@@ -61,7 +61,7 @@ mutable struct LinearOperator{T, I <: Integer, F, Ft, Fct, S} <: AbstractLinearO
6161
allocated5::Bool # true for 5-args mul!, false for 3-args mul! until the vectors are allocated
6262
end
6363

64-
function LinearOperator{T}(
64+
function LinearOperator{T, S}(
6565
nrow::I,
6666
ncol::I,
6767
symmetric::Bool,
@@ -71,16 +71,15 @@ function LinearOperator{T}(
7171
ctprod!::Fct,
7272
nprod::I,
7373
ntprod::I,
74-
nctprod::I;
75-
S::Type = Vector{T},
76-
) where {T, I <: Integer, F, Ft, Fct}
74+
nctprod::I
75+
) where {T, S, I <: Integer, F, Ft, Fct}
7776
Mv5, Mtu5 = S(undef, 0), S(undef, 0)
7877
nargs = get_nargs(prod!)
7978
args5 = (nargs == 4)
8079
(args5 == false) || (nargs != 2) || throw(LinearOperatorException("Invalid number of arguments"))
8180
allocated5 = args5 ? true : false
8281
use_prod5! = args5 ? true : false
83-
return LinearOperator{T, I, F, Ft, Fct, S}(
82+
return LinearOperator{T, S, I, F, Ft, Fct}(
8483
nrow,
8584
ncol,
8685
symmetric,
@@ -99,19 +98,27 @@ function LinearOperator{T}(
9998
)
10099
end
101100

102-
LinearOperator{T}(
101+
# backward compatibility (not inferrable; use LinearOperator{T, S} if you want something inferrable)
102+
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, nprod, ntprod, nctprod; S::Type = Vector{T}) where {T} =
103+
LinearOperator{T, S}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, nprod, ntprod, nctprod)
104+
105+
LinearOperator{T, S}(
103106
nrow::I,
104107
ncol::I,
105108
symmetric::Bool,
106109
hermitian::Bool,
107110
prod!,
108111
tprod!,
109-
ctprod!;
110-
S::Type = Vector{T},
111-
) where {T, I <: Integer} =
112-
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S)
112+
ctprod!
113+
) where {T, S, I <: Integer} =
114+
LinearOperator{T, S}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0)
115+
116+
# backward compatibility (not inferrable)
117+
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!; S::Type = Vector{T}) where {T} =
118+
LinearOperator{T, S}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!)
113119

114120
# create operator from other operators with +, *, vcat,...
121+
# TODO: this is not a type, so it should not be uppercase
115122
function CompositeLinearOperator(
116123
T::Type,
117124
nrow::I,
@@ -121,13 +128,13 @@ function CompositeLinearOperator(
121128
prod!::F,
122129
tprod!::Ft,
123130
ctprod!::Fct,
124-
args5::Bool;
125-
S::Type = Vector{T},
126-
) where {I <: Integer, F, Ft, Fct}
131+
args5::Bool,
132+
::Type{S},
133+
) where {S, I <: Integer, F, Ft, Fct}
127134
Mv5, Mtu5 = S(undef, 0), S(undef, 0)
128135
allocated5 = true
129136
use_prod5! = true
130-
return LinearOperator{T, I, F, Ft, Fct, S}(
137+
return LinearOperator{T, S, I, F, Ft, Fct}(
131138
nrow,
132139
ncol,
133140
symmetric,
@@ -146,6 +153,20 @@ function CompositeLinearOperator(
146153
)
147154
end
148155

156+
# backward compatibility (not inferrable)
157+
CompositeLinearOperator(
158+
T::Type,
159+
nrow::I,
160+
ncol::I,
161+
symmetric::Bool,
162+
hermitian::Bool,
163+
prod!::F,
164+
tprod!::Ft,
165+
ctprod!::Fct,
166+
args5::Bool;
167+
S::Type = Vector{T}) where {I <: Integer, F, Ft, Fct} =
168+
CompositeLinearOperator(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, args5, S)
169+
149170
nprod(op::AbstractLinearOperator) = op.nprod
150171
ntprod(op::AbstractLinearOperator) = op.ntprod
151172
nctprod(op::AbstractLinearOperator) = op.nctprod

src/cat.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function hcat(A::AbstractLinearOperator, B::AbstractLinearOperator)
4848
S = promote_type(storage_type(A), storage_type(B))
4949
isconcretetype(S) ||
5050
throw(LinearOperatorException("storage types cannot be promoted to a concrete type"))
51-
CompositeLinearOperator(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, args5, S = S)
51+
CompositeLinearOperator(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, args5, S)
5252
end
5353

5454
function hcat(ops::AbstractLinearOperator...)
@@ -107,7 +107,7 @@ function vcat(A::AbstractLinearOperator, B::AbstractLinearOperator)
107107
S = promote_type(storage_type(A), storage_type(B))
108108
isconcretetype(S) ||
109109
throw(LinearOperatorException("storage types cannot be promoted to a concrete type"))
110-
CompositeLinearOperator(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, args5, S = S)
110+
CompositeLinearOperator(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, args5, S)
111111
end
112112

113113
function vcat(ops::AbstractLinearOperator...)

src/constructors.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function LinearOperator(
1616
prod! = @closure (res, v, α, β) -> mul!(res, M, v, α, β)
1717
tprod! = @closure (res, u, α, β) -> mul!(res, transpose(M), u, α, β)
1818
ctprod! = @closure (res, w, α, β) -> mul!(res, adjoint(M), w, α, β)
19-
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
19+
LinearOperator{T, S}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!)
2020
end
2121

2222
"""
@@ -43,7 +43,7 @@ end
4343

4444
"""
4545
LinearOperator(M::Hermitian{T}, S = Vector{T}) where {T}
46-
46+
4747
Constructs a linear operator from a Hermitian matrix. If
4848
its elements are real, it is also symmetric.
4949
Change `S` to use LinearOperators on GPU.
@@ -57,7 +57,7 @@ end
5757
LinearOperator(type::Type{T}, nrow, ncol, symmetric, hermitian, prod!,
5858
[tprod!=nothing, ctprod!=nothing],
5959
S = Vector{T}) where {T}
60-
60+
6161
Construct a linear operator from functions where the type is specified as the first argument.
6262
Change `S` to use LinearOperators on GPU.
6363
Notice that the linear operator does not enforce the type, so using a wrong type can
@@ -67,29 +67,29 @@ A = [im 1.0; 0.0 1.0] # Complex matrix
6767
function mulOp!(res, M, v, α, β)
6868
mul!(res, M, v, α, β)
6969
end
70-
op = LinearOperator(Float64, 2, 2, false, false,
71-
(res, v, α, β) -> mulOp!(res, A, v, α, β),
72-
(res, u, α, β) -> mulOp!(res, transpose(A), u, α, β),
70+
op = LinearOperator(Float64, 2, 2, false, false,
71+
(res, v, α, β) -> mulOp!(res, A, v, α, β),
72+
(res, u, α, β) -> mulOp!(res, transpose(A), u, α, β),
7373
(res, w, α, β) -> mulOp!(res, A', w, α, β))
7474
Matrix(op) # InexactError
7575
```
7676
The error is caused because `Matrix(op)` tries to create a Float64 matrix with the
7777
contents of the complex matrix `A`.
7878
7979
Using `*` may generate a vector that contains `NaN` values.
80-
This can also happen if you use the 3-args `mul!` function with a preallocated vector such as
80+
This can also happen if you use the 3-args `mul!` function with a preallocated vector such as
8181
`Vector{Float64}(undef, n)`.
8282
To fix this issue you will have to deal with the cases `β == 0` and `β != 0` separately:
8383
```
8484
d1 = [2.0; 3.0]
8585
function mulSquareOpDiagonal!(res, d, v, α, β::T) where T
8686
if β == zero(T)
8787
res .= α .* d .* v
88-
else
88+
else
8989
res .= α .* d .* v .+ β .* res
9090
end
9191
end
92-
op = LinearOperator(Float64, 2, 2, true, true,
92+
op = LinearOperator(Float64, 2, 2, true, true,
9393
(res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β))
9494
```
9595
@@ -98,7 +98,7 @@ In this case, using the 5-args `mul!` will generate storage vectors.
9898
9999
```
100100
A = rand(2, 2)
101-
op = LinearOperator(Float64, 2, 2, false, false,
101+
op = LinearOperator(Float64, 2, 2, false, false,
102102
(res, v) -> mul!(res, A, v),
103103
(res, w) -> mul!(res, A', w))
104104
```
@@ -114,7 +114,7 @@ function LinearOperator(
114114
prod!,
115115
tprod! = nothing,
116116
ctprod! = nothing;
117-
S = Vector{T},
117+
S::Type = Vector{T},
118118
) where {T, I <: Integer}
119-
return LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
119+
return LinearOperator{T, S}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!)
120120
end

src/kron.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function kron(A::AbstractLinearOperator, B::AbstractLinearOperator)
4141
symm = issymmetric(A) && issymmetric(B)
4242
herm = ishermitian(A) && ishermitian(B)
4343
nrow, ncol = m * p, n * q
44-
return LinearOperator{T}(nrow, ncol, symm, herm, prod!, tprod!, ctprod!)
44+
return LinearOperator{T, Vector{T}}(nrow, ncol, symm, herm, prod!, tprod!, ctprod!)
4545
end
4646

4747
kron(A::AbstractMatrix, B::AbstractLinearOperator) = kron(LinearOperator(A), B)

src/linalg.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ function opInverse(M::AbstractMatrix{T}; symm = false, herm = false) where {T}
2828
prod! = @closure (res, v, α, β) -> mulFact!(res, M, v, α, β)
2929
tprod! = @closure (res, u, α, β) -> mulFact!(res, transpose(M), u, α, β)
3030
ctprod! = @closure (res, w, α, β) -> mulFact!(res, adjoint(M), w, α, β)
31-
LinearOperator{T}(size(M, 2), size(M, 1), symm, herm, prod!, tprod!, ctprod!)
31+
LinearOperator{T, Vector{T}}(size(M, 2), size(M, 1), symm, herm, prod!, tprod!, ctprod!)
3232
end
3333

3434
"""
3535
opCholesky(M, [check=false])
3636
3737
Inverse of a Hermitian and positive definite matrix as a linear operator
38-
using its Cholesky factorization.
38+
using its Cholesky factorization.
3939
The factorization is computed only once.
4040
The optional `check` argument will perform cheap hermicity and definiteness
4141
checks.
@@ -53,7 +53,7 @@ function opCholesky(M::AbstractMatrix; check::Bool = false)
5353
tprod! = @closure (res, u, α, β) -> tmulFact!(res, LL, u, α, β) # M.' = conj(M)
5454
ctprod! = @closure (res, w, α, β) -> mulFact!(res, LL, w, α, β)
5555
S = eltype(LL)
56-
LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
56+
LinearOperator{S, Vector{S}}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
5757
#TODO: use iterative refinement.
5858
end
5959

@@ -64,7 +64,7 @@ Inverse of a symmetric matrix as a linear operator using its LDLᵀ factorizatio
6464
if it exists. The factorization is computed only once. The optional `check`
6565
argument will perform a cheap hermicity check.
6666
67-
If M is sparse and real, then only the upper triangle should be stored in order to use
67+
If M is sparse and real, then only the upper triangle should be stored in order to use
6868
[`LDLFactorizations.jl`](https://github.com/JuliaSmoothOptimizers/LDLFactorizations.jl):
6969
7070
using LDLFactorizations
@@ -91,7 +91,7 @@ The result is `x -> (I - 2 h hᵀ) x`.
9191
function opHouseholder(h::AbstractVector{T}) where {T}
9292
n = length(h)
9393
prod! = @closure (res, v, α, β) -> mulHouseholder!(res, h, v, α, β) # tprod will be inferred
94-
LinearOperator{T}(n, n, isreal(h), true, prod!, nothing, prod!)
94+
LinearOperator{T, Vector{T}}(n, n, isreal(h), true, prod!, nothing, prod!)
9595
end
9696

9797
function mulHermitian!(res, d, L, v, α, β::T) where {T}
@@ -113,7 +113,7 @@ function opHermitian(d::AbstractVector{S}, A::AbstractMatrix{T}) where {S, T}
113113
L = tril(A, -1)
114114
U = promote_type(S, T)
115115
prod! = @closure (res, v, α, β) -> mulHermitian!(res, d, L, v, α, β)
116-
LinearOperator{U}(m, m, isreal(A), true, prod!, nothing, nothing)
116+
LinearOperator{U, Vector{U}}(m, m, isreal(A), true, prod!, nothing, nothing)
117117
end
118118

119119
"""

src/operations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ function *(op1::AbstractLinearOperator, op2::AbstractLinearOperator)
150150
tprod! = @closure (res, u, α, β) -> prod_op!(res, transpose(op2), transpose(op1), utmp, u, α, β)
151151
ctprod! = @closure (res, w, α, β) -> prod_op!(res, adjoint(op2), adjoint(op1), wtmp, w, α, β)
152152
args5 = (has_args5(op1) && has_args5(op2))
153-
CompositeLinearOperator(T, m1, n2, false, false, prod!, tprod!, ctprod!, args5, S = S)
153+
CompositeLinearOperator(T, m1, n2, false, false, prod!, tprod!, ctprod!, args5, S)
154154
end
155155

156156
## Matrix times operator.
@@ -213,7 +213,7 @@ function +(op1::AbstractLinearOperator, op2::AbstractLinearOperator)
213213
S = promote_type(storage_type(op1), storage_type(op2))
214214
isconcretetype(S) ||
215215
throw(LinearOperatorException("storage types cannot be promoted to a concrete type"))
216-
return CompositeLinearOperator(T, m1, n1, symm, herm, prod!, tprod!, ctprod!, args5, S = S)
216+
return CompositeLinearOperator(T, m1, n1, symm, herm, prod!, tprod!, ctprod!, args5, S)
217217
end
218218

219219
# Operator + matrix.

src/special-operators.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Change `S` to use LinearOperators on GPU.
5252
"""
5353
function opEye(T::Type, n::Int; S = Vector{T})
5454
prod! = @closure (res, v, α, β) -> mulOpEye!(res, v, α, β, n)
55-
LinearOperator{T}(n, n, true, true, prod!, prod!, prod!, S = S)
55+
LinearOperator{T, S}(n, n, true, true, prod!, prod!, prod!)
5656
end
5757

5858
opEye(n::Int) = opEye(Float64, n)
@@ -71,7 +71,7 @@ function opEye(T::Type, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer}
7171
return opEye(T, nrow; S = S)
7272
end
7373
prod! = @closure (res, v, α, β) -> mulOpEye!(res, v, α, β, min(nrow, ncol))
74-
return LinearOperator{T}(nrow, ncol, false, false, prod!, prod!, prod!, S = S)
74+
return LinearOperator{T, S}(nrow, ncol, false, false, prod!, prod!, prod!)
7575
end
7676

7777
opEye(nrow::I, ncol::I) where {I <: Integer} = opEye(Float64, nrow, ncol)
@@ -94,7 +94,7 @@ Change `S` to use LinearOperators on GPU.
9494
"""
9595
function opOnes(T::Type, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer}
9696
prod! = @closure (res, v, α, β) -> mulOpOnes!(res, v, α, β)
97-
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
97+
LinearOperator{T, S}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!)
9898
end
9999

100100
opOnes(nrow::I, ncol::I) where {I <: Integer} = opOnes(Float64, nrow, ncol)
@@ -117,7 +117,7 @@ Change `S` to use LinearOperators on GPU.
117117
"""
118118
function opZeros(T::Type, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer}
119119
prod! = @closure (res, v, α, β) -> mulOpZeros!(res, v, α, β)
120-
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
120+
LinearOperator{T, S}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!)
121121
end
122122

123123
opZeros(nrow::I, ncol::I) where {I <: Integer} = opZeros(Float64, nrow, ncol)
@@ -138,7 +138,7 @@ Diagonal operator with the vector `d` on its main diagonal.
138138
function opDiagonal(d::AbstractVector{T}) where {T}
139139
prod! = @closure (res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β)
140140
ctprod! = @closure (res, w, α, β) -> mulSquareOpDiagonal!(res, conj.(d), w, α, β)
141-
LinearOperator{T}(length(d), length(d), true, isreal(d), prod!, prod!, ctprod!, S = typeof(d))
141+
LinearOperator{T, typeof(d)}(length(d), length(d), true, isreal(d), prod!, prod!, ctprod!)
142142
end
143143

144144
function mulOpDiagonal!(res, d, v, α, β::T, n_min) where {T}
@@ -161,11 +161,11 @@ function opDiagonal(nrow::I, ncol::I, d::AbstractVector{T}) where {T, I <: Integ
161161
prod! = @closure (res, v, α, β) -> mulOpDiagonal!(res, d, v, α, β, n_min)
162162
tprod! = @closure (res, u, α, β) -> mulOpDiagonal!(res, d, u, α, β, n_min)
163163
ctprod! = @closure (res, w, α, β) -> mulOpDiagonal!(res, conj.(d), w, α, β, n_min)
164-
LinearOperator{T}(nrow, ncol, false, false, prod!, tprod!, ctprod!, S = typeof(d))
164+
LinearOperator{T, typeof(d)}(nrow, ncol, false, false, prod!, tprod!, ctprod!)
165165
end
166166

167167
function mulRestrict!(res, I, v, α, β)
168-
res .= v[I]
168+
res .= view(v, I)
169169
end
170170

171171
function multRestrict!(res, I, u, α, β)
@@ -190,9 +190,9 @@ function opRestriction(Idx::LinearOperatorIndexType{I}, ncol::I; S = nothing) wh
190190
prod! = @closure (res, v, α, β) -> mulRestrict!(res, Idx, v, α, β)
191191
tprod! = @closure (res, u, α, β) -> multRestrict!(res, Idx, u, α, β)
192192
if isnothing(S)
193-
return LinearOperator{I}(nrow, ncol, false, false, prod!, tprod!, tprod!)
193+
return LinearOperator{I, Vector{I}}(nrow, ncol, false, false, prod!, tprod!, tprod!)
194194
else
195-
return LinearOperator{I}(nrow, ncol, false, false, prod!, tprod!, tprod!; S = S)
195+
return LinearOperator{I, S}(nrow, ncol, false, false, prod!, tprod!, tprod!)
196196
end
197197
end
198198

@@ -291,5 +291,5 @@ function BlockDiagonalOperator(ops...; S = promote_type(storage_type.(ops)...))
291291
symm = all((issymmetric(op) for op ops))
292292
herm = all((ishermitian(op) for op ops))
293293
args5 = all((has_args5(op) for op ops))
294-
CompositeLinearOperator(T, nrow, ncol, symm, herm, prod!, tprod!, ctprod!, args5, S = S)
294+
CompositeLinearOperator(T, nrow, ncol, symm, herm, prod!, tprod!, ctprod!, args5, S)
295295
end

0 commit comments

Comments
 (0)