@@ -12,6 +12,11 @@ function LinearAlgebra.mul!(Φ::GradVector, G::GradgenOperator, Ψ::GradVector,
1212end
1313
1414
15+ function LinearAlgebra. mul! (Φ:: GradVector , G:: GradgenOperator , Ψ:: GradVector )
16+ return LinearAlgebra. mul! (Φ, G, Ψ, true , false )
17+ end
18+
19+
1520function LinearAlgebra. lmul! (c, Ψ:: GradVector )
1621 LinearAlgebra. lmul! (c, Ψ. state)
1722 for i ∈ eachindex (Ψ. grad_states)
@@ -48,6 +53,11 @@ function LinearAlgebra.dot(Ψ::GradVector, Φ::GradVector)
4853end
4954
5055
56+ function LinearAlgebra. dot (Ψ:: GradVector , G:: GradgenOperator , Φ:: GradVector )
57+ return LinearAlgebra. dot (Ψ, G * Φ)
58+ end
59+
60+
5161LinearAlgebra. ishermitian (G:: GradgenOperator ) = false
5262
5363
@@ -70,41 +80,126 @@ function Base.copy(Ψ::GradVector{num_controls,T}) where {num_controls,T}
7080end
7181
7282
73- function Base. length (Ψ:: GradVector )
83+ # === Vector interface for GradVector ===
84+ #
85+ # The following methods are part of the vector interface and are only
86+ # meaningful when `supports_vector_interface` is true for the state type T.
87+ # Each method delegates to a private `_name(::Val{supports}, ...)` function:
88+ # the Val{true} method contains the implementation, and the Val{false} method
89+ # throws an error.
90+
91+ function _length (:: Val{true} , Ψ:: GradVector )
7492 return length (Ψ. state) * (1 + length (Ψ. grad_states))
7593end
7694
95+ function _length (:: Val{false} , Ψ:: GradVector )
96+ error (" $(typeof (Ψ)) does not support the vector interface" )
97+ end
7798
78- function Base. size (O :: GradgenOperator{num_controls,GT,CGT} ) where {num_controls,GT,CGT }
79- return (num_controls + 1 ) .* size (O . G )
99+ function Base. length (Ψ :: T ) where {T <: GradVector }
100+ return _length ( Val ( supports_vector_interface (T)), Ψ )
80101end
81102
82103
83- function Base. size (
84- O:: GradgenOperator{num_controls,GT,CGT} ,
85- dim:: Integer
86- ) where {num_controls,GT,CGT}
87- return (num_controls + 1 ) * size (O. G, dim)
104+ function _size (:: Val{true} , Ψ:: GradVector{num_controls,T} ) where {num_controls,T}
105+ return ((num_controls + 1 ) * length (Ψ. state),)
88106end
89107
108+ function _size (:: Val{false} , Ψ:: GradVector )
109+ error (" $(typeof (Ψ)) does not support the vector interface" )
110+ end
90111
91- function Base. similar (Ψ:: GradVector{num_controls,T} ) where {num_controls,T }
92- return GradVector {num_controls,T} ( similar (Ψ . state), [ similar (ϕ) for ϕ ∈ Ψ . grad_states] )
112+ function Base. size (Ψ:: T ) where {T <: GradVector }
113+ return _size ( Val ( supports_vector_interface (T)), Ψ )
93114end
94115
95- function Base. similar (G:: GradgenOperator{num_controls,GT,CGT} ) where {num_controls,GT,CGT}
96- return GradgenOperator {num_controls,GT,CGT} (similar (G. G), similar (G. control_deriv_ops))
116+
117+ function _getindex (
118+ :: Val{true} ,
119+ Ψ:: GradVector{num_controls,T} ,
120+ k:: Int
121+ ) where {num_controls,T}
122+ N = length (Ψ. state)
123+ L = num_controls
124+ block = (k - 1 ) ÷ N + 1
125+ local_k = (k - 1 ) % N + 1
126+ if block <= L
127+ return Ψ. grad_states[block][local_k]
128+ else
129+ return Ψ. state[local_k]
130+ end
97131end
98132
99- function Base . eltype (O :: GradgenOperator{num_controls,GT,CGT} ) where {num_controls,GT,CGT}
100- return promote_type ( eltype (GT), eltype (CGT) )
133+ function _getindex ( :: Val{false} , Ψ :: GradVector , k :: Int )
134+ error ( " $( typeof (Ψ)) does not support the vector interface " )
101135end
102136
103- function Base. copyto! (dest:: GradgenOperator , src:: GradgenOperator )
104- copyto! (dest. G, src. G)
105- copyto! (dest. control_deriv_ops, src. control_deriv_ops)
137+ function Base. getindex (Ψ:: T , k:: Int ) where {T<: GradVector }
138+ return _getindex (Val (supports_vector_interface (T)), Ψ, k)
139+ end
140+
141+
142+ function _setindex! (
143+ :: Val{true} ,
144+ Ψ:: GradVector{num_controls,T} ,
145+ v,
146+ k:: Int
147+ ) where {num_controls,T}
148+ N = length (Ψ. state)
149+ L = num_controls
150+ block = (k - 1 ) ÷ N + 1
151+ local_k = (k - 1 ) % N + 1
152+ if block <= L
153+ Ψ. grad_states[block][local_k] = v
154+ else
155+ Ψ. state[local_k] = v
156+ end
157+ return Ψ
158+ end
159+
160+ function _setindex! (:: Val{false} , Ψ:: GradVector , v, k:: Int )
161+ error (" $(typeof (Ψ)) does not support the vector interface" )
106162end
107163
164+ function Base. setindex! (Ψ:: T , v, k:: Int ) where {T<: GradVector }
165+ return _setindex! (Val (supports_vector_interface (T)), Ψ, v, k)
166+ end
167+
168+
169+ function _iterate (:: Val{true} , Ψ:: GradVector , k)
170+ k > length (Ψ) && return nothing
171+ return (Ψ[k], k + 1 )
172+ end
173+
174+ function _iterate (:: Val{false} , Ψ:: GradVector , k)
175+ error (" $(typeof (Ψ)) does not support the vector interface" )
176+ end
177+
178+ function Base. iterate (Ψ:: T , k = 1 ) where {T<: GradVector }
179+ return _iterate (Val (supports_vector_interface (T)), Ψ, k)
180+ end
181+
182+
183+ function Base. similar (Ψ:: GradVector{num_controls,T} ) where {num_controls,T}
184+ state_sim = similar (Ψ. state)
185+ grad_states_sim = [similar (ϕ) for ϕ ∈ Ψ. grad_states]
186+ return GradVector {num_controls,typeof(state_sim)} (state_sim, grad_states_sim)
187+ end
188+
189+ # similar(Ψ, S) calls length(Ψ), which will error if !supports_vector_interface
190+ Base. similar (Ψ:: GradVector , :: Type{S} ) where {S} = Vector {S} (undef, length (Ψ))
191+
192+ # similar(Ψ, dims) calls eltype(Ψ) but not length/size, so no vector interface needed
193+ Base. similar (Ψ:: GradVector , dims:: Tuple{Vararg{Int}} ) = Array {eltype(Ψ)} (undef, dims)
194+
195+ # These definitions of `similar` exist to make ExponentialUtilities happy, but
196+ # it's not clear at all that `similar` with a custom shape really makes sense
197+ Base. similar (:: GradVector , :: Type{T} , dims:: Tuple{Int,Int} ) where {T} =
198+ Matrix {T} (undef, dims... )
199+
200+ Base. similar (:: GradVector , :: Type{T} , dims:: Tuple{Int} ) where {T} =
201+ Vector {T} (undef, dims[1 ])
202+
108203
109204function Base. fill! (Ψ:: GradVector , v)
110205 Base. fill! (Ψ. state, v)
@@ -115,6 +210,139 @@ function Base.fill!(Ψ::GradVector, v)
115210end
116211
117212
213+ # === Matrix interface for GradgenOperator ===
214+ #
215+ # The following methods are part of the matrix interface and are only
216+ # meaningful when `supports_matrix_interface` is true for both component types.
217+ # Each method delegates to a private `_name(::Val{supports}, ...)` function:
218+ # the Val{true} method contains the implementation, and the Val{false} method
219+ # throws an error.
220+
221+ function _size (
222+ :: Val{true} ,
223+ O:: GradgenOperator{num_controls,GT,CGT}
224+ ) where {num_controls,GT,CGT}
225+ return (num_controls + 1 ) .* size (O. G)
226+ end
227+
228+ function _size (:: Val{false} , O:: GradgenOperator )
229+ error (" $(typeof (O)) does not support the matrix interface" )
230+ end
231+
232+ function Base. size (O:: T ) where {T<: GradgenOperator }
233+ return _size (Val (supports_matrix_interface (T)), O)
234+ end
235+
236+
237+ function _size (
238+ :: Val{true} ,
239+ O:: GradgenOperator{num_controls,GT,CGT} ,
240+ dim:: Integer
241+ ) where {num_controls,GT,CGT}
242+ return (num_controls + 1 ) * size (O. G, dim)
243+ end
244+
245+ function _size (:: Val{false} , O:: GradgenOperator , dim:: Integer )
246+ error (" $(typeof (O)) does not support the matrix interface" )
247+ end
248+
249+ function Base. size (O:: T , dim:: Integer ) where {T<: GradgenOperator }
250+ return _size (Val (supports_matrix_interface (T)), O, dim)
251+ end
252+
253+
254+ # As for an `Operator`, we implement `similar` to return a standard `Array`
255+ # because `GradgenOperator` does not `setindex!`, so it's arguably not a
256+ # "mutable array" even if its components are mutable.
257+ # similar(O) and similar(O, S) call size(O), which will error if
258+ # !supports_matrix_interface. The dims-based variants need no guard.
259+ Base. similar (G:: GradgenOperator ) = Array {eltype(G)} (undef, size (G))
260+
261+ Base. similar (O:: GradgenOperator , :: Type{S} ) where {S} = Array {S} (undef, size (O))
262+ Base. similar (O:: GradgenOperator , dims:: Tuple{Vararg{Int}} ) = Array {eltype(O)} (undef, dims)
263+ Base. similar (O:: GradgenOperator , :: Type{S} , dims:: Tuple{Vararg{Int}} ) where {S} =
264+ Array {S} (undef, dims)
265+
266+
267+ function Base. eltype (
268+ :: Type{GradgenOperator{num_controls,GT,CGT}}
269+ ) where {num_controls,GT,CGT}
270+ return promote_type (eltype (GT), eltype (CGT))
271+ end
272+
273+
274+ function _getindex (
275+ :: Val{true} ,
276+ O:: GradgenOperator{num_controls,GT,CGT} ,
277+ row:: Int ,
278+ col:: Int
279+ ) where {num_controls,GT,CGT}
280+ T = eltype (O)
281+ N, M = size (O. G)
282+ L = num_controls
283+ block_row = (row - 1 ) ÷ N + 1
284+ block_col = (col - 1 ) ÷ M + 1
285+ local_row = (row - 1 ) % N + 1
286+ local_col = (col - 1 ) % M + 1
287+ if block_row == block_col
288+ return convert (T, O. G[local_row, local_col])
289+ elseif block_col == L + 1 && block_row <= L
290+ return convert (T, O. control_deriv_ops[block_row][local_row, local_col])
291+ else
292+ return zero (T)
293+ end
294+ end
295+
296+ function _getindex (:: Val{false} , O:: GradgenOperator , row:: Int , col:: Int )
297+ error (" $(typeof (O)) does not support the matrix interface" )
298+ end
299+
300+ function Base. getindex (O:: T , row:: Int , col:: Int ) where {T<: GradgenOperator }
301+ return _getindex (Val (supports_matrix_interface (T)), O, row, col)
302+ end
303+
304+
305+ function _length (:: Val{true} , O:: GradgenOperator )
306+ return prod (size (O))
307+ end
308+
309+ function _length (:: Val{false} , O:: GradgenOperator )
310+ error (" $(typeof (O)) does not support the matrix interface" )
311+ end
312+
313+ function Base. length (O:: T ) where {T<: GradgenOperator }
314+ return _length (Val (supports_matrix_interface (T)), O)
315+ end
316+
317+
318+ function _iterate (:: Val{true} , O:: GradgenOperator , k)
319+ n = length (O)
320+ k > n && return nothing
321+ n_rows = size (O, 1 )
322+ i = (k - 1 ) % n_rows + 1
323+ j = (k - 1 ) ÷ n_rows + 1
324+ return (O[i, j], k + 1 )
325+ end
326+
327+ function _iterate (:: Val{false} , O:: GradgenOperator , k)
328+ error (" $(typeof (O)) does not support the matrix interface" )
329+ end
330+
331+ function Base. iterate (O:: T , k = 1 ) where {T<: GradgenOperator }
332+ return _iterate (Val (supports_matrix_interface (T)), O, k)
333+ end
334+
335+
336+ function Base. eltype (:: Type{GradVector{num_controls,T}} ) where {num_controls,T}
337+ return eltype (T)
338+ end
339+
340+ function Base. copyto! (dest:: GradgenOperator , src:: GradgenOperator )
341+ copyto! (dest. G, src. G)
342+ copyto! (dest. control_deriv_ops, src. control_deriv_ops)
343+ end
344+
345+
118346function Base. zero (Ψ:: GradVector{num_controls,T} ) where {num_controls,T}
119347 return GradVector {num_controls,T} (zero (Ψ. state), [zero (ϕ) for ϕ ∈ Ψ. grad_states])
120348end
0 commit comments