Skip to content

Commit 93391c3

Browse files
committed
Merge #18 from branch matrix-vector-interfaces
2 parents 856448a + db43eca commit 93391c3

4 files changed

Lines changed: 433 additions & 26 deletions

File tree

src/grad_vector.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import QuantumControl.QuantumPropagators: _exp_prop_convert_state
2-
import QuantumControl.QuantumPropagators.Interfaces: supports_inplace
2+
import QuantumControl.QuantumPropagators.Interfaces:
3+
supports_inplace, supports_vector_interface
34

45

56
@doc raw"""Extended state-vector for the dynamic gradient.
@@ -68,8 +69,8 @@ in-place operations.
6869
6970
Returns `Ψ̃`.
7071
"""
71-
function resetgradvec!(Ψ̃::GradVector)
72-
if supports_inplace(Ψ̃)
72+
function resetgradvec!(Ψ̃::T) where {T<:GradVector}
73+
if supports_inplace(T)
7374
for i in eachindex(Ψ̃.grad_states)
7475
fill!(Ψ̃.grad_states[i], 0.0)
7576
end
@@ -89,4 +90,7 @@ end
8990

9091
_exp_prop_convert_state(::GradVector) = Vector{ComplexF64}
9192

92-
supports_inplace(Ψ̃::GradVector) = supports_inplace(Ψ̃.state)
93+
supports_inplace(::Type{GradVector{N,T}}) where {N,T} = supports_inplace(T)
94+
95+
supports_vector_interface(::Type{GradVector{N,T}}) where {N,T} =
96+
supports_vector_interface(T)

src/gradgen_operator.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using Random: GLOBAL_RNG
22
import QuantumControl.QuantumPropagators: _exp_prop_convert_operator
33
import QuantumControl.QuantumPropagators.Controls: get_controls
44
import QuantumControl.QuantumPropagators.SpectralRange: random_state
5-
import QuantumControl.QuantumPropagators.Interfaces: supports_inplace
5+
import QuantumControl.QuantumPropagators.Interfaces:
6+
supports_inplace, supports_matrix_interface
67

78

89
"""Static generator for the dynamic gradient.
@@ -40,4 +41,8 @@ end
4041

4142
_exp_prop_convert_operator(::GradgenOperator) = Matrix{ComplexF64}
4243

43-
supports_inplace(::GradgenOperator) = true
44+
supports_inplace(::Type{GradgenOperator{N,GT,CGT}}) where {N,GT,CGT} =
45+
(supports_inplace(GT) && supports_inplace(CGT))
46+
47+
supports_matrix_interface(::Type{<:GradgenOperator{N,GT,CGT}}) where {N,GT,CGT} =
48+
supports_matrix_interface(GT) && supports_matrix_interface(CGT)

src/linalg.jl

Lines changed: 245 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ function LinearAlgebra.mul!(Φ::GradVector, G::GradgenOperator, Ψ::GradVector,
1212
end
1313

1414

15+
function LinearAlgebra.mul!::GradVector, G::GradgenOperator, Ψ::GradVector)
16+
return LinearAlgebra.mul!(Φ, G, Ψ, true, false)
17+
end
18+
19+
1520
function LinearAlgebra.lmul!(c, Ψ::GradVector)
1621
LinearAlgebra.lmul!(c, Ψ.state)
1722
for i eachindex.grad_states)
@@ -48,6 +53,11 @@ function LinearAlgebra.dot(Ψ::GradVector, Φ::GradVector)
4853
end
4954

5055

56+
function LinearAlgebra.dot::GradVector, G::GradgenOperator, Φ::GradVector)
57+
return LinearAlgebra.dot(Ψ, G * Φ)
58+
end
59+
60+
5161
LinearAlgebra.ishermitian(G::GradgenOperator) = false
5262

5363

@@ -70,41 +80,126 @@ function Base.copy(Ψ::GradVector{num_controls,T}) where {num_controls,T}
7080
end
7181

7282

73-
function Base.length::GradVector)
83+
# === Vector interface for GradVector ===
84+
#
85+
# The following methods are part of the vector interface and are only
86+
# meaningful when `supports_vector_interface` is true for the state type T.
87+
# Each method delegates to a private `_name(::Val{supports}, ...)` function:
88+
# the Val{true} method contains the implementation, and the Val{false} method
89+
# throws an error.
90+
91+
function _length(::Val{true}, Ψ::GradVector)
7492
return length.state) * (1 + length.grad_states))
7593
end
7694

95+
function _length(::Val{false}, Ψ::GradVector)
96+
error("$(typeof(Ψ)) does not support the vector interface")
97+
end
7798

78-
function Base.size(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
79-
return (num_controls + 1) .* size(O.G)
99+
function Base.length::T) where {T<:GradVector}
100+
return _length(Val(supports_vector_interface(T)), Ψ)
80101
end
81102

82103

83-
function Base.size(
84-
O::GradgenOperator{num_controls,GT,CGT},
85-
dim::Integer
86-
) where {num_controls,GT,CGT}
87-
return (num_controls + 1) * size(O.G, dim)
104+
function _size(::Val{true}, Ψ::GradVector{num_controls,T}) where {num_controls,T}
105+
return ((num_controls + 1) * length.state),)
88106
end
89107

108+
function _size(::Val{false}, Ψ::GradVector)
109+
error("$(typeof(Ψ)) does not support the vector interface")
110+
end
90111

91-
function Base.similar::GradVector{num_controls,T}) where {num_controls,T}
92-
return GradVector{num_controls,T}(similar.state), [similar(ϕ) for ϕ Ψ.grad_states])
112+
function Base.size::T) where {T<:GradVector}
113+
return _size(Val(supports_vector_interface(T)), Ψ)
93114
end
94115

95-
function Base.similar(G::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
96-
return GradgenOperator{num_controls,GT,CGT}(similar(G.G), similar(G.control_deriv_ops))
116+
117+
function _getindex(
118+
::Val{true},
119+
Ψ::GradVector{num_controls,T},
120+
k::Int
121+
) where {num_controls,T}
122+
N = length.state)
123+
L = num_controls
124+
block = (k - 1) ÷ N + 1
125+
local_k = (k - 1) % N + 1
126+
if block <= L
127+
return Ψ.grad_states[block][local_k]
128+
else
129+
return Ψ.state[local_k]
130+
end
97131
end
98132

99-
function Base.eltype(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
100-
return promote_type(eltype(GT), eltype(CGT))
133+
function _getindex(::Val{false}, Ψ::GradVector, k::Int)
134+
error("$(typeof(Ψ)) does not support the vector interface")
101135
end
102136

103-
function Base.copyto!(dest::GradgenOperator, src::GradgenOperator)
104-
copyto!(dest.G, src.G)
105-
copyto!(dest.control_deriv_ops, src.control_deriv_ops)
137+
function Base.getindex::T, k::Int) where {T<:GradVector}
138+
return _getindex(Val(supports_vector_interface(T)), Ψ, k)
139+
end
140+
141+
142+
function _setindex!(
143+
::Val{true},
144+
Ψ::GradVector{num_controls,T},
145+
v,
146+
k::Int
147+
) where {num_controls,T}
148+
N = length.state)
149+
L = num_controls
150+
block = (k - 1) ÷ N + 1
151+
local_k = (k - 1) % N + 1
152+
if block <= L
153+
Ψ.grad_states[block][local_k] = v
154+
else
155+
Ψ.state[local_k] = v
156+
end
157+
return Ψ
158+
end
159+
160+
function _setindex!(::Val{false}, Ψ::GradVector, v, k::Int)
161+
error("$(typeof(Ψ)) does not support the vector interface")
106162
end
107163

164+
function Base.setindex!::T, v, k::Int) where {T<:GradVector}
165+
return _setindex!(Val(supports_vector_interface(T)), Ψ, v, k)
166+
end
167+
168+
169+
function _iterate(::Val{true}, Ψ::GradVector, k)
170+
k > length(Ψ) && return nothing
171+
return (Ψ[k], k + 1)
172+
end
173+
174+
function _iterate(::Val{false}, Ψ::GradVector, k)
175+
error("$(typeof(Ψ)) does not support the vector interface")
176+
end
177+
178+
function Base.iterate::T, k = 1) where {T<:GradVector}
179+
return _iterate(Val(supports_vector_interface(T)), Ψ, k)
180+
end
181+
182+
183+
function Base.similar::GradVector{num_controls,T}) where {num_controls,T}
184+
state_sim = similar.state)
185+
grad_states_sim = [similar(ϕ) for ϕ Ψ.grad_states]
186+
return GradVector{num_controls,typeof(state_sim)}(state_sim, grad_states_sim)
187+
end
188+
189+
# similar(Ψ, S) calls length(Ψ), which will error if !supports_vector_interface
190+
Base.similar::GradVector, ::Type{S}) where {S} = Vector{S}(undef, length(Ψ))
191+
192+
# similar(Ψ, dims) calls eltype(Ψ) but not length/size, so no vector interface needed
193+
Base.similar::GradVector, dims::Tuple{Vararg{Int}}) = Array{eltype(Ψ)}(undef, dims)
194+
195+
# These definitions of `similar` exist to make ExponentialUtilities happy, but
196+
# it's not clear at all that `similar` with a custom shape really makes sense
197+
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int,Int}) where {T} =
198+
Matrix{T}(undef, dims...)
199+
200+
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int}) where {T} =
201+
Vector{T}(undef, dims[1])
202+
108203

109204
function Base.fill!::GradVector, v)
110205
Base.fill!.state, v)
@@ -115,6 +210,139 @@ function Base.fill!(Ψ::GradVector, v)
115210
end
116211

117212

213+
# === Matrix interface for GradgenOperator ===
214+
#
215+
# The following methods are part of the matrix interface and are only
216+
# meaningful when `supports_matrix_interface` is true for both component types.
217+
# Each method delegates to a private `_name(::Val{supports}, ...)` function:
218+
# the Val{true} method contains the implementation, and the Val{false} method
219+
# throws an error.
220+
221+
function _size(
222+
::Val{true},
223+
O::GradgenOperator{num_controls,GT,CGT}
224+
) where {num_controls,GT,CGT}
225+
return (num_controls + 1) .* size(O.G)
226+
end
227+
228+
function _size(::Val{false}, O::GradgenOperator)
229+
error("$(typeof(O)) does not support the matrix interface")
230+
end
231+
232+
function Base.size(O::T) where {T<:GradgenOperator}
233+
return _size(Val(supports_matrix_interface(T)), O)
234+
end
235+
236+
237+
function _size(
238+
::Val{true},
239+
O::GradgenOperator{num_controls,GT,CGT},
240+
dim::Integer
241+
) where {num_controls,GT,CGT}
242+
return (num_controls + 1) * size(O.G, dim)
243+
end
244+
245+
function _size(::Val{false}, O::GradgenOperator, dim::Integer)
246+
error("$(typeof(O)) does not support the matrix interface")
247+
end
248+
249+
function Base.size(O::T, dim::Integer) where {T<:GradgenOperator}
250+
return _size(Val(supports_matrix_interface(T)), O, dim)
251+
end
252+
253+
254+
# As for an `Operator`, we implement `similar` to return a standard `Array`
255+
# because `GradgenOperator` does not `setindex!`, so it's arguably not a
256+
# "mutable array" even if its components are mutable.
257+
# similar(O) and similar(O, S) call size(O), which will error if
258+
# !supports_matrix_interface. The dims-based variants need no guard.
259+
Base.similar(G::GradgenOperator) = Array{eltype(G)}(undef, size(G))
260+
261+
Base.similar(O::GradgenOperator, ::Type{S}) where {S} = Array{S}(undef, size(O))
262+
Base.similar(O::GradgenOperator, dims::Tuple{Vararg{Int}}) = Array{eltype(O)}(undef, dims)
263+
Base.similar(O::GradgenOperator, ::Type{S}, dims::Tuple{Vararg{Int}}) where {S} =
264+
Array{S}(undef, dims)
265+
266+
267+
function Base.eltype(
268+
::Type{GradgenOperator{num_controls,GT,CGT}}
269+
) where {num_controls,GT,CGT}
270+
return promote_type(eltype(GT), eltype(CGT))
271+
end
272+
273+
274+
function _getindex(
275+
::Val{true},
276+
O::GradgenOperator{num_controls,GT,CGT},
277+
row::Int,
278+
col::Int
279+
) where {num_controls,GT,CGT}
280+
T = eltype(O)
281+
N, M = size(O.G)
282+
L = num_controls
283+
block_row = (row - 1) ÷ N + 1
284+
block_col = (col - 1) ÷ M + 1
285+
local_row = (row - 1) % N + 1
286+
local_col = (col - 1) % M + 1
287+
if block_row == block_col
288+
return convert(T, O.G[local_row, local_col])
289+
elseif block_col == L + 1 && block_row <= L
290+
return convert(T, O.control_deriv_ops[block_row][local_row, local_col])
291+
else
292+
return zero(T)
293+
end
294+
end
295+
296+
function _getindex(::Val{false}, O::GradgenOperator, row::Int, col::Int)
297+
error("$(typeof(O)) does not support the matrix interface")
298+
end
299+
300+
function Base.getindex(O::T, row::Int, col::Int) where {T<:GradgenOperator}
301+
return _getindex(Val(supports_matrix_interface(T)), O, row, col)
302+
end
303+
304+
305+
function _length(::Val{true}, O::GradgenOperator)
306+
return prod(size(O))
307+
end
308+
309+
function _length(::Val{false}, O::GradgenOperator)
310+
error("$(typeof(O)) does not support the matrix interface")
311+
end
312+
313+
function Base.length(O::T) where {T<:GradgenOperator}
314+
return _length(Val(supports_matrix_interface(T)), O)
315+
end
316+
317+
318+
function _iterate(::Val{true}, O::GradgenOperator, k)
319+
n = length(O)
320+
k > n && return nothing
321+
n_rows = size(O, 1)
322+
i = (k - 1) % n_rows + 1
323+
j = (k - 1) ÷ n_rows + 1
324+
return (O[i, j], k + 1)
325+
end
326+
327+
function _iterate(::Val{false}, O::GradgenOperator, k)
328+
error("$(typeof(O)) does not support the matrix interface")
329+
end
330+
331+
function Base.iterate(O::T, k = 1) where {T<:GradgenOperator}
332+
return _iterate(Val(supports_matrix_interface(T)), O, k)
333+
end
334+
335+
336+
function Base.eltype(::Type{GradVector{num_controls,T}}) where {num_controls,T}
337+
return eltype(T)
338+
end
339+
340+
function Base.copyto!(dest::GradgenOperator, src::GradgenOperator)
341+
copyto!(dest.G, src.G)
342+
copyto!(dest.control_deriv_ops, src.control_deriv_ops)
343+
end
344+
345+
118346
function Base.zero::GradVector{num_controls,T}) where {num_controls,T}
119347
return GradVector{num_controls,T}(zero.state), [zero(ϕ) for ϕ Ψ.grad_states])
120348
end

0 commit comments

Comments
 (0)