Skip to content

Commit a10fb29

Browse files
committed
Merge #100 from branch optimize-check-kwargs
2 parents eeb52e0 + 802cbb5 commit a10fb29

2 files changed

Lines changed: 47 additions & 15 deletions

File tree

src/optimize.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ result = optimize(problem; method=Krotov)
3131
3232
If `check` is true (default), the `initial_state` and `generator` of each
3333
trajectory is checked with [`check_state`](@ref) and [`check_generator`](@ref).
34-
Any other keyword argument temporarily overrides the corresponding keyword
35-
argument in [`problem`](@ref ControlProblem). These arguments are available to
36-
the optimizer, see each optimization package's documentation for details.
34+
Additional keyword arguments can be passed to these two `check` routines by
35+
passing a named-tuple as `check_state_kwargs` and `check_generator_kwargs`.
3736
3837
The `callback` can be given as a function to be called after each iteration in
3938
order to analyze the progress of the optimization or to modify the state of
@@ -54,7 +53,8 @@ the method-specific [`make_print_iters`](@ref) to print the progress of the
5453
optimization after each iteration. This automatic callback runs after any
5554
manually given `callback`.
5655
57-
All remaining keyword argument are method-specific.
56+
All remaining keyword argument are method-specific and temporarily overrides
57+
the corresponding keyword argument in [`problem`](@ref ControlProblem).
5858
To obtain the documentation for which options a particular method uses, run,
5959
e.g.,
6060
@@ -75,10 +75,8 @@ function optimize(
7575
check = get(problem.kwargs, :check, true),
7676
print_iters = get(problem.kwargs, :print_iters, true),
7777
callback = get(problem.kwargs, :callback, nothing),
78-
for_expval = true, # undocumented
79-
for_pwc = true, # undocumented
80-
for_time_continuous = false, # undocumented
81-
for_parameterization = false, # undocumented
78+
check_state_kwargs = get(problem.kwargs, :check_state_kwargs, (;)),
79+
check_generator_kwargs = get(problem.kwargs, :check_generator_kwargs, (;)),
8280
kwargs...
8381
)
8482

@@ -113,20 +111,16 @@ function optimize(
113111
)
114112

115113
if check
116-
# TODO: checks will have to be method-dependent, and then we may not
117-
# need all the `for_...` keyword arguments
114+
# TODO: checks maybe should be method-dependent
118115
for (i, traj) in enumerate(problem.trajectories)
119-
if !check_state(traj.initial_state)
116+
if !check_state(traj.initial_state; check_state_kwargs...)
120117
error("The `initial_state` of trajectory $i is not valid")
121118
end
122119
if !check_generator(
123120
traj.generator;
121+
check_generator_kwargs...,
124122
state = traj.initial_state,
125123
tlist = problem.tlist,
126-
for_expval,
127-
for_pwc,
128-
for_time_continuous,
129-
for_parameterization,
130124
)
131125
error("The `generator` of trajectory $i is not valid")
132126
end

test/test_optimize_or_load.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,44 @@ using QuantumControl: @optimize_or_load, load_optimization, save_optimization, o
55
using QuantumControlTestUtils.DummyOptimization:
66
dummy_control_problem, DummyOptimizationResult
77

8+
9+
@testset "check_state_kwargs and check_generator_kwargs" begin
10+
11+
problem = dummy_control_problem()
12+
outdir = mktempdir()
13+
outfile = joinpath(outdir, "optimization_check_kwargs.jld2")
14+
15+
# Smoke test: extra kwargs are accepted
16+
captured = IOCapture.capture(passthrough = false) do
17+
@optimize_or_load(
18+
outfile,
19+
problem;
20+
method = :dummymethod,
21+
force = true,
22+
check_state_kwargs = (; atol = 1e-10),
23+
check_generator_kwargs = (; atol = 1e-10),
24+
)
25+
end
26+
@test captured.value isa DummyOptimizationResult
27+
@test captured.value.converged
28+
29+
# Test with teeth: for_time_continuous=true fails check_generator for the
30+
# pwc dummy controls (array-based controls cannot be evaluated at a
31+
# continuous time t)
32+
captured2 = IOCapture.capture(rethrow = Union{}, passthrough = false) do
33+
@optimize_or_load(
34+
outfile,
35+
problem;
36+
method = :dummymethod,
37+
force = true,
38+
check_generator_kwargs = (; for_time_continuous = true),
39+
)
40+
end
41+
@test captured2.value isa Exception
42+
43+
end
44+
45+
846
@testset "metadata" begin
947

1048
problem = dummy_control_problem()

0 commit comments

Comments
 (0)