Skip to content

Commit cb542e8

Browse files
simplify TR
1 parent 4a44919 commit cb542e8

1 file changed

Lines changed: 14 additions & 67 deletions

File tree

src/TR_alg.jl

Lines changed: 14 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ mutable struct TRSolver{
1818
s::V
1919
v0::V
2020
v1::V
21-
has_bnds::Bool
22-
l_bound::V
23-
u_bound::V
24-
l_bound_m_x::V
25-
u_bound_m_x::V
2621
m_fh_hist::V
2722
subsolver::ST
2823
subpb::PB
@@ -36,39 +31,20 @@ function TRSolver(
3631
m_monotone::Int = 1,
3732
) where {T, V, X}
3833
x0 = reg_nlp.model.meta.x0
39-
l_bound = reg_nlp.model.meta.lvar
40-
u_bound = reg_nlp.model.meta.uvar
4134

4235
xk = similar(x0)
4336
∇fk = similar(x0)
4437
∇fk⁻ = similar(x0)
4538
mν∇fk = similar(x0)
4639
xkn = similar(x0)
4740
s = similar(x0)
48-
has_bnds = any(l_bound .!= T(-Inf)) || any(u_bound .!= T(Inf))
49-
if has_bnds || subsolver == TRDHSolver
50-
l_bound_m_x = similar(xk)
51-
u_bound_m_x = similar(xk)
52-
@. l_bound_m_x = l_bound - x0
53-
@. u_bound_m_x = u_bound - x0
54-
else
55-
l_bound_m_x = similar(xk, 0)
56-
u_bound_m_x = similar(xk, 0)
57-
end
5841

5942
m_fh_hist = fill(T(-Inf), m_monotone - 1)
6043

6144
v0 = [(-1.0)^i for i = 0:(reg_nlp.model.meta.nvar - 1)]
6245
v0 ./= sqrt(reg_nlp.model.meta.nvar)
6346
v1 = similar(v0)
6447

65-
ψ =
66-
has_bnds || subsolver == TRDHSolver ?
67-
shifted(reg_nlp.h, xk, l_bound_m_x, u_bound_m_x, reg_nlp.selected) :
68-
shifted(reg_nlp.h, xk, T(1), χ)
69-
70-
Bk = hess_op(reg_nlp, xk)
71-
sub_nlp = QuadraticModel(∇fk, Bk, x0 = x0)
7248
subpb = ShiftedProximableQuadraticNLPModel(reg_nlp, xk, ∇f = ∇fk, χ = χ)
7349
substats = RegularizedExecutionStats(subpb)
7450
subsolver = subsolver(subpb)
@@ -83,11 +59,6 @@ function TRSolver(
8359
s,
8460
v0,
8561
v1,
86-
has_bnds,
87-
l_bound,
88-
u_bound,
89-
l_bound_m_x,
90-
u_bound_m_x,
9162
m_fh_hist,
9263
subsolver,
9364
subpb,
@@ -247,16 +218,12 @@ function SolverCore.solve!(
247218
s = solver.s
248219
χ = solver.χ
249220
m_fh_hist = solver.m_fh_hist .= T(-Inf)
250-
has_bnds = solver.has_bnds
221+
has_bnds = has_bounds(nlp)
251222

252223
m_monotone = length(m_fh_hist) + 1
253224

254225
set_radius!(solver.subpb, Δk)
255226
if has_bnds || isa(solver.subsolver, TRDHSolver) #TODO elsewhere ?
256-
#l_bound_m_x, u_bound_m_x = solver.l_bound_m_x, solver.u_bound_m_x
257-
#l_bound, u_bound = solver.l_bound, solver.u_bound
258-
#update_bounds!(l_bound_m_x, u_bound_m_x, false, l_bound, u_bound, xk, Δk)
259-
#set_bounds!(ψ, l_bound_m_x, u_bound_m_x)
260227
set_bounds!(solver.subsolver.ψ, ψ.l, ψ.u)
261228
else
262229
set_radius!(solver.subsolver.ψ, Δk)
@@ -328,20 +295,8 @@ function SolverCore.solve!(
328295
set_solver_specific!(stats, :prox_evals, prox_evals + 1)
329296
m_monotone > 1 && (m_fh_hist[stats.iter % (m_monotone - 1) + 1] = fk + hk)
330297

331-
# models
332-
φ1 = let ∇fk = ∇fk
333-
d -> dot(∇fk, d)
334-
end
335-
mk1 = let ψ = ψ, φ1 = φ1
336-
d -> φ1(d) + ψ(d)
337-
end
338-
339-
mk = let ψ = ψ, solver = solver
340-
d -> obj(solver.subpb.model, d) + ψ(d)::T
341-
end
342-
343298
prox!(s, ψ, mν∇fk, ν₁)
344-
ξ1 = hk - mk1(s) + max(1, abs(hk)) * 10 * eps()
299+
ξ1 = hk - obj(mk, s, cauchy = true) + max(1, abs(hk)) * 10 * eps()
345300
ξ1 > 0 || error("TR: first prox-gradient step should produce a decrease but ξ1 = $(ξ1)")
346301
sqrt_ξ1_νInv = sqrt(ξ1 / ν₁)
347302

@@ -372,14 +327,13 @@ function SolverCore.solve!(
372327
sub_atol = stats.iter == 0 ? 1e-5 : max(sub_atol, min(1e-2, sqrt_ξ1_νInv))
373328
∆_effective = min* χ(s), Δk)
374329

375-
if has_bnds || isa(solver.subsolver, TRDHSolver) #TODO elsewhere ?
376-
update_bounds!(l_bound_m_x, u_bound_m_x, false, l_bound, u_bound, xk, Δk)
377-
set_bounds!(ψ, l_bound_m_x, u_bound_m_x)
378-
set_bounds!(solver.subsolver.ψ, l_bound_m_x, u_bound_m_x)
330+
set_radius!(solver.subpb, ∆_effective)
331+
if has_bnds || isa(solver.subsolver, TRDHSolver)
332+
set_bounds!(solver.subsolver.ψ, ψ.l, ψ.u)
379333
else
380334
set_radius!(solver.subsolver.ψ, ∆_effective)
381-
set_radius!(ψ, ∆_effective)
382335
end
336+
383337
with_logger(subsolver_logger) do
384338
if isa(solver.subsolver, TRDHSolver) #FIXME
385339
solver.subsolver.D.d[1] = 1/ν₁
@@ -415,7 +369,7 @@ function SolverCore.solve!(
415369

416370
fhmax = m_monotone > 1 ? maximum(m_fh_hist) : fk + hk
417371
Δobj = fhmax - (fkn + hkn) + max(1, abs(fk + hk)) * 10 * eps()
418-
ξ = hk - mk(s) + max(1, abs(hk)) * 10 * eps()
372+
ξ = hk - obj(mk, s) + max(1, abs(hk)) * 10 * eps()
419373

420374
if 0 || isnan(ξ))
421375
error("TR: failed to compute a step: ξ = ")
@@ -444,25 +398,20 @@ function SolverCore.solve!(
444398

445399
if η2 ρk < Inf
446400
Δk = max(Δk, γ * sNorm)
401+
set_radius!(solver.subpb, Δk)
447402
if !(has_bnds || isa(solver.subsolver, TRDHSolver))
448-
set_radius!(ψ, Δk)
449403
set_radius!(solver.subsolver.ψ, Δk)
450404
end
451405
end
452406

453407
if η1 ρk < Inf
454408
xk .= xkn
455-
if has_bnds || isa(solver.subsolver, TRDHSolver)
456-
update_bounds!(l_bound_m_x, u_bound_m_x, false, l_bound, u_bound, xk, Δk)
457-
set_bounds!(ψ, l_bound_m_x, u_bound_m_x)
458-
set_bounds!(solver.subsolver.ψ, l_bound_m_x, u_bound_m_x)
459-
end
409+
410+
shift!(solver.subpb, xk, compute_grad = compute_grad)
411+
460412
fk = fkn
461413
hk = hkn
462414

463-
shift!(ψ, xk)
464-
grad!(nlp, xk, ∇fk)
465-
466415
if quasiNewtTest
467416
@. ∇fk⁻ = ∇fk - ∇fk⁻
468417
push!(nlp, s, ∇fk⁻) # update QN operator
@@ -481,12 +430,10 @@ function SolverCore.solve!(
481430

482431
if ρk < η1 || ρk == Inf
483432
Δk = Δk / 2
433+
set_radius!(solver.subpb, Δk)
484434
if has_bnds || isa(solver.subsolver, TRDHSolver)
485-
update_bounds!(l_bound_m_x, u_bound_m_x, false, l_bound, u_bound, xk, Δk)
486-
set_bounds!(ψ, l_bound_m_x, u_bound_m_x)
487-
set_bounds!(solver.subsolver.ψ, l_bound_m_x, u_bound_m_x)
435+
set_bounds!(solver.subsolver.ψ, ψ.l, ψ.u)
488436
else
489-
set_radius!(ψ, Δk)
490437
set_radius!(solver.subsolver.ψ, Δk)
491438
end
492439
set_step_status!(stats, :rejected)
@@ -505,7 +452,7 @@ function SolverCore.solve!(
505452
@. mν∇fk = -ν₁ * ∇fk
506453

507454
prox!(s, ψ, mν∇fk, ν₁)
508-
ξ1 = hk - mk1(s) + max(1, abs(hk)) * 10 * eps()
455+
ξ1 = hk - obj(mk, s, cauchy = true) + max(1, abs(hk)) * 10 * eps()
509456
sqrt_ξ1_νInv = sqrt(ξ1 / ν₁)
510457

511458
solved = (ξ1 < 0 && sqrt_ξ1_νInv neg_tol) || (ξ1 0 && sqrt_ξ1_νInv atol)

0 commit comments

Comments
 (0)