Skip to content

Commit 3208a49

Browse files
committed
Allow default_counters macro to skip functions
Skipping certain functions is useful to avoid overwriting methods. If, e.g., neval_hprod is forwarded by @default_counters Model inner and the user redefines NLPModels.neval_hprod(m::Model) = ... then, neval_hprod is being overwritten, and that prevents precompilation. That is the case in quasi-Newton models.
1 parent 3731d22 commit 3208a49

3 files changed

Lines changed: 34 additions & 4 deletions

File tree

src/NLPModels.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Base type for a nonlinear least-squares model.
3737
"""
3838
abstract type AbstractNLSModel{T, S} <: AbstractNLPModel{T, S} end
3939

40-
for f in ["utils", "api", "counters", "meta", "show", "tools"]
40+
for f in ["counters", "utils", "api", "meta", "show", "tools"]
4141
include("nlp/$f.jl")
4242
include("nls/$f.jl")
4343
end

src/nlp/utils.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,42 @@ function coo_sym_prod!(
116116
end
117117

118118
"""
119-
@default_counters Model inner
119+
@default_counters Model inner [excluded]
120120
121121
Define functions relating counters of `Model` to counters of `Model.inner`.
122+
Any function listed in `excluded` (which is an empty list by default), will
123+
not be forwarded.
124+
125+
Examples:
126+
127+
@default_counters MyModel inner (sum_counters, neval_hprod,)
128+
@default_counters MyModel inner (neval_hprod,)
129+
130+
Excluding a method from forwarding allows the user to redefine it without
131+
overwriting an existing method. Note that a generic method will still be
132+
defined as, e.g.,
133+
134+
neval_hprod(model) = model.inner.counters.neval_hprod
135+
136+
because the `counters` attribute itself is forwarded by `@default_counters`.
122137
"""
123-
macro default_counters(Model, inner)
138+
macro default_counters(Model, inner, excluded = :())
139+
140+
# Normalize excluded to a set of symbols
141+
println("excluded: ", excluded)
142+
println("excluded.args: ", excluded.args)
143+
excluded_set = if excluded == :()
144+
Set{Symbol}()
145+
elseif excluded isa Expr && excluded.head == :tuple
146+
Set{Symbol}([excluded.args...])
147+
else
148+
throw(ArgumentError("`@default_counters`: third argument must be a tuple of functions"))
149+
end
150+
println("excluded_set: ", excluded_set)
151+
124152
ex = Expr(:block)
125153
for foo in fieldnames(Counters) [:sum_counters]
154+
Symbol(foo) in excluded_set && continue
126155
push!(ex.args, :(NLPModels.$foo(nlp::$(esc(Model))) = $foo(nlp.$inner)))
127156
end
128157
push!(ex.args, :(NLPModels.reset!(nlp::$(esc(Model))) = begin

test/nlp/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ end
1818
end
1919

2020
@testset "Increase coverage of default_NLPcounters" begin
21-
@default_counters SuperNLPModel model
21+
@default_counters SuperNLPModel model (neval_hprod,)
2222
nlp = SuperNLPModel{Float64, Vector{Float64}}(SimpleNLPModel())
2323
increment!(nlp, :neval_obj)
2424
@test neval_obj(nlp.model) == 1
2525
@test nlp.counters == nlp.model.counters
26+
@test neval_hprod(nlp) == 0 # because counters are forwarded, even though neval_hprod has not been forwarded
2627
end

0 commit comments

Comments
 (0)