Skip to content

Commit 48dceda

Browse files
committed
Add support for gradients w.r.t. trajectories
This allows to call `Zygote.gradient` for a function where a `Trajectory` is the argument. See also https://discourse.julialang.org/t/136704 While this is probably not something that people would do _directly_, the lack of a custom `rrule` was causing issues with Zygote constructing the derivative of state-dependent running costs where information relevant to the running cost was stored in a custom property of the relevant `Trajectory`.
1 parent a10fb29 commit 48dceda

5 files changed

Lines changed: 107 additions & 0 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ QuantumPropagators = "7bf12567-5742-4b91-a078-644e72a65fc1"
1515
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1616

1717
[weakdeps]
18+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1819
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1920
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2021

2122
[extensions]
23+
QuantumControlChainRulesCoreExt = "ChainRulesCore"
2224
QuantumControlFiniteDifferencesExt = "FiniteDifferences"
2325
QuantumControlZygoteExt = "Zygote"
2426

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module QuantumControlChainRulesCoreExt
2+
3+
using ChainRulesCore: ChainRulesCore, NoTangent
4+
using QuantumControl: Trajectory
5+
6+
7+
# Allow to differentiate w.r.t. to a trajectory. See `test_traj_zygote.jl` for
8+
# an example. Evaluating a gradient with Zygote returns a NamedTuple with
9+
# the fields of the trajectory. Unfortunately, Zygote gets confused about the
10+
# custom `getproperty` method that is defined for a Trajectory, and we need a
11+
# special method to differential through `getproperty`
12+
function ChainRulesCore.rrule(::typeof(getproperty), traj::Trajectory, name::Symbol)
13+
val = getproperty(traj, name)
14+
if name in (:initial_state, :generator, :target_state, :weight)
15+
function field_pullback(Δ)
16+
dt = ChainRulesCore.Tangent{typeof(traj)}(; (name => Δ,)...)
17+
return NoTangent(), dt, NoTangent()
18+
end
19+
return val, field_pullback
20+
else
21+
# kwargs-stored property: route gradient back into the kwargs Dict
22+
function kwargs_pullback(Δ)
23+
dkwargs = Dict{Symbol,Any}(name => Δ)
24+
dt = ChainRulesCore.Tangent{typeof(traj)}(; kwargs = dkwargs)
25+
return NoTangent(), dt, NoTangent()
26+
end
27+
return val, kwargs_pullback
28+
end
29+
end
30+
31+
32+
end

src/trajectories.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ function Base.setproperty!(traj::Trajectory, name::Symbol, value)
167167
end
168168

169169

170+
# Transparently access properties stored in the `kwargs` field.
171+
# Note: This also requires a custom ChainRulesCore.rrule for certain operations
172+
# in Zygote (or other AD frameworks using ChainRules). This is implemented in
173+
# the QuantumControlChainRulesCoreExt extension module.
174+
# See `test_traj_zygote.jl` for an example
170175
function Base.getproperty(traj::Trajectory, name::Symbol)
171176
if name in (:initial_state, :generator, :target_state, :weight)
172177
return getfield(traj, name)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ end
5959
println("* Trajectories (test_trajectories.jl):")
6060
@time @safetestset "Trajectories" begin
6161
include("test_trajectories.jl")
62+
include("test_traj_zygote.jl")
6263
end
6364

6465
println("* Adjoint Trajectories (test_adjoint_trajectory.jl):")

test/test_traj_zygote.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using Test
2+
using StableRNGs
3+
using IOCapture
4+
using QuantumControl: Trajectory
5+
using LinearAlgebra: dot, norm
6+
using Random: rand
7+
using Zygote
8+
9+
10+
function J_T(Ψ; Ψtgt, N)
11+
return 1 - (abs2(dot(Ψ, Ψtgt)) / N)
12+
end
13+
14+
15+
16+
@testset "Gradient w.r.t. trajectory.initial_state" begin
17+
18+
function f(traj; Ψtgt, N)
19+
return J_T(traj.initial_state; Ψtgt, N)
20+
end
21+
22+
rng = StableRNG(3143162815)
23+
N = 4
24+
H = nothing
25+
Ψ = rand(rng, ComplexF64, N)
26+
Ψ ./ norm(Ψ)
27+
Ψtgt = zeros(ComplexF64, N)
28+
Ψtgt[1] = 1.0
29+
traj = Trajectory(Ψ, H)
30+
@test f(traj; Ψtgt, N) > 0.0
31+
grad = Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1]
32+
@test grad isa NamedTuple
33+
@test grad.initial_state isa Vector
34+
35+
end
36+
37+
38+
@testset "Gradient w.r.t. trajectory.x" begin
39+
40+
function f(traj; Ψtgt, N)
41+
return J_T(traj.x; Ψtgt, N)
42+
end
43+
44+
rng = StableRNG(3143162816)
45+
N = 4
46+
H = nothing
47+
Ψ = rand(rng, ComplexF64, N)
48+
Ψ ./ norm(Ψ)
49+
Ψtgt = zeros(ComplexF64, N)
50+
Ψtgt[1] = 1.0
51+
x = Ψ
52+
traj = Trajectory(Ψ, H; x)
53+
@test f(traj; Ψtgt, N) > 0.0
54+
captured = IOCapture.capture(rethrow = Union{}) do
55+
# Without the custom `rrule` in `QuantumControlchainRulesCoreExt`, this
56+
# test would show a potentially very confusing error, and throw an
57+
# `UndefRefError`. See also: https://discourse.julialang.org/t/136704/
58+
Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1]
59+
end
60+
grad = captured.value
61+
@test grad isa NamedTuple
62+
if grad isa NamedTuple
63+
@test grad.initial_state isa Nothing
64+
@test grad.kwargs[:x] isa Vector
65+
end
66+
67+
end

0 commit comments

Comments
 (0)