Skip to content

Commit d4dcd14

Browse files
committed
Refactor LockedAmplitude and ShapedAmplitude
Both of these amplitudes could be made more flexible by getting rid of the type hierarchy that distinguished between, e.g., `ShapedContinousAmplitude` and `ShapedPulseAmplitude`. The shape and control can now be vectors or functions independently of each other. This enables substituting optimized pulses into existing amplitudes, among other benefits.
1 parent df3491e commit d4dcd14

2 files changed

Lines changed: 166 additions & 76 deletions

File tree

src/amplitudes.jl

Lines changed: 114 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,34 @@ using ..Controls: t_mid
1515
ampl = LockedAmplitude(shape)
1616
```
1717
18-
wraps around `shape`, which must be either a vector of values defined on the
19-
midpoints of a time grid or a callable `shape(t)`.
18+
wraps around `shape`, which must be either a `Vector{Float64}` of values defined
19+
on the midpoints of a time grid or a callable `shape(t)`.
2020
2121
```julia
2222
ampl = LockedAmplitude(shape, tlist)
2323
```
2424
2525
discretizes `shape` to the midpoints of `tlist`.
2626
"""
27-
abstract type LockedAmplitude end
28-
29-
30-
function LockedAmplitude(shape)
31-
if shape isa Vector{Float64}
32-
return LockedPulseAmplitude(shape)
33-
else
34-
return LockedContinuousAmplitude(shape)
27+
struct LockedAmplitude{ST}
28+
shape::ST
29+
30+
function LockedAmplitude(shape::ST; check = true) where {ST}
31+
if check && !(shape isa Vector{Float64})
32+
try
33+
shape(0.0)
34+
catch
35+
msg = "A LockedAmplitude shape must either be a Vector{Float64} or a callable"
36+
error(msg)
37+
end
38+
end
39+
return new{ST}(shape)
3540
end
3641
end
3742

3843

3944
function LockedAmplitude(shape, tlist)
40-
return LockedPulseAmplitude(discretize_on_midpoints(shape, tlist))
45+
return LockedAmplitude(discretize_on_midpoints(shape, tlist); check = false)
4146
end
4247

4348

@@ -46,53 +51,29 @@ function Base.show(io::IO, ampl::LockedAmplitude)
4651
end
4752

4853

49-
struct LockedPulseAmplitude <: LockedAmplitude
50-
shape::Vector{Float64}
51-
end
52-
53-
54-
Base.Array(ampl::LockedPulseAmplitude) = ampl.shape
55-
56-
57-
struct LockedContinuousAmplitude <: LockedAmplitude
58-
59-
shape
60-
61-
function LockedContinuousAmplitude(shape)
62-
try
63-
S_t = shape(0.0)
64-
catch
65-
error("A LockedAmplitude shape must either be a vector of values or a callable")
66-
end
67-
return new(shape)
68-
end
69-
70-
end
71-
72-
(ampl::LockedContinuousAmplitude)(t::Float64) = ampl.shape(t)
54+
Base.Array(ampl::LockedAmplitude{Vector{Float64}}) = ampl.shape
7355

7456
get_controls(ampl::LockedAmplitude) = ()
7557

7658
function substitute(ampl::LockedAmplitude, replacements)
7759
return get(replacements, ampl, ampl)
7860
end
7961

80-
function evaluate(ampl::LockedPulseAmplitude, tlist, n; _...)
62+
function evaluate(ampl::LockedAmplitude{Vector{Float64}}, tlist, n::Int; _...)
8163
return ampl.shape[n]
8264
end
8365

84-
function evaluate(ampl::LockedPulseAmplitude, t; _...)
85-
error(
86-
"A LockedAmplitude initialized on a `tlist` can only be evaluated with arguments `tlist, n`."
87-
)
66+
function evaluate(ampl::LockedAmplitude{Vector{Float64}}, t::Float64; _...)
67+
msg = "A LockedAmplitude initialized from a vector can only be evaluated with (tlist, n)."
68+
error(msg)
8869
end
8970

90-
function evaluate(ampl::LockedContinuousAmplitude, tlist, n; _...)
91-
return ampl(t_mid(tlist, n))
71+
function evaluate(ampl::LockedAmplitude, tlist, n::Int; _...)
72+
return ampl.shape(t_mid(tlist, n))
9273
end
9374

94-
function evaluate(ampl::LockedContinuousAmplitude, t; _...)
95-
return ampl(t)
75+
function evaluate(ampl::LockedAmplitude, t::Float64; _...)
76+
return ampl.shape(t)
9677
end
9778

9879

@@ -124,63 +105,122 @@ ampl = ShapedAmplitude(control; shape=shape)
124105
```
125106
126107
produces an amplitude ``a(t) = S(t) ϵ(t)``, where ``S(t)`` corresponds to
127-
`shape` and ``ϵ(t)`` corresponds to `control`. Both `control` and `shape`
128-
should be either a vector of values defined on the midpoints of a time grid or
129-
a callable `control(t)`, respectively `shape(t)`. In the latter case, `ampl`
130-
will also be callable.
108+
`shape` and ``ϵ(t)`` corresponds to `control`. Each of `control` and `shape`
109+
must be either a `Vector{Float64}` of values defined on the midpoints of a time
110+
grid or a callable `control(t)`, respectively `shape(t)`. If both are callables,
111+
`ampl` will also be callable. If both are vectors, they must have the same length.
131112
132113
```julia
133114
ampl = ShapedAmplitude(control, tlist; shape=shape)
134115
```
135116
136117
discretizes `control` and `shape` to the midpoints of `tlist`.
137118
"""
138-
abstract type ShapedAmplitude <: ControlAmplitude end
119+
struct ShapedAmplitude{CT,ST} <: ControlAmplitude
120+
control::CT
121+
shape::ST
122+
end
139123

140-
function ShapedAmplitude(control; shape)
141-
if (control isa Vector{Float64}) && (shape isa Vector{Float64})
142-
return ShapedPulseAmplitude(control, shape)
143-
else
144-
try
145-
ϵ_t = control(0.0)
146-
catch
147-
error(
148-
"A ShapedAmplitude control must either be a vector of values or a callable"
149-
)
124+
125+
function ShapedAmplitude(control; shape, check = true)
126+
if check
127+
if !(control isa Vector{Float64})
128+
try
129+
control(0.0)
130+
catch
131+
msg = "A ShapedAmplitude control must either be a Vector{Float64} or a callable"
132+
error(msg)
133+
end
150134
end
151-
try
152-
S_t = shape(0.0)
153-
catch
154-
error("A ShapedAmplitude shape must either be a vector of values or a callable")
135+
if !(shape isa Vector{Float64})
136+
try
137+
shape(0.0)
138+
catch
139+
msg = "A ShapedAmplitude shape must either be a Vector{Float64} or a callable"
140+
error(msg)
141+
end
142+
end
143+
if (control isa Vector{Float64}) && (shape isa Vector{Float64})
144+
if length(control) length(shape)
145+
msg = "ShapedAmplitude control and shape vectors must have the same length"
146+
error(msg)
147+
end
155148
end
156-
return ShapedContinuousAmplitude(control, shape)
157149
end
150+
return ShapedAmplitude{typeof(control),typeof(shape)}(control, shape)
158151
end
159152

153+
160154
function Base.show(io::IO, ampl::ShapedAmplitude)
161155
print(io, "ShapedAmplitude(::$(typeof(ampl.control)); shape::$(typeof(ampl.shape)))")
162156
end
163157

164-
function ShapedAmplitude(control, tlist; shape)
158+
159+
function ShapedAmplitude(control, tlist::Vector{Float64}; shape)
165160
control = discretize_on_midpoints(control, tlist)
166161
shape = discretize_on_midpoints(shape, tlist)
167-
return ShapedPulseAmplitude(control, shape)
162+
return ShapedAmplitude{typeof(control),typeof(shape)}(control, shape)
163+
end
164+
165+
166+
function substitute(ampl::ShapedAmplitude, replacements)
167+
ampl in keys(replacements) && return replacements[ampl]
168+
control = substitute(ampl.control, replacements)
169+
return ShapedAmplitude(control; shape = ampl.shape, check = true)
170+
end
171+
172+
173+
Base.Array(ampl::ShapedAmplitude{Vector{Float64},Vector{Float64}}) =
174+
ampl.control .* ampl.shape
175+
176+
(ampl::ShapedAmplitude)(t::Float64) = ampl.shape(t) * ampl.control(t)
177+
178+
179+
# Vector shape: index directly; control may be vector or callable (evaluated via Controls.evaluate)
180+
function evaluate(
181+
ampl::ShapedAmplitude{CT,Vector{Float64}},
182+
tlist::Vector{Float64},
183+
n::Int;
184+
vals_dict = IdDict()
185+
) where {CT}
186+
S_t = ampl.shape[n]
187+
ϵ_t = evaluate(ampl.control, tlist, n; vals_dict)
188+
return S_t * ϵ_t
168189
end
169190

170-
struct ShapedPulseAmplitude <: ShapedAmplitude
171-
control::Vector{Float64}
172-
shape::Vector{Float64}
191+
# Callable shape: call directly with t_mid; control may be vector or callable
192+
function evaluate(
193+
ampl::ShapedAmplitude,
194+
tlist::Vector{Float64},
195+
n::Int;
196+
vals_dict = IdDict()
197+
)
198+
S_t = ampl.shape(t_mid(tlist, n))
199+
ϵ_t = evaluate(ampl.control, tlist, n; vals_dict)
200+
return S_t * ϵ_t
173201
end
174202

175-
Base.Array(ampl::ShapedPulseAmplitude) = ampl.control .* ampl.shape
203+
# Callable shape and callable control: call both directly at t
204+
function evaluate(ampl::ShapedAmplitude, t::Float64; vals_dict = IdDict())
205+
S_t = ampl.shape(t)
206+
ϵ_t = evaluate(ampl.control, t; vals_dict)
207+
return S_t * ϵ_t
208+
end
176209

210+
function evaluate(ampl::ShapedAmplitude{Vector{Float64},Vector{Float64}}, t::Float64; _...)
211+
msg = "A ShapedAmplitude with vector control and shape can only be evaluated with (tlist, n)."
212+
error(msg)
213+
end
177214

178-
struct ShapedContinuousAmplitude <: ShapedAmplitude
179-
control
180-
shape
215+
function evaluate(ampl::ShapedAmplitude{Vector{Float64},ST}, t::Float64; _...) where {ST}
216+
msg = "A ShapedAmplitude with a vector control can only be evaluated with (tlist, n)."
217+
error(msg)
181218
end
182219

183-
(ampl::ShapedContinuousAmplitude)(t::Float64) = ampl.shape(t) * ampl.control(t)
220+
function evaluate(ampl::ShapedAmplitude{CT,Vector{Float64}}, t::Float64; _...) where {CT}
221+
msg = "A ShapedAmplitude with a vector shape can only be evaluated with (tlist, n)."
222+
error(msg)
223+
end
184224

185225

186226
function evaluate(ampl::ShapedAmplitude, args...; vals_dict = IdDict())

test/test_amplitudes.jl

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ using Test
22
using QuantumPropagators.Interfaces: check_amplitude, check_control
33
using QuantumPropagators.Shapes: flattop
44
using QuantumPropagators.Amplitudes: LockedAmplitude, ShapedAmplitude
5-
using QuantumPropagators.Controls: evaluate, get_controls, t_mid
6-
using QuantumPropagators.Controls: discretize_on_midpoints # DEBUG
5+
using QuantumPropagators.Controls: evaluate, get_controls, substitute, t_mid
6+
using QuantumPropagators.Controls: discretize_on_midpoints
77

88
@testset "LockedAmplitude" begin
99

@@ -25,6 +25,7 @@ using QuantumPropagators.Controls: discretize_on_midpoints # DEBUG
2525
@test length(get_controls(ampl)) == 0
2626
t = t_mid(tlist, 20)
2727
@test evaluate(ampl, tlist, 20) S(t)
28+
@test_throws Exception evaluate(ampl, t)
2829

2930
end
3031

@@ -57,5 +58,54 @@ end
5758
@test check_amplitude(ampl; tlist)
5859
t = t_mid(tlist, 20)
5960
@test evaluate(ampl, tlist, 20) ϵ(t) * S(t)
61+
@test_throws Exception evaluate(ampl, t)
62+
63+
end
64+
65+
66+
@testset "ShapedAmplitude mixed" begin
67+
tlist = collect(range(0, 10, length = 101))
68+
S(t) = flattop(t, T = 10, t_rise = 2, func = :blackman)
69+
ϵ(t) = 1.0
70+
S_vec = discretize_on_midpoints(S, tlist)
71+
ϵ_vec = discretize_on_midpoints(ϵ, tlist)
72+
73+
# callable control, vector shape
74+
ampl = ShapedAmplitude(ϵ; shape = S_vec)
75+
@test startswith("$ampl", "ShapedAmplitude(")
76+
controls = get_controls(ampl)
77+
@test length(controls) == 1
78+
@test check_amplitude(ampl; tlist)
79+
t = t_mid(tlist, 20)
80+
@test evaluate(ampl, tlist, 20) ϵ(t) * S(t)
81+
@test_throws Exception evaluate(ampl, t)
82+
83+
# vector control, callable shape
84+
ampl = ShapedAmplitude(ϵ_vec; shape = S)
85+
@test startswith("$ampl", "ShapedAmplitude(")
86+
controls = get_controls(ampl)
87+
@test length(controls) == 1
88+
@test check_amplitude(ampl; tlist)
89+
t = t_mid(tlist, 20)
90+
@test evaluate(ampl, tlist, 20) ϵ(t) * S(t)
91+
@test_throws Exception evaluate(ampl, t)
92+
93+
end
94+
95+
96+
@testset "ShapedAmplitude substitute" begin
97+
tlist = collect(range(0, 10, length = 101))
98+
S(t) = flattop(t, T = 10, t_rise = 2, func = :blackman)
99+
ϵ(t) = 1.0
100+
ϵ_vec = discretize_on_midpoints(ϵ, tlist)
101+
102+
ampl = ShapedAmplitude(ϵ; shape = S)
103+
ampl2 = substitute(ampl, IdDict=> ϵ_vec))
104+
@test ampl2 isa ShapedAmplitude
105+
@test get_controls(ampl2)[1] === ϵ_vec
106+
@test check_amplitude(ampl2; tlist)
107+
t = t_mid(tlist, 20)
108+
@test evaluate(ampl2, tlist, 20) ϵ(t) * S(t)
109+
@test_throws Exception evaluate(ampl2, t)
60110

61111
end

0 commit comments

Comments
 (0)