Skip to content

Commit d3b358f

Browse files
committed
Refactor callback implementation to use outputs in trace functions
1 parent 9e212a5 commit d3b358f

File tree

2 files changed

+19
-26
lines changed

2 files changed

+19
-26
lines changed

include/pyoptinterface/knitro_model.hpp

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ struct CallbackEvaluator
147147
std::vector<KNINT> indexCons;
148148

149149
CppAD::ADFun<V> fun;
150-
151-
std::vector<size_t> fun_rows;
152150
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_;
153151
CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>> jac_;
154152
CppAD::sparse_jac_work jac_work_;
@@ -163,11 +161,10 @@ struct CallbackEvaluator
163161
void setup()
164162
{
165163
fun.optimize();
166-
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_in(fun.Range(), fun_rows.size(),
167-
fun_rows.size());
168-
for (size_t k = 0; k < fun_rows.size(); k++)
164+
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_in(fun.Range(), fun.Range(), fun.Range());
165+
for (size_t k = 0; k < fun.Range(); k++)
169166
{
170-
jac_pattern_in.set(k, fun_rows[k], fun_rows[k]);
167+
jac_pattern_in.set(k, k, k);
171168
}
172169
fun.rev_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_);
173170
jac_pattern_in.resize(fun.Domain(), fun.Domain(), fun.Domain());
@@ -177,11 +174,7 @@ struct CallbackEvaluator
177174
}
178175
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_out;
179176
fun.for_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_out);
180-
std::vector<bool> select_rows(fun.Range(), false);
181-
for (size_t k = 0; k < fun_rows.size(); k++)
182-
{
183-
select_rows[fun_rows[k]] = true;
184-
}
177+
std::vector<bool> select_rows(fun.Range(), true);
185178
fun.rev_hes_sparsity(select_rows, false, true, hess_pattern_);
186179
for (size_t k = 0; k < hess_pattern_.nnz(); k++)
187180
{
@@ -205,15 +198,15 @@ struct CallbackEvaluator
205198
x[i] = req_x[indexVars[i]];
206199
}
207200
auto y = fun.Forward(0, x);
208-
for (size_t k = 0; k < fun_rows.size(); k++)
201+
for (size_t k = 0; k < fun.Range(); k++)
209202
{
210203
if (aggregate)
211204
{
212-
res_y[0] += y[fun_rows[k]];
205+
res_y[0] += y[k];
213206
}
214207
else
215208
{
216-
res_y[k] = y[fun_rows[k]];
209+
res_y[k] = y[k];
217210
}
218211
}
219212
}
@@ -238,15 +231,15 @@ struct CallbackEvaluator
238231
{
239232
x[i] = req_x[indexVars[i]];
240233
}
241-
for (size_t k = 0; k < fun_rows.size(); k++)
234+
for (size_t k = 0; k < fun.Range(); k++)
242235
{
243236
if (aggregate)
244237
{
245-
w[fun_rows[k]] = req_w[0];
238+
w[k] = req_w[0];
246239
}
247240
else
248241
{
249-
w[fun_rows[k]] = req_w[indexCons[k]];
242+
w[k] = req_w[indexCons[k]];
250243
}
251244
}
252245
fun.sparse_hes(x, w, hess_, hess_pattern_, hess_coloring_, hess_work_);
@@ -696,14 +689,12 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
696689
}
697690

698691
template <typename T, typename F, typename G, typename H>
699-
void _add_callback_impl(const ExpressionGraph &graph, const std::vector<size_t> &rows,
700-
const std::vector<ConstraintIndex> cons, const T &trace, const F f,
701-
const G g, const H h)
692+
void _add_callback_impl(const ExpressionGraph &graph, const std::vector<ConstraintIndex> cons,
693+
const T &trace, const F f, const G g, const H h)
702694
{
703695
auto evaluator_ptr = std::make_unique<CallbackEvaluator<double>>();
704696
auto *evaluator = evaluator_ptr.get();
705697
evaluator->fun = trace(graph);
706-
evaluator->fun_rows = rows;
707698
evaluator->indexVars.resize(graph.n_variables());
708699
for (size_t i = 0; i < graph.n_variables(); i++)
709700
{

lib/knitro_model.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -856,8 +856,10 @@ void KNITROModel::_add_constraint_callback(ExpressionGraph *graph, const Outputs
856856
evaluator->eval_hess(req->x, req->lambda, res->hess);
857857
return 0;
858858
};
859-
auto trace = cppad_trace_graph_constraints;
860-
_add_callback_impl(*graph, outputs.con_idxs, outputs.cons, trace, f, g, h);
859+
auto trace = [outputs](const ExpressionGraph &graph) {
860+
return cppad_trace_graph_constraints(graph, outputs.con_idxs);
861+
};
862+
_add_callback_impl(*graph, outputs.cons, trace, f, g, h);
861863
}
862864

863865
void KNITROModel::_add_objective_callback(ExpressionGraph *graph, const Outputs &outputs)
@@ -881,10 +883,10 @@ void KNITROModel::_add_objective_callback(ExpressionGraph *graph, const Outputs
881883
evaluator->eval_hess(req->x, req->sigma, res->hess, true);
882884
return 0;
883885
};
884-
auto trace = [](const ExpressionGraph &graph) {
885-
return cppad_trace_graph_objective(graph, false);
886+
auto trace = [outputs](const ExpressionGraph &graph) {
887+
return cppad_trace_graph_objective(graph, true, outputs.obj_idxs);
886888
};
887-
_add_callback_impl(*graph, outputs.obj_idxs, {}, trace, f, g, h);
889+
_add_callback_impl(*graph, {}, trace, f, g, h);
888890
}
889891

890892
void KNITROModel::_add_callbacks()

0 commit comments

Comments
 (0)