diff --git a/ext/QuantumControlFiniteDifferencesExt.jl b/ext/QuantumControlFiniteDifferencesExt.jl index 3c97e698..e6163f20 100644 --- a/ext/QuantumControlFiniteDifferencesExt.jl +++ b/ext/QuantumControlFiniteDifferencesExt.jl @@ -4,15 +4,10 @@ using LinearAlgebra import FiniteDifferences import QuantumControl.Functionals: - _default_chi_via, make_gate_chi, make_automatic_chi, make_automatic_grad_J_a + make_gate_chi, make_automatic_chi, make_automatic_grad_J_a -function make_automatic_chi( - J_T, - trajectories, - ::Val{:FiniteDifferences}; - via=_default_chi_via(trajectories) -) +function make_automatic_chi(J_T, trajectories, ::Val{:FiniteDifferences}; via=:states) # TODO: Benchmark if χ should be closure, see QuantumControlZygoteExt.jl diff --git a/ext/QuantumControlZygoteExt.jl b/ext/QuantumControlZygoteExt.jl index 858d28af..3e4bf22a 100644 --- a/ext/QuantumControlZygoteExt.jl +++ b/ext/QuantumControlZygoteExt.jl @@ -4,15 +4,10 @@ using LinearAlgebra import Zygote import QuantumControl.Functionals: - _default_chi_via, make_gate_chi, make_automatic_chi, make_automatic_grad_J_a + make_gate_chi, make_automatic_chi, make_automatic_grad_J_a -function make_automatic_chi( - J_T, - trajectories, - ::Val{:Zygote}; - via=_default_chi_via(trajectories) -) +function make_automatic_chi(J_T, trajectories, ::Val{:Zygote}; via=:states) # TODO: At some point, for a large system, we could benchmark if there is # any benefit to making χ a closure and using LinearAlgebra.axpby! to @@ -26,7 +21,14 @@ function make_automatic_chi( χ = Vector{eltype(Ψ)}(undef, length(Ψ)) ∇J = Zygote.gradient(_J_T, Ψ...) for (k, ∇Jₖ) ∈ enumerate(∇J) - χ[k] = 0.5 * ∇Jₖ # ½ corrects for gradient vs Wirtinger deriv + if isnothing(∇Jₖ) + # Functional does not depend on Ψₖ. That probably means a buggy + # J_T, but who knows: maybe there are situations where that + # makes sense. It would be extremely noisy to warn here. + χ[k] = zero(χ[k]) + else + χ[k] = 0.5 * ∇Jₖ # ½ corrects for gradient vs Wirtinger deriv + end # axpby!(0.5, ∇Jₖ, false, χ[k]) end return χ @@ -43,7 +45,12 @@ function make_automatic_chi( χ = Vector{eltype(Ψ)}(undef, length(Ψ)) ∇J = Zygote.gradient(_J_T, τ...) for (k, traj) ∈ enumerate(trajectories) - ∂J╱∂τ̄ₖ = 0.5 * ∇J[k] # ½ corrects for gradient vs Wirtinger deriv + if isnothing(∇J[k]) + # Functional does not depend on τₖ + ∂J╱∂τ̄ₖ = zero(ComplexF64) + else + ∂J╱∂τ̄ₖ = 0.5 * ∇J[k] # ½ corrects for gradient vs Wirtinger deriv + end χ[k] = ∂J╱∂τ̄ₖ * traj.target_state # axpby!(∂J╱∂τ̄ₖ, traj.target_state, false, χ[k]) end diff --git a/src/functionals.jl b/src/functionals.jl index d3c9c00c..c262027f 100644 --- a/src/functionals.jl +++ b/src/functionals.jl @@ -9,12 +9,21 @@ export make_grad_J_a, make_chi using LinearAlgebra: axpy!, dot -# default for `via` argument of `make_chi` -function _default_chi_via(trajectories) - if any(isnothing(traj.target_state) for traj in trajectories) - return :states - else - return :tau +function _check_chi(chi; states, trajectories, tau, via) + try + if via == :tau + chi_states = chi(states, trajectories; tau) + else + chi_states = chi(states, trajectories) + end + if typeof(chi_states) ≠ typeof(states) + msg = "`chi` must return a vector of states" + error(msg) + end + catch exception + msg = "The chi generated by `make_chi` does not have the required interface" + @error msg exception + error("Cannot make chi") end end @@ -86,15 +95,25 @@ chi = make_chi( trajectories; mode=:any, automatic=:default, - via=(any(isnothing(t.target_state) for t in trajectories) ? :states : :tau), + via=:automatic, # one of :automatic, :tau, :states ) ``` -creates a function `chi(Ψ, trajectories; τ)` that returns -a vector of states `χ` with ``|χ_k⟩ = -∂J_T/∂⟨Ψ_k|``, where ``|Ψ_k⟩`` is the -k'th element of `Ψ`. These are the states used as the boundary condition for -the backward propagation propagation in Krotov's method and GRAPE. Each -``|χₖ⟩`` is defined as a matrix calculus +creates a function `chi(Ψ, trajectories)` or `chi(Ψ, trajectories; tau)` that +returns a vector of states `χ` with ``|χ_k⟩ = -∂J_T/∂⟨Ψ_k|``, where ``|Ψ_k⟩`` +is the k'th element of `Ψ`. These are the states used as the boundary condition +for the backward propagation propagation in Krotov's method and GRAPE. + +The resulting `chi` function takes the keyword argument `tau` +if and only if `via=:tau` or `via=:automatic` (default) if the following +conditions are met: + +* All `trajectories` have a defined `target_state` component (not `nothing`) +* `J_T` takes `tau` as a keyword argument (determined via introspection) + +Both of these conditions are _requirements_ for `via=:tau`. + +Each ``|χₖ⟩`` is defined as a matrix calculus [Wirtinger derivative](https://www.ekinakyurek.me/complex-derivatives-wirtinger/), ```math @@ -193,25 +212,53 @@ and the definition of the Zygote gradient with respect to a complex scalar, gradients). Always test automatic derivatives against finite differences and/or other automatic differentiation frameworks. """ -function make_chi( - J_T, - trajectories; - mode=:any, - automatic=:default, - via=_default_chi_via(trajectories), -) +function make_chi(J_T, trajectories; mode=:any, automatic=:default, via=:automatic,) + states = [traj.initial_state for traj in trajectories] + tau = [zero(ComplexF64) for _ in states] + J_T_takes_tau = hasmethod(J_T, Tuple{typeof(states),typeof(trajectories)}, (:tau,)) + has_target_states = all((traj.target_state ≢ nothing) for traj in trajectories) + if (via == :tau) && !J_T_takes_tau + msg = "Called `make_chi` with `via=:tau`, but given J_T does not take `tau` keyword argument" + error(msg) + end + if (via == :tau) && !has_target_states + msg = "Called `make_chi` with `via=:tau`, but not all `trajectories` define a `target_state`" + error(msg) + end + if via == :automatic + via = :states + if J_T_takes_tau && has_target_states + via = :tau + end + end + chi = nothing + try + if via == :tau + J_T_val = J_T(states, trajectories; tau) + else + J_T_val = J_T(states, trajectories) + end + if !(J_T_val isa Float64) + msg = "J_T passed to `make_chi` must return a Float64, not $(typeof(J_T_val))" + error(msg) + end + catch exception + msg = "The J_T passed to `make_chi` does not have the required interface" + @error msg exception + error("Cannot make chi") + end if mode == :any try chi = make_analytic_chi(J_T, trajectories) @debug "make_chi for J_T=$(J_T) -> analytic" - # TODO: call chi to compile it and ensure required properties + _check_chi(chi; states, trajectories, tau, via) return chi catch exception if exception isa MethodError @info "make_chi for J_T=$(J_T): fallback to mode=:automatic" try chi = make_automatic_chi(J_T, trajectories, automatic; via) - # TODO: call chi to compile it and ensure required properties + _check_chi(chi; states, trajectories, tau, via) return chi catch exception if exception isa MethodError @@ -228,7 +275,7 @@ function make_chi( elseif mode == :analytic try chi = make_analytic_chi(J_T, trajectories) - # TODO: call chi to compile it and ensure required properties + _check_chi(chi; states, trajectories, tau, via) return chi catch exception if exception isa MethodError @@ -241,7 +288,7 @@ function make_chi( elseif mode == :automatic try chi = make_automatic_chi(J_T, trajectories, automatic; via) - # TODO: call chi to compile it and ensure required properties + _check_chi(chi; states, trajectories, tau, via) return chi catch exception if exception isa MethodError diff --git a/test/test_functionals.jl b/test/test_functionals.jl index 8675bed9..0e501af2 100644 --- a/test/test_functionals.jl +++ b/test/test_functionals.jl @@ -299,13 +299,13 @@ end throw(DomainError("XXX")) end - @test_throws DomainError begin + @test_throws Exception begin IOCapture.capture() do make_chi(J_T_xxx, trajectories) end end - @test_throws DomainError begin + @test_throws Exception begin IOCapture.capture() do make_chi(J_T_xxx, trajectories; mode=:automatic) end