Skip to content

Commit 60f18f7

Browse files
committed
Refactor callback handling in KNITROModel to improve clarity and maintainability
1 parent d3b358f commit 60f18f7

2 files changed

Lines changed: 173 additions & 167 deletions

File tree

include/pyoptinterface/knitro_model.hpp

Lines changed: 72 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -161,93 +161,60 @@ struct CallbackEvaluator
161161
void setup()
162162
{
163163
fun.optimize();
164-
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_in(fun.Range(), fun.Range(), fun.Range());
165-
for (size_t k = 0; k < fun.Range(); k++)
166-
{
167-
jac_pattern_in.set(k, k, k);
168-
}
169-
fun.rev_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_);
170-
jac_pattern_in.resize(fun.Domain(), fun.Domain(), fun.Domain());
171-
for (size_t i = 0; i < fun.Domain(); i++)
164+
auto nx = fun.Domain();
165+
auto ny = fun.Range();
166+
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_in(nx, nx, nx);
167+
for (size_t i = 0; i < nx; i++)
172168
{
173169
jac_pattern_in.set(i, i, i);
174170
}
175-
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_out;
176-
fun.for_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_out);
177-
std::vector<bool> select_rows(fun.Range(), true);
171+
fun.for_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_);
172+
std::vector<bool> select_rows(ny, true);
178173
fun.rev_hes_sparsity(select_rows, false, true, hess_pattern_);
174+
auto &hess_rows = hess_pattern_.row();
175+
auto &hess_cols = hess_pattern_.col();
179176
for (size_t k = 0; k < hess_pattern_.nnz(); k++)
180177
{
181-
size_t row = hess_pattern_.row()[k];
182-
size_t col = hess_pattern_.col()[k];
178+
size_t row = hess_rows[k];
179+
size_t col = hess_cols[k];
183180
if (row <= col)
184181
{
185182
hess_pattern_symm_.push_back(row, col);
186183
}
187184
}
188-
x.resize(fun.Domain(), 0.0);
189-
w.resize(fun.Range(), 0.0);
185+
x.resize(nx, 0.0);
186+
w.resize(ny, 0.0);
190187
jac_ = CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>>(jac_pattern_);
191188
hess_ = CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>>(hess_pattern_symm_);
192189
}
193190

194-
void eval_fun(const V *req_x, V *res_y, bool aggregate = false)
191+
bool is_objective() const
195192
{
196-
for (size_t i = 0; i < indexVars.size(); i++)
197-
{
198-
x[i] = req_x[indexVars[i]];
199-
}
193+
return indexCons.empty();
194+
}
195+
196+
void eval_fun(const V *req_x, V *res_y)
197+
{
198+
copy_ptr(req_x, indexVars.data(), x);
200199
auto y = fun.Forward(0, x);
201-
for (size_t k = 0; k < fun.Range(); k++)
202-
{
203-
if (aggregate)
204-
{
205-
res_y[0] += y[k];
206-
}
207-
else
208-
{
209-
res_y[k] = y[k];
210-
}
211-
}
200+
copy_vec(y, res_y, is_objective());
212201
}
213202

214203
void eval_jac(const V *req_x, V *res_jac)
215204
{
216-
for (size_t i = 0; i < indexVars.size(); i++)
217-
{
218-
x[i] = req_x[indexVars[i]];
219-
}
205+
copy_ptr(req_x, indexVars.data(), x);
220206
fun.sparse_jac_rev(x, jac_, jac_pattern_, jac_coloring_, jac_work_);
221207
auto &jac = jac_.val();
222-
for (size_t i = 0; i < jac_.nnz(); i++)
223-
{
224-
res_jac[i] = jac[i];
225-
}
208+
copy_vec(jac, res_jac);
226209
}
227210

228-
void eval_hess(const V *req_x, const V *req_w, V *res_hess, bool aggregate = false)
211+
void eval_hess(const V *req_x, const V *req_w, V *res_hess)
229212
{
230-
for (size_t i = 0; i < indexVars.size(); i++)
231-
{
232-
x[i] = req_x[indexVars[i]];
233-
}
234-
for (size_t k = 0; k < fun.Range(); k++)
235-
{
236-
if (aggregate)
237-
{
238-
w[k] = req_w[0];
239-
}
240-
else
241-
{
242-
w[k] = req_w[indexCons[k]];
243-
}
244-
}
213+
copy_ptr(req_x, indexVars.data(), x);
214+
copy_ptr(req_w, indexCons.data(), w, is_objective());
245215
fun.sparse_hes(x, w, hess_, hess_pattern_, hess_coloring_, hess_work_);
246216
auto &hess = hess_.val();
247-
for (size_t i = 0; i < hess_.nnz(); i++)
248-
{
249-
res_hess[i] = hess[i];
250-
}
217+
copy_vec(hess, res_hess);
251218
}
252219

253220
CallbackPattern get_callback_pattern() const
@@ -283,13 +250,50 @@ struct CallbackEvaluator
283250

284251
return pattern;
285252
}
253+
254+
private:
255+
template <typename T, typename I>
256+
static void copy_ptr(const T *src, const I *idx, std::vector<V> &dst, bool duplicate = false)
257+
{
258+
for (size_t i = 0; i < dst.size(); i++)
259+
{
260+
if (duplicate)
261+
{
262+
dst[i] = src[0];
263+
}
264+
else
265+
{
266+
dst[i] = src[idx[i]];
267+
}
268+
}
269+
}
270+
271+
template <typename T>
272+
static void copy_vec(const std::vector<T> &src, T *dst, bool aggregate = false)
273+
{
274+
if (aggregate)
275+
{
276+
dst[0] = 0.0;
277+
}
278+
for (size_t i = 0; i < src.size(); i++)
279+
{
280+
if (aggregate)
281+
{
282+
dst[0] += src[i];
283+
}
284+
else
285+
{
286+
dst[i] = src[i];
287+
}
288+
}
289+
}
286290
};
287291

288292
struct Outputs
289293
{
290-
std::vector<size_t> obj_idxs;
291-
std::vector<size_t> con_idxs;
292-
std::vector<ConstraintIndex> cons;
294+
std::vector<size_t> objective_outputs;
295+
std::vector<size_t> constraint_outputs;
296+
std::vector<ConstraintIndex> constraints;
293297
};
294298

295299
inline bool is_name_empty(const char *name)
@@ -575,7 +579,7 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
575579

576580
std::unordered_map<ExpressionGraph *, Outputs> m_pending_outputs;
577581
std::vector<std::unique_ptr<CallbackEvaluator<double>>> m_evaluators;
578-
bool m_need_to_add_callbacks = false;
582+
bool m_has_pending_callbacks = false;
579583
int m_solve_status = 0;
580584
bool m_is_dirty = true;
581585

@@ -596,9 +600,11 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
596600
void _set_quadratic_objective(const ScalarQuadraticFunction &f);
597601
void _reset_objective();
598602
void _add_graph(ExpressionGraph &graph);
599-
void _add_callbacks();
600-
void _add_constraint_callback(ExpressionGraph *graph, const Outputs &outputs);
601-
void _add_objective_callback(ExpressionGraph *graph, const Outputs &outputs);
603+
void _add_pending_callbacks();
604+
void _add_callbacks(const ExpressionGraph &graph, const Outputs &outputs);
605+
void _add_callback(const ExpressionGraph &graph, const std::vector<size_t> &outputs,
606+
const std::vector<ConstraintIndex> &constraints);
607+
void _register_callback(CallbackEvaluator<double> *evaluator);
602608
void _update();
603609
void _pre_solve();
604610
void _solve();
@@ -668,48 +674,6 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
668674
m_is_dirty = true;
669675
}
670676

671-
template <typename F, typename G, typename H>
672-
void _register_callback(CallbackEvaluator<double> *evaluator, const F f, const G g, const H h)
673-
{
674-
CB_context *cb = nullptr;
675-
auto p = evaluator->get_callback_pattern();
676-
int error;
677-
error = knitro::KN_add_eval_callback(m_kc.get(), p.indexCons.empty(), p.indexCons.size(),
678-
p.indexCons.data(), f, &cb);
679-
_check_error(error);
680-
error = knitro::KN_set_cb_user_params(m_kc.get(), cb, evaluator);
681-
_check_error(error);
682-
error = knitro::KN_set_cb_grad(m_kc.get(), cb, p.objGradIndexVars.size(),
683-
p.objGradIndexVars.data(), p.jacIndexCons.size(),
684-
p.jacIndexCons.data(), p.jacIndexVars.data(), g);
685-
_check_error(error);
686-
error = knitro::KN_set_cb_hess(m_kc.get(), cb, p.hessIndexVars1.size(),
687-
p.hessIndexVars1.data(), p.hessIndexVars2.data(), h);
688-
_check_error(error);
689-
}
690-
691-
template <typename T, typename F, typename G, typename H>
692-
void _add_callback_impl(const ExpressionGraph &graph, const std::vector<ConstraintIndex> cons,
693-
const T &trace, const F f, const G g, const H h)
694-
{
695-
auto evaluator_ptr = std::make_unique<CallbackEvaluator<double>>();
696-
auto *evaluator = evaluator_ptr.get();
697-
evaluator->fun = trace(graph);
698-
evaluator->indexVars.resize(graph.n_variables());
699-
for (size_t i = 0; i < graph.n_variables(); i++)
700-
{
701-
evaluator->indexVars[i] = _variable_index(graph.m_variables[i]);
702-
}
703-
evaluator->indexCons.resize(cons.size());
704-
for (size_t i = 0; i < cons.size(); i++)
705-
{
706-
evaluator->indexCons[i] = _constraint_index(cons[i]);
707-
}
708-
evaluator->setup();
709-
_register_callback(evaluator, f, g, h);
710-
m_evaluators.push_back(std::move(evaluator_ptr));
711-
}
712-
713677
template <typename V>
714678
using Getter = std::function<int(KN_context *, V *)>;
715679
template <typename V>

0 commit comments

Comments
 (0)