Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 138 additions & 70 deletions src/amplitudes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,46 @@ using ..Controls: t_mid
ampl = LockedAmplitude(shape)
```

wraps around `shape`, which must be either a vector of values defined on the
midpoints of a time grid or a callable `shape(t)`.
wraps around `shape`, which must be either a `Vector{Float64}` of values defined
on the midpoints of a time grid or a callable `shape(t)`.

```julia
ampl = LockedAmplitude(shape, tlist)
```

discretizes `shape` to the midpoints of `tlist`.
"""
abstract type LockedAmplitude end


function LockedAmplitude(shape)
if shape isa Vector{Float64}
return LockedPulseAmplitude(shape)
else
return LockedContinuousAmplitude(shape)
struct LockedAmplitude{ST}
shape::ST

function LockedAmplitude(shape::ST; check = true) where {ST}
if shape isa AbstractVector
if !(shape isa Vector{Float64})
try
shape = Vector{Float64}(shape)
catch
msg = "A LockedAmplitude shape that is a vector must be convertible to Vector{Float64}"
error(msg)
end
end
return new{Vector{Float64}}(shape)
else
if check
try
shape(0.0)
catch
msg = "A LockedAmplitude shape must either be a Vector{Float64} or a callable"
error(msg)
end
end
return new{ST}(shape)
end
end
end


function LockedAmplitude(shape, tlist)
return LockedPulseAmplitude(discretize_on_midpoints(shape, tlist))
return LockedAmplitude(discretize_on_midpoints(shape, tlist); check = false)
Comment thread
goerz marked this conversation as resolved.
end


Expand All @@ -46,53 +63,29 @@ function Base.show(io::IO, ampl::LockedAmplitude)
end


struct LockedPulseAmplitude <: LockedAmplitude
shape::Vector{Float64}
end


Base.Array(ampl::LockedPulseAmplitude) = ampl.shape


struct LockedContinuousAmplitude <: LockedAmplitude

shape

function LockedContinuousAmplitude(shape)
try
S_t = shape(0.0)
catch
error("A LockedAmplitude shape must either be a vector of values or a callable")
end
return new(shape)
end

end

(ampl::LockedContinuousAmplitude)(t::Float64) = ampl.shape(t)
Base.Array(ampl::LockedAmplitude{Vector{Float64}}) = ampl.shape

get_controls(ampl::LockedAmplitude) = ()

function substitute(ampl::LockedAmplitude, replacements)
return get(replacements, ampl, ampl)
end

function evaluate(ampl::LockedPulseAmplitude, tlist, n; _...)
function evaluate(ampl::LockedAmplitude{Vector{Float64}}, tlist, n::Int; _...)
return ampl.shape[n]
end

function evaluate(ampl::LockedPulseAmplitude, t; _...)
error(
"A LockedAmplitude initialized on a `tlist` can only be evaluated with arguments `tlist, n`."
)
function evaluate(ampl::LockedAmplitude{Vector{Float64}}, t::Float64; _...)
msg = "A LockedAmplitude initialized from a vector can only be evaluated with (tlist, n)."
error(msg)
end

function evaluate(ampl::LockedContinuousAmplitude, tlist, n; _...)
return ampl(t_mid(tlist, n))
function evaluate(ampl::LockedAmplitude, tlist, n::Int; _...)
return ampl.shape(t_mid(tlist, n))
end

function evaluate(ampl::LockedContinuousAmplitude, t; _...)
return ampl(t)
function evaluate(ampl::LockedAmplitude, t::Float64; _...)
return ampl.shape(t)
end


Expand Down Expand Up @@ -124,63 +117,138 @@ ampl = ShapedAmplitude(control; shape=shape)
```

produces an amplitude ``a(t) = S(t) ϵ(t)``, where ``S(t)`` corresponds to
`shape` and ``ϵ(t)`` corresponds to `control`. Both `control` and `shape`
should be either a vector of values defined on the midpoints of a time grid or
a callable `control(t)`, respectively `shape(t)`. In the latter case, `ampl`
will also be callable.
`shape` and ``ϵ(t)`` corresponds to `control`. Each of `control` and `shape`
must be either a `Vector{Float64}` of values defined on the midpoints of a time
grid or a callable `control(t)`, respectively `shape(t)`. If both are callables,
`ampl` will also be callable. If both are vectors, they must have the same length.

```julia
ampl = ShapedAmplitude(control, tlist; shape=shape)
```

discretizes `control` and `shape` to the midpoints of `tlist`.
"""
abstract type ShapedAmplitude <: ControlAmplitude end
struct ShapedAmplitude{CT,ST} <: ControlAmplitude
control::CT
shape::ST
end

function ShapedAmplitude(control; shape)
if (control isa Vector{Float64}) && (shape isa Vector{Float64})
return ShapedPulseAmplitude(control, shape)
else

function ShapedAmplitude(control; shape, check = true)
if control isa AbstractVector && !(control isa Vector{Float64})
try
ϵ_t = control(0.0)
control = Vector{Float64}(control)
catch
error(
"A ShapedAmplitude control must either be a vector of values or a callable"
)
msg = "A ShapedAmplitude control that is a vector must be convertible to Vector{Float64}"
error(msg)
end
end
if shape isa AbstractVector && !(shape isa Vector{Float64})
try
S_t = shape(0.0)
shape = Vector{Float64}(shape)
catch
error("A ShapedAmplitude shape must either be a vector of values or a callable")
msg = "A ShapedAmplitude shape that is a vector must be convertible to Vector{Float64}"
error(msg)
end
return ShapedContinuousAmplitude(control, shape)
end
if check
if !(control isa Vector{Float64})
try
control(0.0)
catch
msg = "A ShapedAmplitude control must either be a Vector{Float64} or a callable"
error(msg)
end
end
if !(shape isa Vector{Float64})
try
shape(0.0)
catch
msg = "A ShapedAmplitude shape must either be a Vector{Float64} or a callable"
error(msg)
end
end
if (control isa Vector{Float64}) && (shape isa Vector{Float64})
if length(control) ≠ length(shape)
msg = "ShapedAmplitude control and shape vectors must have the same length"
error(msg)
end
end
end
return ShapedAmplitude{typeof(control),typeof(shape)}(control, shape)
end


function Base.show(io::IO, ampl::ShapedAmplitude)
print(io, "ShapedAmplitude(::$(typeof(ampl.control)); shape::$(typeof(ampl.shape)))")
end

function ShapedAmplitude(control, tlist; shape)

function ShapedAmplitude(control, tlist::Vector{Float64}; shape)
control = discretize_on_midpoints(control, tlist)
shape = discretize_on_midpoints(shape, tlist)
return ShapedPulseAmplitude(control, shape)
return ShapedAmplitude{Vector{Float64},Vector{Float64}}(control, shape)
end

struct ShapedPulseAmplitude <: ShapedAmplitude
control::Vector{Float64}
shape::Vector{Float64}

function substitute(ampl::ShapedAmplitude, replacements)
ampl in keys(replacements) && return replacements[ampl]
control = substitute(ampl.control, replacements)
return ShapedAmplitude(control; shape = ampl.shape, check = true)
end

Base.Array(ampl::ShapedPulseAmplitude) = ampl.control .* ampl.shape

Base.Array(ampl::ShapedAmplitude{Vector{Float64},Vector{Float64}}) =
ampl.control .* ampl.shape

(ampl::ShapedAmplitude)(t::Float64) = ampl.shape(t) * ampl.control(t)


Comment thread
goerz marked this conversation as resolved.
struct ShapedContinuousAmplitude <: ShapedAmplitude
control
shape
# Vector shape: index directly; control may be vector or callable (evaluated via Controls.evaluate)
function evaluate(
ampl::ShapedAmplitude{CT,Vector{Float64}},
tlist::Vector{Float64},
n::Int;
vals_dict = IdDict()
) where {CT}
S_t = ampl.shape[n]
ϵ_t = evaluate(ampl.control, tlist, n; vals_dict)
return S_t * ϵ_t
end

(ampl::ShapedContinuousAmplitude)(t::Float64) = ampl.shape(t) * ampl.control(t)
# Callable shape: call directly with t_mid; control may be vector or callable
function evaluate(
ampl::ShapedAmplitude,
tlist::Vector{Float64},
n::Int;
vals_dict = IdDict()
)
S_t = ampl.shape(t_mid(tlist, n))
ϵ_t = evaluate(ampl.control, tlist, n; vals_dict)
return S_t * ϵ_t
end

# Callable shape and callable control: call both directly at t
function evaluate(ampl::ShapedAmplitude, t::Float64; vals_dict = IdDict())
S_t = ampl.shape(t)
ϵ_t = evaluate(ampl.control, t; vals_dict)
return S_t * ϵ_t
end

function evaluate(ampl::ShapedAmplitude{Vector{Float64},Vector{Float64}}, t::Float64; _...)
msg = "A ShapedAmplitude with vector control and shape can only be evaluated with (tlist, n)."
error(msg)
end

function evaluate(ampl::ShapedAmplitude{Vector{Float64},ST}, t::Float64; _...) where {ST}
msg = "A ShapedAmplitude with a vector control can only be evaluated with (tlist, n)."
error(msg)
end

function evaluate(ampl::ShapedAmplitude{CT,Vector{Float64}}, t::Float64; _...) where {CT}
msg = "A ShapedAmplitude with a vector shape can only be evaluated with (tlist, n)."
error(msg)
end


function evaluate(ampl::ShapedAmplitude, args...; vals_dict = IdDict())
Expand Down
10 changes: 5 additions & 5 deletions src/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ function discretize(control::Function, tlist; via_midpoints = true)
vals_on_midpoints = discretize_on_midpoints(control, tlist)
return discretize(vals_on_midpoints, tlist)
else
return [control(t) for t in tlist]
return Float64[control(t) for t in tlist]
end
end

function discretize(control::Vector, tlist)
if length(control) == length(tlist)
return copy(control)
return Vector{Float64}(control)
elseif length(control) == length(tlist) - 1
# convert `control` on intervals to values on `tlist`
# cf. pulse_onto_tlist in Python krotov package
vals = zeros(eltype(control), length(control) + 1)
vals = zeros(Float64, length(control) + 1)
vals[1] = control[1]
vals[end] = control[end]
for i = 2:(length(vals)-1)
Expand Down Expand Up @@ -193,9 +193,9 @@ end

function discretize_on_midpoints(control::Vector, tlist)
if length(control) == length(tlist) - 1
return copy(control)
return Vector{Float64}(control)
elseif length(control) == length(tlist)
vals = Vector{eltype(control)}(undef, length(tlist) - 1)
vals = Vector{Float64}(undef, length(tlist) - 1)
vals[1] = control[1]
vals[end] = control[end]
for i = 2:(length(vals)-1)
Expand Down
Loading