diff --git a/src/amplitudes.jl b/src/amplitudes.jl index 2a41dfe..9d46a46 100644 --- a/src/amplitudes.jl +++ b/src/amplitudes.jl @@ -15,8 +15,8 @@ using ..Controls: t_mid ampl = LockedAmplitude(shape) ``` -wraps around `shape`, which must be either a vector of values defined on the -midpoints of a time grid or a callable `shape(t)`. +wraps around `shape`, which must be either a `Vector{Float64}` of values defined +on the midpoints of a time grid or a callable `shape(t)`. ```julia ampl = LockedAmplitude(shape, tlist) @@ -24,20 +24,37 @@ ampl = LockedAmplitude(shape, tlist) discretizes `shape` to the midpoints of `tlist`. """ -abstract type LockedAmplitude end - - -function LockedAmplitude(shape) - if shape isa Vector{Float64} - return LockedPulseAmplitude(shape) - else - return LockedContinuousAmplitude(shape) +struct LockedAmplitude{ST} + shape::ST + + function LockedAmplitude(shape::ST; check = true) where {ST} + if shape isa AbstractVector + if !(shape isa Vector{Float64}) + try + shape = Vector{Float64}(shape) + catch + msg = "A LockedAmplitude shape that is a vector must be convertible to Vector{Float64}" + error(msg) + end + end + return new{Vector{Float64}}(shape) + else + if check + try + shape(0.0) + catch + msg = "A LockedAmplitude shape must either be a Vector{Float64} or a callable" + error(msg) + end + end + return new{ST}(shape) + end end end function LockedAmplitude(shape, tlist) - return LockedPulseAmplitude(discretize_on_midpoints(shape, tlist)) + return LockedAmplitude(discretize_on_midpoints(shape, tlist); check = false) end @@ -46,30 +63,7 @@ function Base.show(io::IO, ampl::LockedAmplitude) end -struct LockedPulseAmplitude <: LockedAmplitude - shape::Vector{Float64} -end - - -Base.Array(ampl::LockedPulseAmplitude) = ampl.shape - - -struct LockedContinuousAmplitude <: LockedAmplitude - - shape - - function LockedContinuousAmplitude(shape) - try - S_t = shape(0.0) - catch - error("A LockedAmplitude shape must either be a vector of values or a callable") - end - return new(shape) - end - -end - -(ampl::LockedContinuousAmplitude)(t::Float64) = ampl.shape(t) +Base.Array(ampl::LockedAmplitude{Vector{Float64}}) = ampl.shape get_controls(ampl::LockedAmplitude) = () @@ -77,22 +71,21 @@ function substitute(ampl::LockedAmplitude, replacements) return get(replacements, ampl, ampl) end -function evaluate(ampl::LockedPulseAmplitude, tlist, n; _...) +function evaluate(ampl::LockedAmplitude{Vector{Float64}}, tlist, n::Int; _...) return ampl.shape[n] end -function evaluate(ampl::LockedPulseAmplitude, t; _...) - error( - "A LockedAmplitude initialized on a `tlist` can only be evaluated with arguments `tlist, n`." - ) +function evaluate(ampl::LockedAmplitude{Vector{Float64}}, t::Float64; _...) + msg = "A LockedAmplitude initialized from a vector can only be evaluated with (tlist, n)." + error(msg) end -function evaluate(ampl::LockedContinuousAmplitude, tlist, n; _...) - return ampl(t_mid(tlist, n)) +function evaluate(ampl::LockedAmplitude, tlist, n::Int; _...) + return ampl.shape(t_mid(tlist, n)) end -function evaluate(ampl::LockedContinuousAmplitude, t; _...) - return ampl(t) +function evaluate(ampl::LockedAmplitude, t::Float64; _...) + return ampl.shape(t) end @@ -124,10 +117,10 @@ ampl = ShapedAmplitude(control; shape=shape) ``` produces an amplitude ``a(t) = S(t) ϵ(t)``, where ``S(t)`` corresponds to -`shape` and ``ϵ(t)`` corresponds to `control`. Both `control` and `shape` -should be either a vector of values defined on the midpoints of a time grid or -a callable `control(t)`, respectively `shape(t)`. In the latter case, `ampl` -will also be callable. +`shape` and ``ϵ(t)`` corresponds to `control`. Each of `control` and `shape` +must be either a `Vector{Float64}` of values defined on the midpoints of a time +grid or a callable `control(t)`, respectively `shape(t)`. If both are callables, +`ampl` will also be callable. If both are vectors, they must have the same length. ```julia ampl = ShapedAmplitude(control, tlist; shape=shape) @@ -135,52 +128,127 @@ ampl = ShapedAmplitude(control, tlist; shape=shape) discretizes `control` and `shape` to the midpoints of `tlist`. """ -abstract type ShapedAmplitude <: ControlAmplitude end +struct ShapedAmplitude{CT,ST} <: ControlAmplitude + control::CT + shape::ST +end -function ShapedAmplitude(control; shape) - if (control isa Vector{Float64}) && (shape isa Vector{Float64}) - return ShapedPulseAmplitude(control, shape) - else + +function ShapedAmplitude(control; shape, check = true) + if control isa AbstractVector && !(control isa Vector{Float64}) try - ϵ_t = control(0.0) + control = Vector{Float64}(control) catch - error( - "A ShapedAmplitude control must either be a vector of values or a callable" - ) + msg = "A ShapedAmplitude control that is a vector must be convertible to Vector{Float64}" + error(msg) end + end + if shape isa AbstractVector && !(shape isa Vector{Float64}) try - S_t = shape(0.0) + shape = Vector{Float64}(shape) catch - error("A ShapedAmplitude shape must either be a vector of values or a callable") + msg = "A ShapedAmplitude shape that is a vector must be convertible to Vector{Float64}" + error(msg) end - return ShapedContinuousAmplitude(control, shape) end + if check + if !(control isa Vector{Float64}) + try + control(0.0) + catch + msg = "A ShapedAmplitude control must either be a Vector{Float64} or a callable" + error(msg) + end + end + if !(shape isa Vector{Float64}) + try + shape(0.0) + catch + msg = "A ShapedAmplitude shape must either be a Vector{Float64} or a callable" + error(msg) + end + end + if (control isa Vector{Float64}) && (shape isa Vector{Float64}) + if length(control) ≠ length(shape) + msg = "ShapedAmplitude control and shape vectors must have the same length" + error(msg) + end + end + end + return ShapedAmplitude{typeof(control),typeof(shape)}(control, shape) end + function Base.show(io::IO, ampl::ShapedAmplitude) print(io, "ShapedAmplitude(::$(typeof(ampl.control)); shape::$(typeof(ampl.shape)))") end -function ShapedAmplitude(control, tlist; shape) + +function ShapedAmplitude(control, tlist::Vector{Float64}; shape) control = discretize_on_midpoints(control, tlist) shape = discretize_on_midpoints(shape, tlist) - return ShapedPulseAmplitude(control, shape) + return ShapedAmplitude{Vector{Float64},Vector{Float64}}(control, shape) end -struct ShapedPulseAmplitude <: ShapedAmplitude - control::Vector{Float64} - shape::Vector{Float64} + +function substitute(ampl::ShapedAmplitude, replacements) + ampl in keys(replacements) && return replacements[ampl] + control = substitute(ampl.control, replacements) + return ShapedAmplitude(control; shape = ampl.shape, check = true) end -Base.Array(ampl::ShapedPulseAmplitude) = ampl.control .* ampl.shape + +Base.Array(ampl::ShapedAmplitude{Vector{Float64},Vector{Float64}}) = + ampl.control .* ampl.shape + +(ampl::ShapedAmplitude)(t::Float64) = ampl.shape(t) * ampl.control(t) -struct ShapedContinuousAmplitude <: ShapedAmplitude - control - shape +# Vector shape: index directly; control may be vector or callable (evaluated via Controls.evaluate) +function evaluate( + ampl::ShapedAmplitude{CT,Vector{Float64}}, + tlist::Vector{Float64}, + n::Int; + vals_dict = IdDict() +) where {CT} + S_t = ampl.shape[n] + ϵ_t = evaluate(ampl.control, tlist, n; vals_dict) + return S_t * ϵ_t end -(ampl::ShapedContinuousAmplitude)(t::Float64) = ampl.shape(t) * ampl.control(t) +# Callable shape: call directly with t_mid; control may be vector or callable +function evaluate( + ampl::ShapedAmplitude, + tlist::Vector{Float64}, + n::Int; + vals_dict = IdDict() +) + S_t = ampl.shape(t_mid(tlist, n)) + ϵ_t = evaluate(ampl.control, tlist, n; vals_dict) + return S_t * ϵ_t +end + +# Callable shape and callable control: call both directly at t +function evaluate(ampl::ShapedAmplitude, t::Float64; vals_dict = IdDict()) + S_t = ampl.shape(t) + ϵ_t = evaluate(ampl.control, t; vals_dict) + return S_t * ϵ_t +end + +function evaluate(ampl::ShapedAmplitude{Vector{Float64},Vector{Float64}}, t::Float64; _...) + msg = "A ShapedAmplitude with vector control and shape can only be evaluated with (tlist, n)." + error(msg) +end + +function evaluate(ampl::ShapedAmplitude{Vector{Float64},ST}, t::Float64; _...) where {ST} + msg = "A ShapedAmplitude with a vector control can only be evaluated with (tlist, n)." + error(msg) +end + +function evaluate(ampl::ShapedAmplitude{CT,Vector{Float64}}, t::Float64; _...) where {CT} + msg = "A ShapedAmplitude with a vector shape can only be evaluated with (tlist, n)." + error(msg) +end function evaluate(ampl::ShapedAmplitude, args...; vals_dict = IdDict()) diff --git a/src/controls.jl b/src/controls.jl index 309e98f..6707c2a 100644 --- a/src/controls.jl +++ b/src/controls.jl @@ -45,17 +45,17 @@ function discretize(control::Function, tlist; via_midpoints = true) vals_on_midpoints = discretize_on_midpoints(control, tlist) return discretize(vals_on_midpoints, tlist) else - return [control(t) for t in tlist] + return Float64[control(t) for t in tlist] end end function discretize(control::Vector, tlist) if length(control) == length(tlist) - return copy(control) + return Vector{Float64}(control) elseif length(control) == length(tlist) - 1 # convert `control` on intervals to values on `tlist` # cf. pulse_onto_tlist in Python krotov package - vals = zeros(eltype(control), length(control) + 1) + vals = zeros(Float64, length(control) + 1) vals[1] = control[1] vals[end] = control[end] for i = 2:(length(vals)-1) @@ -193,9 +193,9 @@ end function discretize_on_midpoints(control::Vector, tlist) if length(control) == length(tlist) - 1 - return copy(control) + return Vector{Float64}(control) elseif length(control) == length(tlist) - vals = Vector{eltype(control)}(undef, length(tlist) - 1) + vals = Vector{Float64}(undef, length(tlist) - 1) vals[1] = control[1] vals[end] = control[end] for i = 2:(length(vals)-1) diff --git a/test/test_amplitudes.jl b/test/test_amplitudes.jl index fc02601..1a9abed 100644 --- a/test/test_amplitudes.jl +++ b/test/test_amplitudes.jl @@ -1,9 +1,10 @@ using Test +using IOCapture: IOCapture using QuantumPropagators.Interfaces: check_amplitude, check_control using QuantumPropagators.Shapes: flattop using QuantumPropagators.Amplitudes: LockedAmplitude, ShapedAmplitude -using QuantumPropagators.Controls: evaluate, get_controls, t_mid -using QuantumPropagators.Controls: discretize_on_midpoints # DEBUG +using QuantumPropagators.Controls: evaluate, get_controls, substitute, t_mid +using QuantumPropagators.Controls: discretize_on_midpoints @testset "LockedAmplitude" begin @@ -25,6 +26,11 @@ using QuantumPropagators.Controls: discretize_on_midpoints # DEBUG @test length(get_controls(ampl)) == 0 t = t_mid(tlist, 20) @test evaluate(ampl, tlist, 20) ≈ S(t) + captured = IOCapture.capture(rethrow = Union{}) do + evaluate(ampl, t) + end + @test captured.error + @test contains(captured.value.msg, "can only be evaluated with (tlist, n)") end @@ -48,7 +54,6 @@ end t = t_mid(tlist, 20) @test evaluate(ampl, tlist, 20) ≈ ϵ(t) * S(t) - ampl = ShapedAmplitude(ϵ, tlist; shape = S) @test startswith("$ampl", "ShapedAmplitude(") controls = get_controls(ampl) @@ -57,5 +62,196 @@ end @test check_amplitude(ampl; tlist) t = t_mid(tlist, 20) @test evaluate(ampl, tlist, 20) ≈ ϵ(t) * S(t) + captured = IOCapture.capture(rethrow = Union{}) do + evaluate(ampl, t) + end + @test captured.error + @test contains( + captured.value.msg, + "vector control and shape can only be evaluated with (tlist, n)" + ) + +end + + +@testset "Vector{Float64} conversion" begin + tlist = collect(range(0, 10, length = 101)) + S(t) = flattop(t, T = 10, t_rise = 2, func = :blackman) + ϵ(t) = 1.0 + + # LockedAmplitude: callable returning Int is discretized to Vector{Float64} + ampl = LockedAmplitude(t -> 1, tlist) + @test ampl.shape isa Vector{Float64} + @test check_amplitude(ampl; tlist) + + # LockedAmplitude: Vector{Int} is converted to Vector{Float64} + S_int = ones(Int, length(tlist) - 1) + ampl = LockedAmplitude(S_int) + @test ampl.shape isa Vector{Float64} + @test check_amplitude(ampl; tlist) + + # ShapedAmplitude: callable shape returning Int + ampl = ShapedAmplitude(ϵ; shape = t -> 1) + @test check_amplitude(ampl; tlist) + t = t_mid(tlist, 20) + @test evaluate(ampl, tlist, 20) ≈ ϵ(t) * 1 + + # ShapedAmplitude: Vector{Int} control and shape are converted + ϵ_int = ones(Int, length(tlist) - 1) + S_int = ones(Int, length(tlist) - 1) + ampl = ShapedAmplitude(ϵ_int; shape = S_int) + @test ampl.control isa Vector{Float64} + @test ampl.shape isa Vector{Float64} + @test check_amplitude(ampl; tlist) + + # ShapedAmplitude tlist constructor: callable returning Int is discretized to Vector{Float64} + ampl = ShapedAmplitude(t -> 1, tlist; shape = t -> 1) + @test ampl.control isa Vector{Float64} + @test ampl.shape isa Vector{Float64} + +end + + +@testset "ShapedAmplitude mixed" begin + tlist = collect(range(0, 10, length = 101)) + S(t) = flattop(t, T = 10, t_rise = 2, func = :blackman) + ϵ(t) = 1.0 + S_vec = discretize_on_midpoints(S, tlist) + ϵ_vec = discretize_on_midpoints(ϵ, tlist) + + # callable control, vector shape + ampl = ShapedAmplitude(ϵ; shape = S_vec) + @test startswith("$ampl", "ShapedAmplitude(") + controls = get_controls(ampl) + @test length(controls) == 1 + @test check_amplitude(ampl; tlist) + t = t_mid(tlist, 20) + @test evaluate(ampl, tlist, 20) ≈ ϵ(t) * S(t) + captured = IOCapture.capture(rethrow = Union{}) do + evaluate(ampl, t) + end + @test captured.error + @test contains(captured.value.msg, "vector shape can only be evaluated with (tlist, n)") + + # vector control, callable shape + ampl = ShapedAmplitude(ϵ_vec; shape = S) + @test startswith("$ampl", "ShapedAmplitude(") + controls = get_controls(ampl) + @test length(controls) == 1 + @test check_amplitude(ampl; tlist) + t = t_mid(tlist, 20) + @test evaluate(ampl, tlist, 20) ≈ ϵ(t) * S(t) + captured = IOCapture.capture(rethrow = Union{}) do + evaluate(ampl, t) + end + @test captured.error + @test contains( + captured.value.msg, + "vector control can only be evaluated with (tlist, n)" + ) + +end + + +@testset "Invalid constructions" begin + tlist = collect(range(0, 10, length = 101)) + S(t) = flattop(t, T = 10, t_rise = 2, func = :blackman) + ϵ(t) = 1.0 + n = length(tlist) - 1 + + # LockedAmplitude: complex vector is not convertible to Vector{Float64} + captured = IOCapture.capture(rethrow = Union{}) do + LockedAmplitude(fill(1.0 + 1.0im, n)) + end + @test captured.error + @test contains( + captured.value.msg, + "shape that is a vector must be convertible to Vector{Float64}" + ) + + # LockedAmplitude: non-vector, non-callable shape + captured = IOCapture.capture(rethrow = Union{}) do + LockedAmplitude(42) + end + @test captured.error + @test contains( + captured.value.msg, + "shape must either be a Vector{Float64} or a callable" + ) + + # ShapedAmplitude: complex vector control + captured = IOCapture.capture(rethrow = Union{}) do + ShapedAmplitude(fill(1.0 + 1.0im, n); shape = S) + end + @test captured.error + @test contains( + captured.value.msg, + "control that is a vector must be convertible to Vector{Float64}" + ) + + # ShapedAmplitude: complex vector shape + captured = IOCapture.capture(rethrow = Union{}) do + ShapedAmplitude(ϵ; shape = fill(1.0 + 1.0im, n)) + end + @test captured.error + @test contains( + captured.value.msg, + "shape that is a vector must be convertible to Vector{Float64}" + ) + + # ShapedAmplitude: non-vector, non-callable control + captured = IOCapture.capture(rethrow = Union{}) do + ShapedAmplitude(42; shape = S) + end + @test captured.error + @test contains( + captured.value.msg, + "control must either be a Vector{Float64} or a callable" + ) + + # ShapedAmplitude: non-vector, non-callable shape + captured = IOCapture.capture(rethrow = Union{}) do + ShapedAmplitude(ϵ; shape = 42) + end + @test captured.error + @test contains( + captured.value.msg, + "shape must either be a Vector{Float64} or a callable" + ) + + # ShapedAmplitude: control and shape vectors of different lengths + captured = IOCapture.capture(rethrow = Union{}) do + ShapedAmplitude(ones(n); shape = ones(n - 1)) + end + @test captured.error + @test contains( + captured.value.msg, + "control and shape vectors must have the same length" + ) + +end + + +@testset "ShapedAmplitude substitute" begin + tlist = collect(range(0, 10, length = 101)) + S(t) = flattop(t, T = 10, t_rise = 2, func = :blackman) + ϵ(t) = 1.0 + ϵ_vec = discretize_on_midpoints(ϵ, tlist) + + ampl = ShapedAmplitude(ϵ; shape = S) + ampl2 = substitute(ampl, IdDict(ϵ => ϵ_vec)) + @test ampl2 isa ShapedAmplitude + @test get_controls(ampl2)[1] === ϵ_vec + @test check_amplitude(ampl2; tlist) + t = t_mid(tlist, 20) + @test evaluate(ampl2, tlist, 20) ≈ ϵ(t) * S(t) + captured = IOCapture.capture(rethrow = Union{}) do + evaluate(ampl2, t) + end + @test captured.error + @test contains( + captured.value.msg, + "vector control can only be evaluated with (tlist, n)" + ) end