Skip to content

Commit 4b3df28

Browse files
committed
Implement matrix and vector interfaces
This properly defines `supports_vector_interface` and `supports_matrix_interface` for `GradVector` and `GradgenOperator`, respectively, and implement the full required interface, as checked by `check_operator` and `check_state`.
1 parent 856448a commit 4b3df28

4 files changed

Lines changed: 262 additions & 13 deletions

File tree

src/grad_vector.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import QuantumControl.QuantumPropagators: _exp_prop_convert_state
2-
import QuantumControl.QuantumPropagators.Interfaces: supports_inplace
2+
import QuantumControl.QuantumPropagators.Interfaces:
3+
supports_inplace, supports_vector_interface
34

45

56
@doc raw"""Extended state-vector for the dynamic gradient.
@@ -68,8 +69,8 @@ in-place operations.
6869
6970
Returns `Ψ̃`.
7071
"""
71-
function resetgradvec!(Ψ̃::GradVector)
72-
if supports_inplace(Ψ̃)
72+
function resetgradvec!(Ψ̃::T) where {T<:GradVector}
73+
if supports_inplace(T)
7374
for i in eachindex(Ψ̃.grad_states)
7475
fill!(Ψ̃.grad_states[i], 0.0)
7576
end
@@ -89,4 +90,7 @@ end
8990

9091
_exp_prop_convert_state(::GradVector) = Vector{ComplexF64}
9192

92-
supports_inplace(Ψ̃::GradVector) = supports_inplace(Ψ̃.state)
93+
supports_inplace(::Type{GradVector{N,T}}) where {N,T} = supports_inplace(T)
94+
95+
supports_vector_interface(::Type{GradVector{N,T}}) where {N,T} =
96+
supports_vector_interface(T)

src/gradgen_operator.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using Random: GLOBAL_RNG
22
import QuantumControl.QuantumPropagators: _exp_prop_convert_operator
33
import QuantumControl.QuantumPropagators.Controls: get_controls
44
import QuantumControl.QuantumPropagators.SpectralRange: random_state
5-
import QuantumControl.QuantumPropagators.Interfaces: supports_inplace
5+
import QuantumControl.QuantumPropagators.Interfaces:
6+
supports_inplace, supports_matrix_interface
67

78

89
"""Static generator for the dynamic gradient.
@@ -40,4 +41,8 @@ end
4041

4142
_exp_prop_convert_operator(::GradgenOperator) = Matrix{ComplexF64}
4243

43-
supports_inplace(::GradgenOperator) = true
44+
supports_inplace(::Type{GradgenOperator{N,GT,CGT}}) where {N,GT,CGT} =
45+
(supports_inplace(GT) && supports_inplace(CGT))
46+
47+
supports_matrix_interface(::Type{<:GradgenOperator{N,GT,CGT}}) where {N,GT,CGT} =
48+
supports_matrix_interface(GT) && supports_matrix_interface(CGT)

src/linalg.jl

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ function LinearAlgebra.mul!(Φ::GradVector, G::GradgenOperator, Ψ::GradVector,
1212
end
1313

1414

15+
function LinearAlgebra.mul!::GradVector, G::GradgenOperator, Ψ::GradVector)
16+
return LinearAlgebra.mul!(Φ, G, Ψ, true, false)
17+
end
18+
19+
1520
function LinearAlgebra.lmul!(c, Ψ::GradVector)
1621
LinearAlgebra.lmul!(c, Ψ.state)
1722
for i eachindex.grad_states)
@@ -48,6 +53,11 @@ function LinearAlgebra.dot(Ψ::GradVector, Φ::GradVector)
4853
end
4954

5055

56+
function LinearAlgebra.dot::GradVector, G::GradgenOperator, Φ::GradVector)
57+
return LinearAlgebra.dot(Ψ, G * Φ)
58+
end
59+
60+
5161
LinearAlgebra.ishermitian(G::GradgenOperator) = false
5262

5363

@@ -75,6 +85,11 @@ function Base.length(Ψ::GradVector)
7585
end
7686

7787

88+
function Base.size::GradVector{num_controls,T}) where {num_controls,T}
89+
return ((num_controls + 1) * length.state),)
90+
end
91+
92+
7893
function Base.size(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
7994
return (num_controls + 1) .* size(O.G)
8095
end
@@ -89,17 +104,105 @@ end
89104

90105

91106
function Base.similar::GradVector{num_controls,T}) where {num_controls,T}
92-
return GradVector{num_controls,T}(similar.state), [similar(ϕ) for ϕ Ψ.grad_states])
107+
state_sim = similar.state)
108+
grad_states_sim = [similar(ϕ) for ϕ Ψ.grad_states]
109+
return GradVector{num_controls,typeof(state_sim)}(state_sim, grad_states_sim)
110+
end
111+
112+
Base.similar::GradVector, ::Type{S}) where {S} = Vector{S}(undef, length(Ψ))
113+
114+
Base.similar::GradVector, dims::Tuple{Vararg{Int}}) = Array{eltype(Ψ)}(undef, dims)
115+
116+
# These definitions of `similar` exist to make ExponentialUtilities happy, but
117+
# it's not clear at all that `similar` with a custom shape really makes sense
118+
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int,Int}) where {T} =
119+
Matrix{T}(undef, dims...)
120+
121+
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int}) where {T} =
122+
Vector{T}(undef, dims[1])
123+
124+
function Base.getindex::GradVector{num_controls,T}, k::Int) where {num_controls,T}
125+
N = length.state)
126+
L = num_controls
127+
block = (k - 1) ÷ N + 1
128+
local_k = (k - 1) % N + 1
129+
if block <= L
130+
return Ψ.grad_states[block][local_k]
131+
else
132+
return Ψ.state[local_k]
133+
end
134+
end
135+
136+
function Base.setindex!::GradVector{num_controls,T}, v, k::Int) where {num_controls,T}
137+
N = length.state)
138+
L = num_controls
139+
block = (k - 1) ÷ N + 1
140+
local_k = (k - 1) % N + 1
141+
if block <= L
142+
Ψ.grad_states[block][local_k] = v
143+
else
144+
Ψ.state[local_k] = v
145+
end
146+
return Ψ
93147
end
94148

95-
function Base.similar(G::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
96-
return GradgenOperator{num_controls,GT,CGT}(similar(G.G), similar(G.control_deriv_ops))
149+
function Base.iterate::GradVector, k = 1)
150+
k > length(Ψ) && return nothing
151+
return (Ψ[k], k + 1)
97152
end
98153

99-
function Base.eltype(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
154+
# As for an `Operator`, we implement `similar` to return a standard `Array`
155+
# because `GradgenOperator` does not `setindex!`, so it's arguable not a
156+
# "mutable array"even if its components are mutable.
157+
Base.similar(G::GradgenOperator) = Array{eltype(G)}(undef, size(G))
158+
159+
Base.similar(O::GradgenOperator, ::Type{S}) where {S} = Array{S}(undef, size(O))
160+
Base.similar(O::GradgenOperator, dims::Tuple{Vararg{Int}}) = Array{eltype(O)}(undef, dims)
161+
Base.similar(O::GradgenOperator, ::Type{S}, dims::Tuple{Vararg{Int}}) where {S} =
162+
Array{S}(undef, dims)
163+
164+
function Base.eltype(
165+
::Type{GradgenOperator{num_controls,GT,CGT}}
166+
) where {num_controls,GT,CGT}
100167
return promote_type(eltype(GT), eltype(CGT))
101168
end
102169

170+
function Base.getindex(
171+
O::GradgenOperator{num_controls,GT,CGT},
172+
row::Int,
173+
col::Int
174+
) where {num_controls,GT,CGT}
175+
T = eltype(O)
176+
N, M = size(O.G)
177+
L = num_controls
178+
block_row = (row - 1) ÷ N + 1
179+
block_col = (col - 1) ÷ M + 1
180+
local_row = (row - 1) % N + 1
181+
local_col = (col - 1) % M + 1
182+
if block_row == block_col
183+
return convert(T, O.G[local_row, local_col])
184+
elseif block_col == L + 1 && block_row <= L
185+
return convert(T, O.control_deriv_ops[block_row][local_row, local_col])
186+
else
187+
return zero(T)
188+
end
189+
end
190+
191+
Base.length(O::GradgenOperator) = prod(size(O))
192+
193+
function Base.iterate(O::GradgenOperator, k = 1)
194+
n = length(O)
195+
k > n && return nothing
196+
n_rows = size(O, 1)
197+
i = (k - 1) % n_rows + 1
198+
j = (k - 1) ÷ n_rows + 1
199+
return (O[i, j], k + 1)
200+
end
201+
202+
function Base.eltype(::Type{GradVector{num_controls,T}}) where {num_controls,T}
203+
return eltype(T)
204+
end
205+
103206
function Base.copyto!(dest::GradgenOperator, src::GradgenOperator)
104207
copyto!(dest.G, src.G)
105208
copyto!(dest.control_deriv_ops, src.control_deriv_ops)

test/test_interface.jl

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ using QuantumPropagators.Generators: hamiltonian
33
using QuantumPropagators.Controls: get_controls
44
using QuantumControlTestUtils.RandomObjects: random_matrix, random_state_vector
55
using QuantumControl.Interfaces: check_generator
6-
using QuantumPropagators.Interfaces: check_state
7-
using QuantumGradientGenerators: GradGenerator, GradVector
6+
using QuantumPropagators.Interfaces:
7+
check_state, check_operator, supports_matrix_interface, supports_vector_interface
8+
using QuantumGradientGenerators: GradGenerator, GradVector, GradgenOperator
89
using StaticArrays: SVector, SMatrix
9-
using LinearAlgebra: norm
10+
using LinearAlgebra: norm, dot, mul!, I
1011

1112

1213
@testset "GradVector Interface" begin
@@ -75,3 +76,139 @@ end
7576
@test check_generator(G̃_of_t; state = Ψ̃, tlist, for_gradient_optimization = false)
7677

7778
end
79+
80+
81+
@testset "GradgenOperator Matrix Interface" begin
82+
83+
N = 5
84+
L = 2
85+
G = Matrix{ComplexF64}(I, N, N)
86+
mu = [rand(ComplexF64, N, N) for _ = 1:L]
87+
op = GradgenOperator{L,Matrix{ComplexF64},Matrix{ComplexF64}}(G, mu)
88+
state = GradVector(rand(ComplexF64, N), L)
89+
90+
# supports_matrix_interface reports true for matrix-backed GradgenOperator
91+
@test supports_matrix_interface(typeof(op))
92+
93+
# check_operator passes the full matrix interface check including for_expval
94+
@test check_operator(op; state, for_expval = true)
95+
96+
# getindex is consistent with the dense Array representation
97+
dense = Array(op)
98+
@test all(op[i, j] dense[i, j] for i = 1:size(op, 1), j = 1:size(op, 2))
99+
100+
# length
101+
@test length(op) == prod(size(op))
102+
103+
# iterate visits elements in column-major order, consistent with vec(Array(op))
104+
@test all(collect(op) .≈ vec(dense))
105+
106+
# 3-arg mul! agrees with 5-arg mul!(Phi, G, Psi, 1, 0)
107+
Psi = GradVector(rand(ComplexF64, N), L)
108+
Phi1 = GradVector(zeros(ComplexF64, N), L)
109+
Phi2 = GradVector(zeros(ComplexF64, N), L)
110+
mul!(Phi1, op, Psi)
111+
mul!(Phi2, op, Psi, true, false)
112+
@test norm(Phi1 - Phi2) < 1e-14
113+
114+
# 3-arg dot(Psi, op, Phi) matches dot(Psi, op * Phi)
115+
Psi2 = GradVector(rand(ComplexF64, N), L)
116+
@test dot(state, op, Psi2) dot(state, op * Psi2)
117+
118+
# similar(op) returns a dense Array of the same eltype and size (matching Operator pattern)
119+
op_sim = similar(op)
120+
@test op_sim isa Array{eltype(op)}
121+
@test size(op_sim) == size(op)
122+
123+
# similar(op, S) returns a dense Array of type S with matching size
124+
@test similar(op, Float64) isa Array{Float64}
125+
@test size(similar(op, Float64)) == size(op)
126+
127+
# similar(op, dims) returns a dense Array with given dims
128+
@test similar(op, (3, 4)) isa Array{eltype(op)}
129+
@test size(similar(op, (3, 4))) == (3, 4)
130+
131+
# similar(op, S, dims) returns a dense Array of type S with given dims
132+
@test similar(op, Float64, (3, 4)) isa Array{Float64}
133+
@test size(similar(op, Float64, (3, 4))) == (3, 4)
134+
135+
end
136+
137+
138+
@testset "GradVector Vector Interface" begin
139+
140+
N = 5
141+
L = 2
142+
Psi = rand(ComplexF64, N)
143+
gradvec = GradVector(Psi, L)
144+
145+
# supports_vector_interface is true for Vector-backed GradVector
146+
@test supports_vector_interface(typeof(gradvec))
147+
148+
# check_state passes full vector interface check
149+
@test check_state(gradvec)
150+
151+
# size is 1D with total length
152+
@test size(gradvec) == (N * (L + 1),)
153+
@test size(gradvec) == (length(gradvec),)
154+
155+
# getindex is consistent with convert_gradvec_to_dense layout:
156+
# [grad_states[1]; grad_states[2]; ...; grad_states[L]; state]
157+
dense = convert(Vector{ComplexF64}, gradvec)
158+
@test all(gradvec[k] == dense[k] for k = 1:length(gradvec))
159+
160+
# iterate visits elements consistent with getindex
161+
@test all(collect(gradvec) .== dense)
162+
163+
# setindex! round-trips through getindex
164+
gradvec2 = GradVector(copy(Psi), L)
165+
for k = 1:length(gradvec2)
166+
gradvec2[k] = gradvec[k]
167+
end
168+
@test all(gradvec2[k] == gradvec[k] for k = 1:length(gradvec))
169+
170+
# similar(gradvec, S) returns a mutable Vector{S} with same length
171+
@test similar(gradvec, ComplexF32) isa Vector{ComplexF32}
172+
@test length(similar(gradvec, ComplexF32)) == length(gradvec)
173+
174+
# similar(gradvec, dims) returns a plain Array with same eltype and given dims
175+
@test similar(gradvec, (3, 4)) isa Array{eltype(gradvec)}
176+
@test size(similar(gradvec, (3, 4))) == (3, 4)
177+
178+
end
179+
180+
181+
@testset "GradVector Vector Interface (Static)" begin
182+
183+
N = 5
184+
L = 2
185+
Psi = SVector{N,ComplexF64}(rand(ComplexF64, N))
186+
gradvec = GradVector(Psi, L)
187+
188+
# SVector-backed GradVector: supports_vector_interface follows the component type
189+
@test supports_vector_interface(typeof(gradvec))
190+
191+
# check_state passes (SVector is inplace=false, so setindex! is not checked)
192+
@test check_state(gradvec)
193+
194+
# getindex is consistent with the dense layout
195+
dense = convert(Vector{ComplexF64}, gradvec)
196+
@test all(gradvec[k] == dense[k] for k = 1:length(gradvec))
197+
198+
end
199+
200+
201+
@testset "GradVector without Vector Interface" begin
202+
203+
N = 5
204+
L = 2
205+
# Matrix is not an AbstractVector, so supports_vector_interface returns false
206+
Psi = rand(ComplexF64, N, N)
207+
gradvec = GradVector(Psi, L)
208+
209+
@test !supports_vector_interface(typeof(gradvec))
210+
211+
# check_state still passes via the basic (non-vector) state interface
212+
@test check_state(gradvec)
213+
214+
end

0 commit comments

Comments
 (0)