Skip to content

Commit 0f889af

Browse files
Copilotgoerz
andauthored
Fix normalization, comment typo, and add numerical gradient checks in Zygote tests
Agent-Logs-Url: https://github.com/JuliaQuantumControl/QuantumControl.jl/sessions/8d24cbf1-8f4b-49f7-9008-ac22384e7772 Co-authored-by: goerz <112306+goerz@users.noreply.github.com>
1 parent c611f03 commit 0f889af

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

test/test_traj_zygote.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ end
2323
N = 4
2424
H = nothing
2525
Ψ = rand(rng, ComplexF64, N)
26-
Ψ ./ norm(Ψ)
26+
Ψ ./= norm(Ψ)
2727
Ψtgt = zeros(ComplexF64, N)
2828
Ψtgt[1] = 1.0
2929
traj = Trajectory(Ψ, H)
3030
@test f(traj; Ψtgt, N) > 0.0
3131
grad = Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1]
3232
@test grad isa NamedTuple
3333
@test grad.initial_state isa Vector
34+
expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N
35+
@test grad.initial_state expected_grad
3436

3537
end
3638

@@ -45,14 +47,14 @@ end
4547
N = 4
4648
H = nothing
4749
Ψ = rand(rng, ComplexF64, N)
48-
Ψ ./ norm(Ψ)
50+
Ψ ./= norm(Ψ)
4951
Ψtgt = zeros(ComplexF64, N)
5052
Ψtgt[1] = 1.0
5153
x = Ψ
5254
traj = Trajectory(Ψ, H; x)
5355
@test f(traj; Ψtgt, N) > 0.0
5456
captured = IOCapture.capture(rethrow = Union{}) do
55-
# Without the custom `rrule` in `QuantumControlchainRulesCoreExt`, this
57+
# Without the custom `rrule` in `QuantumControlChainRulesCoreExt`, this
5658
# test would show a potentially very confusing error, and throw an
5759
# `UndefRefError`. See also: https://discourse.julialang.org/t/136704/
5860
Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1]
@@ -62,6 +64,8 @@ end
6264
if grad isa NamedTuple
6365
@test grad.initial_state isa Nothing
6466
@test grad.kwargs[:x] isa Vector
67+
expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N
68+
@test grad.kwargs[:x] expected_grad
6569
end
6670

6771
end

0 commit comments

Comments
 (0)