Skip to content

Commit 30c6c74

Browse files
authored
[KNITRO] Enhance implementation (#83)
* [CPPAD] Enhance tracing functions to support selected output indices * Refactor callback implementation to use outputs in trace functions
1 parent b9f035f commit 30c6c74

5 files changed

Lines changed: 68 additions & 38 deletions

File tree

include/pyoptinterface/cppad_interface.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ ADFunDouble sparse_hessian(const ADFunDouble &f, const sparsity_pattern_t &patte
3232
const std::vector<double> &p_values);
3333

3434
// Transform ExpressionGraph to CppAD function
35-
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph);
36-
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate = true);
35+
// selected_outputs: indices of outputs to trace, empty means all outputs
36+
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
37+
const std::vector<size_t> &selected_outputs = {});
38+
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate = true,
39+
const std::vector<size_t> &selected_outputs = {});
3740

3841
struct CppADAutodiffGraph
3942
{

include/pyoptinterface/knitro_model.hpp

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ struct CallbackEvaluator
147147
std::vector<KNINT> indexCons;
148148

149149
CppAD::ADFun<V> fun;
150-
151-
std::vector<size_t> fun_rows;
152150
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_;
153151
CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>> jac_;
154152
CppAD::sparse_jac_work jac_work_;
@@ -163,11 +161,10 @@ struct CallbackEvaluator
163161
void setup()
164162
{
165163
fun.optimize();
166-
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_in(fun.Range(), fun_rows.size(),
167-
fun_rows.size());
168-
for (size_t k = 0; k < fun_rows.size(); k++)
164+
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_in(fun.Range(), fun.Range(), fun.Range());
165+
for (size_t k = 0; k < fun.Range(); k++)
169166
{
170-
jac_pattern_in.set(k, fun_rows[k], fun_rows[k]);
167+
jac_pattern_in.set(k, k, k);
171168
}
172169
fun.rev_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_);
173170
jac_pattern_in.resize(fun.Domain(), fun.Domain(), fun.Domain());
@@ -177,11 +174,7 @@ struct CallbackEvaluator
177174
}
178175
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_out;
179176
fun.for_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_out);
180-
std::vector<bool> select_rows(fun.Range(), false);
181-
for (size_t k = 0; k < fun_rows.size(); k++)
182-
{
183-
select_rows[fun_rows[k]] = true;
184-
}
177+
std::vector<bool> select_rows(fun.Range(), true);
185178
fun.rev_hes_sparsity(select_rows, false, true, hess_pattern_);
186179
for (size_t k = 0; k < hess_pattern_.nnz(); k++)
187180
{
@@ -205,15 +198,15 @@ struct CallbackEvaluator
205198
x[i] = req_x[indexVars[i]];
206199
}
207200
auto y = fun.Forward(0, x);
208-
for (size_t k = 0; k < fun_rows.size(); k++)
201+
for (size_t k = 0; k < fun.Range(); k++)
209202
{
210203
if (aggregate)
211204
{
212-
res_y[0] += y[fun_rows[k]];
205+
res_y[0] += y[k];
213206
}
214207
else
215208
{
216-
res_y[k] = y[fun_rows[k]];
209+
res_y[k] = y[k];
217210
}
218211
}
219212
}
@@ -238,15 +231,15 @@ struct CallbackEvaluator
238231
{
239232
x[i] = req_x[indexVars[i]];
240233
}
241-
for (size_t k = 0; k < fun_rows.size(); k++)
234+
for (size_t k = 0; k < fun.Range(); k++)
242235
{
243236
if (aggregate)
244237
{
245-
w[fun_rows[k]] = req_w[0];
238+
w[k] = req_w[0];
246239
}
247240
else
248241
{
249-
w[fun_rows[k]] = req_w[indexCons[k]];
242+
w[k] = req_w[indexCons[k]];
250243
}
251244
}
252245
fun.sparse_hes(x, w, hess_, hess_pattern_, hess_coloring_, hess_work_);
@@ -696,14 +689,12 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
696689
}
697690

698691
template <typename T, typename F, typename G, typename H>
699-
void _add_callback_impl(const ExpressionGraph &graph, const std::vector<size_t> &rows,
700-
const std::vector<ConstraintIndex> cons, const T &trace, const F f,
701-
const G g, const H h)
692+
void _add_callback_impl(const ExpressionGraph &graph, const std::vector<ConstraintIndex> cons,
693+
const T &trace, const F f, const G g, const H h)
702694
{
703695
auto evaluator_ptr = std::make_unique<CallbackEvaluator<double>>();
704696
auto *evaluator = evaluator_ptr.get();
705697
evaluator->fun = trace(graph);
706-
evaluator->fun_rows = rows;
707698
evaluator->indexVars.resize(graph.n_variables());
708699
for (size_t i = 0; i < graph.n_variables(); i++)
709700
{

lib/cppad_interface.cpp

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,8 @@ CppAD::AD<double> cppad_trace_expression(
427427
return result;
428428
}
429429

430-
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph)
430+
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
431+
const std::vector<size_t> &selected_outputs)
431432
{
432433
ankerl::unordered_dense::map<ExpressionHandle, CppAD::AD<double>> seen_expressions;
433434

@@ -453,13 +454,29 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph)
453454
}
454455

455456
auto &outputs = graph.m_constraint_outputs;
456-
auto N_outputs = outputs.size();
457+
458+
std::vector<size_t> indices;
459+
if (selected_outputs.empty())
460+
{
461+
indices.reserve(outputs.size());
462+
for (size_t i = 0; i < outputs.size(); i++)
463+
{
464+
indices.push_back(i);
465+
}
466+
}
467+
else
468+
{
469+
indices = selected_outputs;
470+
}
471+
472+
auto N_outputs = indices.size();
457473
std::vector<CppAD::AD<double>> y(N_outputs);
458474

459-
// Trace the outputs
475+
// Trace the selected outputs
460476
for (size_t i = 0; i < N_outputs; i++)
461477
{
462-
auto &output = outputs[i];
478+
auto idx = indices[i];
479+
auto &output = outputs[idx];
463480
y[i] = cppad_trace_expression(graph, output, x, p, seen_expressions);
464481
}
465482

@@ -469,7 +486,8 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph)
469486
return f;
470487
}
471488

472-
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate)
489+
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate,
490+
const std::vector<size_t> &selected_outputs)
473491
{
474492
ankerl::unordered_dense::map<ExpressionHandle, CppAD::AD<double>> seen_expressions;
475493

@@ -493,13 +511,28 @@ ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggre
493511
}
494512

495513
auto &outputs = graph.m_objective_outputs;
496-
auto N_outputs = outputs.size();
514+
515+
std::vector<size_t> indices;
516+
if (selected_outputs.empty())
517+
{
518+
indices.reserve(outputs.size());
519+
for (size_t i = 0; i < outputs.size(); i++)
520+
{
521+
indices.push_back(i);
522+
}
523+
}
524+
else
525+
{
526+
indices = selected_outputs;
527+
}
528+
529+
auto N_outputs = indices.size();
497530
std::vector<CppAD::AD<double>> y(N_outputs);
498531

499-
// Trace the outputs
500532
for (size_t i = 0; i < N_outputs; i++)
501533
{
502-
auto &output = outputs[i];
534+
auto idx = indices[i];
535+
auto &output = outputs[idx];
503536
y[i] = cppad_trace_expression(graph, output, x, p, seen_expressions);
504537
}
505538

lib/cppad_interface_ext.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ NB_MODULE(cppad_interface_ext, m)
182182
.def_ro("jacobian", &CppADAutodiffGraph::jacobian_graph)
183183
.def_ro("hessian", &CppADAutodiffGraph::hessian_graph);
184184

185-
m.def("cppad_trace_graph_constraints", cppad_trace_graph_constraints);
185+
m.def("cppad_trace_graph_constraints", cppad_trace_graph_constraints, nb::arg("graph"),
186+
nb::arg("selected_outputs") = std::vector<size_t>{});
186187
m.def("cppad_trace_graph_objective", cppad_trace_graph_objective, nb::arg("graph"),
187-
nb::arg("aggregate") = true);
188+
nb::arg("aggregate") = true, nb::arg("selected_outputs") = std::vector<size_t>{});
188189
m.def("cppad_autodiff", &cppad_autodiff);
189190
}

lib/knitro_model.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -856,8 +856,10 @@ void KNITROModel::_add_constraint_callback(ExpressionGraph *graph, const Outputs
856856
evaluator->eval_hess(req->x, req->lambda, res->hess);
857857
return 0;
858858
};
859-
auto trace = cppad_trace_graph_constraints;
860-
_add_callback_impl(*graph, outputs.con_idxs, outputs.cons, trace, f, g, h);
859+
auto trace = [outputs](const ExpressionGraph &graph) {
860+
return cppad_trace_graph_constraints(graph, outputs.con_idxs);
861+
};
862+
_add_callback_impl(*graph, outputs.cons, trace, f, g, h);
861863
}
862864

863865
void KNITROModel::_add_objective_callback(ExpressionGraph *graph, const Outputs &outputs)
@@ -881,10 +883,10 @@ void KNITROModel::_add_objective_callback(ExpressionGraph *graph, const Outputs
881883
evaluator->eval_hess(req->x, req->sigma, res->hess, true);
882884
return 0;
883885
};
884-
auto trace = [](const ExpressionGraph &graph) {
885-
return cppad_trace_graph_objective(graph, false);
886+
auto trace = [outputs](const ExpressionGraph &graph) {
887+
return cppad_trace_graph_objective(graph, true, outputs.obj_idxs);
886888
};
887-
_add_callback_impl(*graph, outputs.obj_idxs, {}, trace, f, g, h);
889+
_add_callback_impl(*graph, {}, trace, f, g, h);
888890
}
889891

890892
void KNITROModel::_add_callbacks()

0 commit comments

Comments
 (0)