Skip to content

Commit 45a992f

Browse files
committed
Merge #110 from branch refactor-amplitudes
2 parents df3491e + 952dd50 commit 45a992f

3 files changed

Lines changed: 342 additions & 78 deletions

File tree

src/amplitudes.jl

Lines changed: 138 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,46 @@ using ..Controls: t_mid
1515
ampl = LockedAmplitude(shape)
1616
```
1717
18-
wraps around `shape`, which must be either a vector of values defined on the
19-
midpoints of a time grid or a callable `shape(t)`.
18+
wraps around `shape`, which must be either a `Vector{Float64}` of values defined
19+
on the midpoints of a time grid or a callable `shape(t)`.
2020
2121
```julia
2222
ampl = LockedAmplitude(shape, tlist)
2323
```
2424
2525
discretizes `shape` to the midpoints of `tlist`.
2626
"""
27-
abstract type LockedAmplitude end
28-
29-
30-
function LockedAmplitude(shape)
31-
if shape isa Vector{Float64}
32-
return LockedPulseAmplitude(shape)
33-
else
34-
return LockedContinuousAmplitude(shape)
27+
struct LockedAmplitude{ST}
28+
shape::ST
29+
30+
function LockedAmplitude(shape::ST; check = true) where {ST}
31+
if shape isa AbstractVector
32+
if !(shape isa Vector{Float64})
33+
try
34+
shape = Vector{Float64}(shape)
35+
catch
36+
msg = "A LockedAmplitude shape that is a vector must be convertible to Vector{Float64}"
37+
error(msg)
38+
end
39+
end
40+
return new{Vector{Float64}}(shape)
41+
else
42+
if check
43+
try
44+
shape(0.0)
45+
catch
46+
msg = "A LockedAmplitude shape must either be a Vector{Float64} or a callable"
47+
error(msg)
48+
end
49+
end
50+
return new{ST}(shape)
51+
end
3552
end
3653
end
3754

3855

3956
function LockedAmplitude(shape, tlist)
40-
return LockedPulseAmplitude(discretize_on_midpoints(shape, tlist))
57+
return LockedAmplitude(discretize_on_midpoints(shape, tlist); check = false)
4158
end
4259

4360

@@ -46,53 +63,29 @@ function Base.show(io::IO, ampl::LockedAmplitude)
4663
end
4764

4865

49-
struct LockedPulseAmplitude <: LockedAmplitude
50-
shape::Vector{Float64}
51-
end
52-
53-
54-
Base.Array(ampl::LockedPulseAmplitude) = ampl.shape
55-
56-
57-
struct LockedContinuousAmplitude <: LockedAmplitude
58-
59-
shape
60-
61-
function LockedContinuousAmplitude(shape)
62-
try
63-
S_t = shape(0.0)
64-
catch
65-
error("A LockedAmplitude shape must either be a vector of values or a callable")
66-
end
67-
return new(shape)
68-
end
69-
70-
end
71-
72-
(ampl::LockedContinuousAmplitude)(t::Float64) = ampl.shape(t)
66+
Base.Array(ampl::LockedAmplitude{Vector{Float64}}) = ampl.shape
7367

7468
get_controls(ampl::LockedAmplitude) = ()
7569

7670
function substitute(ampl::LockedAmplitude, replacements)
7771
return get(replacements, ampl, ampl)
7872
end
7973

80-
function evaluate(ampl::LockedPulseAmplitude, tlist, n; _...)
74+
function evaluate(ampl::LockedAmplitude{Vector{Float64}}, tlist, n::Int; _...)
8175
return ampl.shape[n]
8276
end
8377

84-
function evaluate(ampl::LockedPulseAmplitude, t; _...)
85-
error(
86-
"A LockedAmplitude initialized on a `tlist` can only be evaluated with arguments `tlist, n`."
87-
)
78+
function evaluate(ampl::LockedAmplitude{Vector{Float64}}, t::Float64; _...)
79+
msg = "A LockedAmplitude initialized from a vector can only be evaluated with (tlist, n)."
80+
error(msg)
8881
end
8982

90-
function evaluate(ampl::LockedContinuousAmplitude, tlist, n; _...)
91-
return ampl(t_mid(tlist, n))
83+
function evaluate(ampl::LockedAmplitude, tlist, n::Int; _...)
84+
return ampl.shape(t_mid(tlist, n))
9285
end
9386

94-
function evaluate(ampl::LockedContinuousAmplitude, t; _...)
95-
return ampl(t)
87+
function evaluate(ampl::LockedAmplitude, t::Float64; _...)
88+
return ampl.shape(t)
9689
end
9790

9891

@@ -124,63 +117,138 @@ ampl = ShapedAmplitude(control; shape=shape)
124117
```
125118
126119
produces an amplitude ``a(t) = S(t) ϵ(t)``, where ``S(t)`` corresponds to
127-
`shape` and ``ϵ(t)`` corresponds to `control`. Both `control` and `shape`
128-
should be either a vector of values defined on the midpoints of a time grid or
129-
a callable `control(t)`, respectively `shape(t)`. In the latter case, `ampl`
130-
will also be callable.
120+
`shape` and ``ϵ(t)`` corresponds to `control`. Each of `control` and `shape`
121+
must be either a `Vector{Float64}` of values defined on the midpoints of a time
122+
grid or a callable `control(t)`, respectively `shape(t)`. If both are callables,
123+
`ampl` will also be callable. If both are vectors, they must have the same length.
131124
132125
```julia
133126
ampl = ShapedAmplitude(control, tlist; shape=shape)
134127
```
135128
136129
discretizes `control` and `shape` to the midpoints of `tlist`.
137130
"""
138-
abstract type ShapedAmplitude <: ControlAmplitude end
131+
struct ShapedAmplitude{CT,ST} <: ControlAmplitude
132+
control::CT
133+
shape::ST
134+
end
139135

140-
function ShapedAmplitude(control; shape)
141-
if (control isa Vector{Float64}) && (shape isa Vector{Float64})
142-
return ShapedPulseAmplitude(control, shape)
143-
else
136+
137+
function ShapedAmplitude(control; shape, check = true)
138+
if control isa AbstractVector && !(control isa Vector{Float64})
144139
try
145-
ϵ_t = control(0.0)
140+
control = Vector{Float64}(control)
146141
catch
147-
error(
148-
"A ShapedAmplitude control must either be a vector of values or a callable"
149-
)
142+
msg = "A ShapedAmplitude control that is a vector must be convertible to Vector{Float64}"
143+
error(msg)
150144
end
145+
end
146+
if shape isa AbstractVector && !(shape isa Vector{Float64})
151147
try
152-
S_t = shape(0.0)
148+
shape = Vector{Float64}(shape)
153149
catch
154-
error("A ShapedAmplitude shape must either be a vector of values or a callable")
150+
msg = "A ShapedAmplitude shape that is a vector must be convertible to Vector{Float64}"
151+
error(msg)
155152
end
156-
return ShapedContinuousAmplitude(control, shape)
157153
end
154+
if check
155+
if !(control isa Vector{Float64})
156+
try
157+
control(0.0)
158+
catch
159+
msg = "A ShapedAmplitude control must either be a Vector{Float64} or a callable"
160+
error(msg)
161+
end
162+
end
163+
if !(shape isa Vector{Float64})
164+
try
165+
shape(0.0)
166+
catch
167+
msg = "A ShapedAmplitude shape must either be a Vector{Float64} or a callable"
168+
error(msg)
169+
end
170+
end
171+
if (control isa Vector{Float64}) && (shape isa Vector{Float64})
172+
if length(control) length(shape)
173+
msg = "ShapedAmplitude control and shape vectors must have the same length"
174+
error(msg)
175+
end
176+
end
177+
end
178+
return ShapedAmplitude{typeof(control),typeof(shape)}(control, shape)
158179
end
159180

181+
160182
function Base.show(io::IO, ampl::ShapedAmplitude)
161183
print(io, "ShapedAmplitude(::$(typeof(ampl.control)); shape::$(typeof(ampl.shape)))")
162184
end
163185

164-
function ShapedAmplitude(control, tlist; shape)
186+
187+
function ShapedAmplitude(control, tlist::Vector{Float64}; shape)
165188
control = discretize_on_midpoints(control, tlist)
166189
shape = discretize_on_midpoints(shape, tlist)
167-
return ShapedPulseAmplitude(control, shape)
190+
return ShapedAmplitude{Vector{Float64},Vector{Float64}}(control, shape)
168191
end
169192

170-
struct ShapedPulseAmplitude <: ShapedAmplitude
171-
control::Vector{Float64}
172-
shape::Vector{Float64}
193+
194+
function substitute(ampl::ShapedAmplitude, replacements)
195+
ampl in keys(replacements) && return replacements[ampl]
196+
control = substitute(ampl.control, replacements)
197+
return ShapedAmplitude(control; shape = ampl.shape, check = true)
173198
end
174199

175-
Base.Array(ampl::ShapedPulseAmplitude) = ampl.control .* ampl.shape
200+
201+
Base.Array(ampl::ShapedAmplitude{Vector{Float64},Vector{Float64}}) =
202+
ampl.control .* ampl.shape
203+
204+
(ampl::ShapedAmplitude)(t::Float64) = ampl.shape(t) * ampl.control(t)
176205

177206

178-
struct ShapedContinuousAmplitude <: ShapedAmplitude
179-
control
180-
shape
207+
# Vector shape: index directly; control may be vector or callable (evaluated via Controls.evaluate)
208+
function evaluate(
209+
ampl::ShapedAmplitude{CT,Vector{Float64}},
210+
tlist::Vector{Float64},
211+
n::Int;
212+
vals_dict = IdDict()
213+
) where {CT}
214+
S_t = ampl.shape[n]
215+
ϵ_t = evaluate(ampl.control, tlist, n; vals_dict)
216+
return S_t * ϵ_t
181217
end
182218

183-
(ampl::ShapedContinuousAmplitude)(t::Float64) = ampl.shape(t) * ampl.control(t)
219+
# Callable shape: call directly with t_mid; control may be vector or callable
220+
function evaluate(
221+
ampl::ShapedAmplitude,
222+
tlist::Vector{Float64},
223+
n::Int;
224+
vals_dict = IdDict()
225+
)
226+
S_t = ampl.shape(t_mid(tlist, n))
227+
ϵ_t = evaluate(ampl.control, tlist, n; vals_dict)
228+
return S_t * ϵ_t
229+
end
230+
231+
# Callable shape and callable control: call both directly at t
232+
function evaluate(ampl::ShapedAmplitude, t::Float64; vals_dict = IdDict())
233+
S_t = ampl.shape(t)
234+
ϵ_t = evaluate(ampl.control, t; vals_dict)
235+
return S_t * ϵ_t
236+
end
237+
238+
function evaluate(ampl::ShapedAmplitude{Vector{Float64},Vector{Float64}}, t::Float64; _...)
239+
msg = "A ShapedAmplitude with vector control and shape can only be evaluated with (tlist, n)."
240+
error(msg)
241+
end
242+
243+
function evaluate(ampl::ShapedAmplitude{Vector{Float64},ST}, t::Float64; _...) where {ST}
244+
msg = "A ShapedAmplitude with a vector control can only be evaluated with (tlist, n)."
245+
error(msg)
246+
end
247+
248+
function evaluate(ampl::ShapedAmplitude{CT,Vector{Float64}}, t::Float64; _...) where {CT}
249+
msg = "A ShapedAmplitude with a vector shape can only be evaluated with (tlist, n)."
250+
error(msg)
251+
end
184252

185253

186254
function evaluate(ampl::ShapedAmplitude, args...; vals_dict = IdDict())

src/controls.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ function discretize(control::Function, tlist; via_midpoints = true)
4545
vals_on_midpoints = discretize_on_midpoints(control, tlist)
4646
return discretize(vals_on_midpoints, tlist)
4747
else
48-
return [control(t) for t in tlist]
48+
return Float64[control(t) for t in tlist]
4949
end
5050
end
5151

5252
function discretize(control::Vector, tlist)
5353
if length(control) == length(tlist)
54-
return copy(control)
54+
return Vector{Float64}(control)
5555
elseif length(control) == length(tlist) - 1
5656
# convert `control` on intervals to values on `tlist`
5757
# cf. pulse_onto_tlist in Python krotov package
58-
vals = zeros(eltype(control), length(control) + 1)
58+
vals = zeros(Float64, length(control) + 1)
5959
vals[1] = control[1]
6060
vals[end] = control[end]
6161
for i = 2:(length(vals)-1)
@@ -193,9 +193,9 @@ end
193193

194194
function discretize_on_midpoints(control::Vector, tlist)
195195
if length(control) == length(tlist) - 1
196-
return copy(control)
196+
return Vector{Float64}(control)
197197
elseif length(control) == length(tlist)
198-
vals = Vector{eltype(control)}(undef, length(tlist) - 1)
198+
vals = Vector{Float64}(undef, length(tlist) - 1)
199199
vals[1] = control[1]
200200
vals[end] = control[end]
201201
for i = 2:(length(vals)-1)

0 commit comments

Comments
 (0)