Skip to content

Commit e7e9394

Browse files
committed
refresh diff cache if hessian mode changes between solves
1 parent b15d1f7 commit e7e9394

3 files changed

Lines changed: 50 additions & 10 deletions

File tree

cvxpy/reductions/solvers/nlp_solvers/ipopt_nlpif.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,15 @@ def solve_via_data(self, data, warm_start: bool, verbose: bool, solver_opts, sol
151151
hessian_approx = solver_opts.get('hessian_approximation', 'exact')
152152
use_hessian = (hessian_approx == 'exact')
153153

154-
if solver_cache is None:
155-
oracles = Oracles(bounds.new_problem, verbose=verbose, use_hessian=use_hessian)
156-
elif 'oracles' in solver_cache:
157-
oracles = solver_cache['oracles']
154+
cached = solver_cache.get('oracles') if solver_cache is not None else None
155+
if cached is not None and cached.use_hessian == use_hessian:
156+
oracles = cached
158157
if bounds.new_problem.parameters():
159158
oracles.update_params(bounds.new_problem)
160159
else:
161160
oracles = Oracles(bounds.new_problem, verbose=verbose, use_hessian=use_hessian)
162-
solver_cache['oracles'] = oracles
161+
if solver_cache is not None:
162+
solver_cache['oracles'] = oracles
163163

164164
nlp = cyipopt.Problem(
165165
n=len(data["x0"]),

cvxpy/reductions/solvers/nlp_solvers/knitro_nlpif.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,15 @@ def solve_via_data(self, data, warm_start: bool, verbose: bool, solver_opts, sol
176176
hessopt = solver_opts.get('hessopt', 1) if solver_opts else 1
177177
use_hessian = (hessopt == 1)
178178

179-
if solver_cache is None:
180-
oracles = Oracles(bounds.new_problem, verbose=verbose, use_hessian=use_hessian)
181-
elif 'oracles' in solver_cache:
182-
oracles = solver_cache['oracles']
179+
cached = solver_cache.get('oracles') if solver_cache is not None else None
180+
if cached is not None and cached.use_hessian == use_hessian:
181+
oracles = cached
183182
if bounds.new_problem.parameters():
184183
oracles.update_params(bounds.new_problem)
185184
else:
186185
oracles = Oracles(bounds.new_problem, verbose=verbose, use_hessian=use_hessian)
187-
solver_cache['oracles'] = oracles
186+
if solver_cache is not None:
187+
solver_cache['oracles'] = oracles
188188

189189
# Extract data from the data dictionary
190190
x0 = data["x0"]

cvxpy/tests/nlp_tests/test_quasi_newton.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,43 @@ def test_entropy_lbfgs(self):
197197
hessian_approximation='limited-memory')
198198
# Minimum entropy distribution is concentrated on one point
199199
assert np.sum(q.value > 1e-8) == 1
200+
201+
def test_hessian_mode_switch_rebuilds_cached_oracle(self):
202+
"""Switch hessian_approximation between solves must refresh the diff oracle"""
203+
n = 4
204+
x = cp.Variable(n, bounds=[-1, 1])
205+
np.random.seed(0)
206+
A = cp.Parameter((n, n), value=np.random.rand(n, n))
207+
prob = cp.Problem(cp.Minimize(cp.sum(A @ cp.exp(x))))
208+
209+
prob.solve(solver=cp.IPOPT, nlp=True,
210+
hessian_approximation='limited-memory')
211+
o1 = prob._solver_cache['NLP']['oracles']
212+
assert o1.use_hessian is False
213+
214+
A.value = np.random.rand(n, n)
215+
prob.solve(solver=cp.IPOPT, nlp=True,
216+
hessian_approximation='exact')
217+
o2 = prob._solver_cache['NLP']['oracles']
218+
assert o2 is not o1
219+
assert o2.use_hessian is True
220+
rows, _ = o2.hessianstructure()
221+
assert rows.size > 0
222+
223+
def test_same_hessian_mode_reuses_cached_oracle(self):
224+
"""Same Hessian mode across solves should not refresh cached diff oracle."""
225+
n = 4
226+
x = cp.Variable(n, bounds=[-1, 1])
227+
np.random.seed(0)
228+
A = cp.Parameter((n, n), value=np.random.rand(n, n))
229+
prob = cp.Problem(cp.Minimize(cp.sum(A @ cp.exp(x))))
230+
231+
prob.solve(solver=cp.IPOPT, nlp=True,
232+
hessian_approximation='exact')
233+
o1 = prob._solver_cache['NLP']['oracles']
234+
235+
A.value = np.random.rand(n, n)
236+
prob.solve(solver=cp.IPOPT, nlp=True,
237+
hessian_approximation='exact')
238+
o2 = prob._solver_cache['NLP']['oracles']
239+
assert o2 is o1

0 commit comments

Comments
 (0)