Skip to content

Commit 75feed0

Browse files
authored
add step_status and step_status_reliable to stats (#130)
Each solver can now set the step status (accepted or rejected) for use in a callback. This is particularly useful to perform quasi-Newton updates.
1 parent e02ce0f commit 75feed0

3 files changed

Lines changed: 63 additions & 1 deletion

File tree

ext/SolverCoreNLPModelsExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ function SolverCore.GenericExecutionStats(
7878
multipliers_L::V = similar(nlp.meta.y0, has_bounds(nlp) ? nlp.meta.nvar : 0),
7979
multipliers_U::V = similar(nlp.meta.y0, has_bounds(nlp) ? nlp.meta.nvar : 0),
8080
iter::Int = -1,
81+
step_status::Symbol = :unknown,
8182
elapsed_time::Real = Inf,
8283
solver_specific::Dict{Symbol, Tsp} = Dict{Symbol, Any}(),
8384
) where {T, S, V, Tsp}
@@ -101,6 +102,8 @@ function SolverCore.GenericExecutionStats(
101102
false,
102103
iter,
103104
false,
105+
step_status,
106+
false,
104107
elapsed_time,
105108
false,
106109
solver_specific,

src/stats.jl

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@ export AbstractExecutionStats,
1111
set_constraint_multipliers!,
1212
set_bounds_multipliers!,
1313
set_iter!,
14+
set_step_status!,
1415
set_time!,
1516
broadcast_solver_specific!,
1617
set_solver_specific!,
1718
statsgetfield,
1819
statshead,
1920
statsline,
2021
getStatus,
21-
show_statuses
22+
show_statuses,
23+
show_step_statuses
2224

2325
const STATUSES = Dict(
2426
:exception => "unhandled exception",
@@ -62,6 +64,32 @@ function show_statuses()
6264
end
6365
end
6466

67+
const STEP_STATUSES =
68+
Dict(:unknown => "unknown", :accepted => "step accepted", :rejected => "step rejected")
69+
70+
function check_step_status(step_status::Symbol)
71+
if !(step_status in keys(STEP_STATUSES))
72+
@error "step_status $step_status is not a valid step status. Use one of the following: " join(
73+
keys(STEP_STATUSES),
74+
", ",
75+
)
76+
throw(KeyError(step_status))
77+
end
78+
end
79+
80+
"""
81+
show_step_statuses()
82+
83+
Show the list of available step statuses to use with `GenericExecutionStats`.
84+
"""
85+
function show_step_statuses()
86+
println("STEP_STATUSES:")
87+
for k in keys(STEP_STATUSES) |> collect |> sort
88+
v = STEP_STATUSES[k]
89+
@printf(" :%-10s => %s\n", k, v)
90+
end
91+
end
92+
6593
abstract type AbstractExecutionStats end
6694

6795
"""
@@ -79,6 +107,7 @@ It contains the following fields:
79107
- `multipliers_L`: The Lagrange multipliers wrt to the lower bounds on the variables (default: an uninitialized vector like `nlp.meta.x0` if there are bounds, or a zero-length vector if not);
80108
- `multipliers_U`: The Lagrange multipliers wrt to the upper bounds on the variables (default: an uninitialized vector like `nlp.meta.x0` if there are bounds, or a zero-length vector if not);
81109
- `iter`: The number of iterations computed by the solver (default: `-1`);
110+
- `step_status`: The status of the most recently computed step. Use show_step_statuses() for the full list (default: `:unknown`);
82111
- `elapsed_time`: The elapsed time computed by the solver (default: `Inf`);
83112
- `solver_specific::Dict{Symbol,Any}`: A solver specific dictionary.
84113
@@ -94,6 +123,7 @@ The following fields indicate whether the information above has been updated and
94123
- `multipliers_reliable` (for `multipliers`)
95124
- `bounds_multipliers_reliable` (for `multipliers_L` and `multipliers_U`)
96125
- `iter_reliable`
126+
- `step_status_reliable`
97127
- `time_reliable`
98128
- `solver_specific_reliable`.
99129
@@ -127,6 +157,8 @@ mutable struct GenericExecutionStats{T, S, V, Tsp} <: AbstractExecutionStats
127157
multipliers_U::V # zU
128158
iter_reliable::Bool
129159
iter::Int
160+
step_status_reliable::Bool
161+
step_status::Symbol
130162
time_reliable::Bool
131163
elapsed_time::Float64
132164
solver_specific_reliable::Bool
@@ -143,6 +175,7 @@ function GenericExecutionStats{T, S, V, Tsp}(;
143175
multipliers_L::V = V(),
144176
multipliers_U::V = V(),
145177
iter::Int = -1,
178+
step_status::Symbol = :unknown,
146179
elapsed_time::Real = Inf,
147180
solver_specific::Dict{Symbol, Tsp} = Dict{Symbol, Any}(),
148181
) where {T, S, V, Tsp}
@@ -165,6 +198,8 @@ function GenericExecutionStats{T, S, V, Tsp}(;
165198
false,
166199
iter,
167200
false,
201+
step_status,
202+
false,
168203
elapsed_time,
169204
false,
170205
solver_specific,
@@ -187,6 +222,7 @@ function reset!(stats::GenericExecutionStats{T, S, V, Tsp}) where {T, S, V, Tsp}
187222
stats.multipliers_reliable = false
188223
stats.bounds_multipliers_reliable = false
189224
stats.iter_reliable = false
225+
stats.step_status_reliable = false
190226
stats.time_reliable = false
191227
stats.solver_specific_reliable = false
192228
stats
@@ -314,6 +350,18 @@ function set_iter!(stats::GenericExecutionStats, iter::Int)
314350
stats
315351
end
316352

353+
"""
354+
set_step_status!(stats::GenericExecutionStats, step_status::Symbol)
355+
356+
Register `step_status` as most recent step status in `stats` and mark it as reliable.
357+
"""
358+
function set_step_status!(stats::GenericExecutionStats, step_status::Symbol)
359+
check_step_status(step_status)
360+
stats.step_status = step_status
361+
stats.step_status_reliable = true
362+
stats
363+
end
364+
317365
"""
318366
set_time!(stats::GenericExecutionStats, time::Float64)
319367
@@ -431,6 +479,9 @@ function statsgetfield(stats::AbstractExecutionStats, name::Symbol)
431479
if name == :status
432480
v = getStatus(stats)
433481
t = String
482+
elseif name == :step_status
483+
v = getStepStatus(stats)
484+
t = String
434485
elseif name in fieldnames(typeof(stats))
435486
v = getfield(stats, name)
436487
t = fieldtype(typeof(stats), name)
@@ -458,6 +509,10 @@ function getStatus(stats::AbstractExecutionStats)
458509
return STATUSES[stats.status]
459510
end
460511

512+
function getStepStatus(stats::AbstractExecutionStats)
513+
return STEP_STATUSES[stats.step_status]
514+
end
515+
461516
"""
462517
get_status(problem, kwargs...)
463518

test/test-stats.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ function test_stats()
7474

7575
stats = GenericExecutionStats{Float64, Vector{Float64}, Vector{Float64}, Any}()
7676
@test_throws Exception set_status!(stats, :bad)
77+
@test_throws Exception set_step_status!(stats, :medium_well)
7778
end
7879

7980
@testset "Testing Dummy Solver with multi-precision" begin
@@ -120,6 +121,7 @@ function test_stats()
120121
"multipliers",
121122
"bounds_multipliers",
122123
"iter",
124+
"step_status",
123125
"time",
124126
"solver_specific",
125127
)
@@ -143,6 +145,8 @@ function test_stats()
143145
@test stats.bounds_multipliers_reliable
144146
set_iter!(stats, 2)
145147
@test stats.iter_reliable
148+
set_step_status!(stats, :accepted)
149+
@test stats.step_status_reliable
146150
set_time!(stats, 0.1)
147151
@test stats.time_reliable
148152
set_solver_specific!(stats, :bla, "boo!")

0 commit comments

Comments
 (0)