Skip to content

Commit c8cf47a

Browse files
committed
Fix analytical derivative in test
1 parent 62cd8e4 commit c8cf47a

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

test/test_traj_zygote.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
grad = Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1]
3232
@test grad isa NamedTuple
3333
@test grad.initial_state isa Vector
34-
expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N
34+
expected_grad = -2 .* Ψtgt .* conj(dot(Ψ, Ψtgt)) / N
3535
@test norm(grad.initial_state - expected_grad) < 1e-14
3636

3737
end
@@ -64,7 +64,7 @@ end
6464
if grad isa NamedTuple
6565
@test grad.initial_state isa Nothing
6666
@test grad.kwargs[:x] isa Vector
67-
expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N
67+
expected_grad = -2 .* Ψtgt .* conj(dot(Ψ, Ψtgt)) / N
6868
@test norm(grad.kwargs[:x] - expected_grad) < 1e-14
6969
end
7070

0 commit comments

Comments
 (0)