Skip to content

Commit 3b95e72

Browse files
committed
Guarantee discretized controls are Vector{Float64}.
1 parent d4dcd14 commit 3b95e72

3 files changed

Lines changed: 79 additions & 13 deletions

File tree

src/amplitudes.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,27 @@ struct LockedAmplitude{ST}
2828
shape::ST
2929

3030
function LockedAmplitude(shape::ST; check = true) where {ST}
31-
if check && !(shape isa Vector{Float64})
32-
try
33-
shape(0.0)
34-
catch
35-
msg = "A LockedAmplitude shape must either be a Vector{Float64} or a callable"
36-
error(msg)
31+
if shape isa AbstractVector
32+
if !(shape isa Vector{Float64})
33+
try
34+
shape = Vector{Float64}(shape)
35+
catch
36+
msg = "A LockedAmplitude shape that is a vector must be convertible to Vector{Float64}"
37+
error(msg)
38+
end
39+
end
40+
return new{Vector{Float64}}(shape)
41+
else
42+
if check
43+
try
44+
shape(0.0)
45+
catch
46+
msg = "A LockedAmplitude shape must either be a Vector{Float64} or a callable"
47+
error(msg)
48+
end
3749
end
50+
return new{ST}(shape)
3851
end
39-
return new{ST}(shape)
4052
end
4153
end
4254

@@ -123,6 +135,22 @@ end
123135

124136

125137
function ShapedAmplitude(control; shape, check = true)
138+
if control isa AbstractVector && !(control isa Vector{Float64})
139+
try
140+
control = Vector{Float64}(control)
141+
catch
142+
msg = "A ShapedAmplitude control that is a vector must be convertible to Vector{Float64}"
143+
error(msg)
144+
end
145+
end
146+
if shape isa AbstractVector && !(shape isa Vector{Float64})
147+
try
148+
shape = Vector{Float64}(shape)
149+
catch
150+
msg = "A ShapedAmplitude shape that is a vector must be convertible to Vector{Float64}"
151+
error(msg)
152+
end
153+
end
126154
if check
127155
if !(control isa Vector{Float64})
128156
try
@@ -159,7 +187,7 @@ end
159187
function ShapedAmplitude(control, tlist::Vector{Float64}; shape)
160188
control = discretize_on_midpoints(control, tlist)
161189
shape = discretize_on_midpoints(shape, tlist)
162-
return ShapedAmplitude{typeof(control),typeof(shape)}(control, shape)
190+
return ShapedAmplitude{Vector{Float64},Vector{Float64}}(control, shape)
163191
end
164192

165193

src/controls.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ function discretize(control::Function, tlist; via_midpoints = true)
4545
vals_on_midpoints = discretize_on_midpoints(control, tlist)
4646
return discretize(vals_on_midpoints, tlist)
4747
else
48-
return [control(t) for t in tlist]
48+
return Float64[control(t) for t in tlist]
4949
end
5050
end
5151

5252
function discretize(control::Vector, tlist)
5353
if length(control) == length(tlist)
54-
return copy(control)
54+
return Vector{Float64}(control)
5555
elseif length(control) == length(tlist) - 1
5656
# convert `control` on intervals to values on `tlist`
5757
# cf. pulse_onto_tlist in Python krotov package
58-
vals = zeros(eltype(control), length(control) + 1)
58+
vals = zeros(Float64, length(control) + 1)
5959
vals[1] = control[1]
6060
vals[end] = control[end]
6161
for i = 2:(length(vals)-1)
@@ -193,9 +193,9 @@ end
193193

194194
function discretize_on_midpoints(control::Vector, tlist)
195195
if length(control) == length(tlist) - 1
196-
return copy(control)
196+
return Vector{Float64}(control)
197197
elseif length(control) == length(tlist)
198-
vals = Vector{eltype(control)}(undef, length(tlist) - 1)
198+
vals = Vector{Float64}(undef, length(tlist) - 1)
199199
vals[1] = control[1]
200200
vals[end] = control[end]
201201
for i = 2:(length(vals)-1)

test/test_amplitudes.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,44 @@ end
6363
end
6464

6565

66+
@testset "Vector{Float64} conversion" begin
67+
tlist = collect(range(0, 10, length = 101))
68+
S(t) = flattop(t, T = 10, t_rise = 2, func = :blackman)
69+
ϵ(t) = 1.0
70+
71+
# LockedAmplitude: callable returning Int is discretized to Vector{Float64}
72+
ampl = LockedAmplitude(t -> 1, tlist)
73+
@test ampl.shape isa Vector{Float64}
74+
@test check_amplitude(ampl; tlist)
75+
76+
# LockedAmplitude: Vector{Int} is converted to Vector{Float64}
77+
S_int = ones(Int, length(tlist) - 1)
78+
ampl = LockedAmplitude(S_int)
79+
@test ampl.shape isa Vector{Float64}
80+
@test check_amplitude(ampl; tlist)
81+
82+
# ShapedAmplitude: callable shape returning Int
83+
ampl = ShapedAmplitude(ϵ; shape = t -> 1)
84+
@test check_amplitude(ampl; tlist)
85+
t = t_mid(tlist, 20)
86+
@test evaluate(ampl, tlist, 20) ϵ(t) * 1
87+
88+
# ShapedAmplitude: Vector{Int} control and shape are converted
89+
ϵ_int = ones(Int, length(tlist) - 1)
90+
S_int = ones(Int, length(tlist) - 1)
91+
ampl = ShapedAmplitude(ϵ_int; shape = S_int)
92+
@test ampl.control isa Vector{Float64}
93+
@test ampl.shape isa Vector{Float64}
94+
@test check_amplitude(ampl; tlist)
95+
96+
# ShapedAmplitude tlist constructor: callable returning Int is discretized to Vector{Float64}
97+
ampl = ShapedAmplitude(t -> 1, tlist; shape = t -> 1)
98+
@test ampl.control isa Vector{Float64}
99+
@test ampl.shape isa Vector{Float64}
100+
101+
end
102+
103+
66104
@testset "ShapedAmplitude mixed" begin
67105
tlist = collect(range(0, 10, length = 101))
68106
S(t) = flattop(t, T = 10, t_rise = 2, func = :blackman)

0 commit comments

Comments
 (0)