Skip to content

Commit db43eca

Browse files
committed
Guard linalg for objects not declaring matrix/vector interface
1 parent 4b3df28 commit db43eca

2 files changed

Lines changed: 203 additions & 45 deletions

File tree

src/linalg.jl

Lines changed: 169 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -80,48 +80,45 @@ function Base.copy(Ψ::GradVector{num_controls,T}) where {num_controls,T}
8080
end
8181

8282

83-
function Base.length::GradVector)
83+
# === Vector interface for GradVector ===
84+
#
85+
# The following methods are part of the vector interface and are only
86+
# meaningful when `supports_vector_interface` is true for the state type T.
87+
# Each method delegates to a private `_name(::Val{supports}, ...)` function:
88+
# the Val{true} method contains the implementation, and the Val{false} method
89+
# throws an error.
90+
91+
function _length(::Val{true}, Ψ::GradVector)
8492
return length.state) * (1 + length.grad_states))
8593
end
8694

87-
88-
function Base.size::GradVector{num_controls,T}) where {num_controls,T}
89-
return ((num_controls + 1) * length.state),)
95+
function _length(::Val{false}, Ψ::GradVector)
96+
error("$(typeof(Ψ)) does not support the vector interface")
9097
end
9198

92-
93-
function Base.size(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
94-
return (num_controls + 1) .* size(O.G)
99+
function Base.length::T) where {T<:GradVector}
100+
return _length(Val(supports_vector_interface(T)), Ψ)
95101
end
96102

97103

98-
function Base.size(
99-
O::GradgenOperator{num_controls,GT,CGT},
100-
dim::Integer
101-
) where {num_controls,GT,CGT}
102-
return (num_controls + 1) * size(O.G, dim)
104+
function _size(::Val{true}, Ψ::GradVector{num_controls,T}) where {num_controls,T}
105+
return ((num_controls + 1) * length.state),)
103106
end
104107

105-
106-
function Base.similar::GradVector{num_controls,T}) where {num_controls,T}
107-
state_sim = similar.state)
108-
grad_states_sim = [similar(ϕ) for ϕ Ψ.grad_states]
109-
return GradVector{num_controls,typeof(state_sim)}(state_sim, grad_states_sim)
108+
function _size(::Val{false}, Ψ::GradVector)
109+
error("$(typeof(Ψ)) does not support the vector interface")
110110
end
111111

112-
Base.similar::GradVector, ::Type{S}) where {S} = Vector{S}(undef, length(Ψ))
113-
114-
Base.similar::GradVector, dims::Tuple{Vararg{Int}}) = Array{eltype(Ψ)}(undef, dims)
112+
function Base.size::T) where {T<:GradVector}
113+
return _size(Val(supports_vector_interface(T)), Ψ)
114+
end
115115

116-
# These definitions of `similar` exist to make ExponentialUtilities happy, but
117-
# it's not clear at all that `similar` with a custom shape really makes sense
118-
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int,Int}) where {T} =
119-
Matrix{T}(undef, dims...)
120116

121-
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int}) where {T} =
122-
Vector{T}(undef, dims[1])
123-
124-
function Base.getindex::GradVector{num_controls,T}, k::Int) where {num_controls,T}
117+
function _getindex(
118+
::Val{true},
119+
Ψ::GradVector{num_controls,T},
120+
k::Int
121+
) where {num_controls,T}
125122
N = length.state)
126123
L = num_controls
127124
block = (k - 1) ÷ N + 1
@@ -133,7 +130,21 @@ function Base.getindex(Ψ::GradVector{num_controls,T}, k::Int) where {num_contro
133130
end
134131
end
135132

136-
function Base.setindex!::GradVector{num_controls,T}, v, k::Int) where {num_controls,T}
133+
function _getindex(::Val{false}, Ψ::GradVector, k::Int)
134+
error("$(typeof(Ψ)) does not support the vector interface")
135+
end
136+
137+
function Base.getindex::T, k::Int) where {T<:GradVector}
138+
return _getindex(Val(supports_vector_interface(T)), Ψ, k)
139+
end
140+
141+
142+
function _setindex!(
143+
::Val{true},
144+
Ψ::GradVector{num_controls,T},
145+
v,
146+
k::Int
147+
) where {num_controls,T}
137148
N = length.state)
138149
L = num_controls
139150
block = (k - 1) ÷ N + 1
@@ -146,28 +157,122 @@ function Base.setindex!(Ψ::GradVector{num_controls,T}, v, k::Int) where {num_co
146157
return Ψ
147158
end
148159

149-
function Base.iterate::GradVector, k = 1)
160+
function _setindex!(::Val{false}, Ψ::GradVector, v, k::Int)
161+
error("$(typeof(Ψ)) does not support the vector interface")
162+
end
163+
164+
function Base.setindex!::T, v, k::Int) where {T<:GradVector}
165+
return _setindex!(Val(supports_vector_interface(T)), Ψ, v, k)
166+
end
167+
168+
169+
function _iterate(::Val{true}, Ψ::GradVector, k)
150170
k > length(Ψ) && return nothing
151171
return (Ψ[k], k + 1)
152172
end
153173

174+
function _iterate(::Val{false}, Ψ::GradVector, k)
175+
error("$(typeof(Ψ)) does not support the vector interface")
176+
end
177+
178+
function Base.iterate::T, k = 1) where {T<:GradVector}
179+
return _iterate(Val(supports_vector_interface(T)), Ψ, k)
180+
end
181+
182+
183+
function Base.similar::GradVector{num_controls,T}) where {num_controls,T}
184+
state_sim = similar.state)
185+
grad_states_sim = [similar(ϕ) for ϕ Ψ.grad_states]
186+
return GradVector{num_controls,typeof(state_sim)}(state_sim, grad_states_sim)
187+
end
188+
189+
# similar(Ψ, S) calls length(Ψ), which will error if !supports_vector_interface
190+
Base.similar::GradVector, ::Type{S}) where {S} = Vector{S}(undef, length(Ψ))
191+
192+
# similar(Ψ, dims) calls eltype(Ψ) but not length/size, so no vector interface needed
193+
Base.similar::GradVector, dims::Tuple{Vararg{Int}}) = Array{eltype(Ψ)}(undef, dims)
194+
195+
# These definitions of `similar` exist to make ExponentialUtilities happy, but
196+
# it's not clear at all that `similar` with a custom shape really makes sense
197+
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int,Int}) where {T} =
198+
Matrix{T}(undef, dims...)
199+
200+
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int}) where {T} =
201+
Vector{T}(undef, dims[1])
202+
203+
204+
function Base.fill!::GradVector, v)
205+
Base.fill!.state, v)
206+
for i = 1:length.grad_states)
207+
Base.fill!.grad_states[i], v)
208+
end
209+
return Ψ
210+
end
211+
212+
213+
# === Matrix interface for GradgenOperator ===
214+
#
215+
# The following methods are part of the matrix interface and are only
216+
# meaningful when `supports_matrix_interface` is true for both component types.
217+
# Each method delegates to a private `_name(::Val{supports}, ...)` function:
218+
# the Val{true} method contains the implementation, and the Val{false} method
219+
# throws an error.
220+
221+
function _size(
222+
::Val{true},
223+
O::GradgenOperator{num_controls,GT,CGT}
224+
) where {num_controls,GT,CGT}
225+
return (num_controls + 1) .* size(O.G)
226+
end
227+
228+
function _size(::Val{false}, O::GradgenOperator)
229+
error("$(typeof(O)) does not support the matrix interface")
230+
end
231+
232+
function Base.size(O::T) where {T<:GradgenOperator}
233+
return _size(Val(supports_matrix_interface(T)), O)
234+
end
235+
236+
237+
function _size(
238+
::Val{true},
239+
O::GradgenOperator{num_controls,GT,CGT},
240+
dim::Integer
241+
) where {num_controls,GT,CGT}
242+
return (num_controls + 1) * size(O.G, dim)
243+
end
244+
245+
function _size(::Val{false}, O::GradgenOperator, dim::Integer)
246+
error("$(typeof(O)) does not support the matrix interface")
247+
end
248+
249+
function Base.size(O::T, dim::Integer) where {T<:GradgenOperator}
250+
return _size(Val(supports_matrix_interface(T)), O, dim)
251+
end
252+
253+
154254
# As for an `Operator`, we implement `similar` to return a standard `Array`
155-
# because `GradgenOperator` does not `setindex!`, so it's arguable not a
156-
# "mutable array"even if its components are mutable.
255+
# because `GradgenOperator` does not `setindex!`, so it's arguably not a
256+
# "mutable array" even if its components are mutable.
257+
# similar(O) and similar(O, S) call size(O), which will error if
258+
# !supports_matrix_interface. The dims-based variants need no guard.
157259
Base.similar(G::GradgenOperator) = Array{eltype(G)}(undef, size(G))
158260

159261
Base.similar(O::GradgenOperator, ::Type{S}) where {S} = Array{S}(undef, size(O))
160262
Base.similar(O::GradgenOperator, dims::Tuple{Vararg{Int}}) = Array{eltype(O)}(undef, dims)
161263
Base.similar(O::GradgenOperator, ::Type{S}, dims::Tuple{Vararg{Int}}) where {S} =
162264
Array{S}(undef, dims)
163265

266+
164267
function Base.eltype(
165268
::Type{GradgenOperator{num_controls,GT,CGT}}
166269
) where {num_controls,GT,CGT}
167270
return promote_type(eltype(GT), eltype(CGT))
168271
end
169272

170-
function Base.getindex(
273+
274+
function _getindex(
275+
::Val{true},
171276
O::GradgenOperator{num_controls,GT,CGT},
172277
row::Int,
173278
col::Int
@@ -188,9 +293,29 @@ function Base.getindex(
188293
end
189294
end
190295

191-
Base.length(O::GradgenOperator) = prod(size(O))
296+
function _getindex(::Val{false}, O::GradgenOperator, row::Int, col::Int)
297+
error("$(typeof(O)) does not support the matrix interface")
298+
end
299+
300+
function Base.getindex(O::T, row::Int, col::Int) where {T<:GradgenOperator}
301+
return _getindex(Val(supports_matrix_interface(T)), O, row, col)
302+
end
303+
304+
305+
function _length(::Val{true}, O::GradgenOperator)
306+
return prod(size(O))
307+
end
308+
309+
function _length(::Val{false}, O::GradgenOperator)
310+
error("$(typeof(O)) does not support the matrix interface")
311+
end
312+
313+
function Base.length(O::T) where {T<:GradgenOperator}
314+
return _length(Val(supports_matrix_interface(T)), O)
315+
end
316+
192317

193-
function Base.iterate(O::GradgenOperator, k = 1)
318+
function _iterate(::Val{true}, O::GradgenOperator, k)
194319
n = length(O)
195320
k > n && return nothing
196321
n_rows = size(O, 1)
@@ -199,6 +324,15 @@ function Base.iterate(O::GradgenOperator, k = 1)
199324
return (O[i, j], k + 1)
200325
end
201326

327+
function _iterate(::Val{false}, O::GradgenOperator, k)
328+
error("$(typeof(O)) does not support the matrix interface")
329+
end
330+
331+
function Base.iterate(O::T, k = 1) where {T<:GradgenOperator}
332+
return _iterate(Val(supports_matrix_interface(T)), O, k)
333+
end
334+
335+
202336
function Base.eltype(::Type{GradVector{num_controls,T}}) where {num_controls,T}
203337
return eltype(T)
204338
end
@@ -209,15 +343,6 @@ function Base.copyto!(dest::GradgenOperator, src::GradgenOperator)
209343
end
210344

211345

212-
function Base.fill!::GradVector, v)
213-
Base.fill!.state, v)
214-
for i = 1:length.grad_states)
215-
Base.fill!.grad_states[i], v)
216-
end
217-
return Ψ
218-
end
219-
220-
221346
function Base.zero::GradVector{num_controls,T}) where {num_controls,T}
222347
return GradVector{num_controls,T}(zero.state), [zero(ϕ) for ϕ Ψ.grad_states])
223348
end

test/test_interface.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ end
103103
# iterate visits elements in column-major order, consistent with vec(Array(op))
104104
@test all(collect(op) .≈ vec(dense))
105105

106-
# 3-arg mul! agrees with 5-arg mul!(Phi, G, Psi, 1, 0)
106+
# 3-arg mul! agrees with 5-arg mul!(Phi, G, Psi, true, false)
107107
Psi = GradVector(rand(ComplexF64, N), L)
108108
Phi1 = GradVector(zeros(ComplexF64, N), L)
109109
Phi2 = GradVector(zeros(ComplexF64, N), L)
@@ -211,4 +211,37 @@ end
211211
# check_state still passes via the basic (non-vector) state interface
212212
@test check_state(gradvec)
213213

214+
# Vector interface methods must throw an error when not supported
215+
@test_throws "does not support the vector interface" gradvec[1]
216+
@test_throws "does not support the vector interface" (gradvec[1] = 0.0)
217+
@test_throws "does not support the vector interface" size(gradvec)
218+
@test_throws "does not support the vector interface" length(gradvec)
219+
@test_throws "does not support the vector interface" iterate(gradvec)
220+
221+
end
222+
223+
224+
225+
# A wrapper type with no supports_matrix_interface declaration (defaults to false)
226+
struct NonMatrixOp
227+
data::Matrix{ComplexF64}
228+
end
229+
230+
@testset "GradgenOperator without Matrix Interface" begin
231+
232+
N = 5
233+
L = 2
234+
G = NonMatrixOp(rand(ComplexF64, N, N))
235+
mu = [NonMatrixOp(rand(ComplexF64, N, N)) for _ = 1:L]
236+
op = GradgenOperator{L,NonMatrixOp,NonMatrixOp}(G, mu)
237+
238+
@test !supports_matrix_interface(typeof(op))
239+
240+
# Matrix interface methods must throw an error when not supported
241+
@test_throws "does not support the matrix interface" op[1, 1]
242+
@test_throws "does not support the matrix interface" size(op)
243+
@test_throws "does not support the matrix interface" size(op, 1)
244+
@test_throws "does not support the matrix interface" length(op)
245+
@test_throws "does not support the matrix interface" iterate(op)
246+
214247
end

0 commit comments

Comments
 (0)