Skip to content

Commit f48d68c

Browse files
eminyousknmetab0t
authored andcommitted
Refactor cppad_trace_graph functions to use 'selected' parameter for output indices
1 parent 6a455a5 commit f48d68c

File tree

4 files changed

+16
-14
lines changed

4 files changed

+16
-14
lines changed

include/pyoptinterface/cppad_interface.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ ADFunDouble sparse_hessian(const ADFunDouble &f, const sparsity_pattern_t &patte
3232
const std::vector<double> &p_values);
3333

3434
// Transform ExpressionGraph to CppAD function
35-
// selected_outputs: indices of outputs to trace, empty means all outputs
35+
// selected: indices of outputs to trace, empty means all outputs
3636
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
37-
const std::vector<size_t> &selected_outputs = {});
38-
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate = true,
39-
const std::vector<size_t> &selected_outputs = {});
37+
const std::vector<size_t> &selected = {});
38+
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph,
39+
const std::vector<size_t> &selected = {},
40+
bool aggregate = true);
4041

4142
struct CppADAutodiffGraph
4243
{

lib/cppad_interface.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ CppAD::AD<double> cppad_trace_expression(
428428
}
429429

430430
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
431-
const std::vector<size_t> &selected_outputs)
431+
const std::vector<size_t> &selected)
432432
{
433433
ankerl::unordered_dense::map<ExpressionHandle, CppAD::AD<double>> seen_expressions;
434434

@@ -456,7 +456,7 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
456456
auto &outputs = graph.m_constraint_outputs;
457457

458458
std::vector<size_t> indices;
459-
if (selected_outputs.empty())
459+
if (selected.empty())
460460
{
461461
indices.reserve(outputs.size());
462462
for (size_t i = 0; i < outputs.size(); i++)
@@ -466,7 +466,7 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
466466
}
467467
else
468468
{
469-
indices = selected_outputs;
469+
indices = selected;
470470
}
471471

472472
auto N_outputs = indices.size();
@@ -486,8 +486,8 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
486486
return f;
487487
}
488488

489-
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate,
490-
const std::vector<size_t> &selected_outputs)
489+
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph,
490+
const std::vector<size_t> &selected, bool aggregate)
491491
{
492492
ankerl::unordered_dense::map<ExpressionHandle, CppAD::AD<double>> seen_expressions;
493493

@@ -513,7 +513,7 @@ ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggre
513513
auto &outputs = graph.m_objective_outputs;
514514

515515
std::vector<size_t> indices;
516-
if (selected_outputs.empty())
516+
if (selected.empty())
517517
{
518518
indices.reserve(outputs.size());
519519
for (size_t i = 0; i < outputs.size(); i++)
@@ -523,7 +523,7 @@ ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggre
523523
}
524524
else
525525
{
526-
indices = selected_outputs;
526+
indices = selected;
527527
}
528528

529529
auto N_outputs = indices.size();

lib/cppad_interface_ext.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <nanobind/make_iterator.h>
33
#include <nanobind/stl/vector.h>
44
#include <nanobind/stl/string.h>
5+
#include <nanobind/stl/optional.h>
56

67
namespace nb = nanobind;
78

@@ -183,8 +184,8 @@ NB_MODULE(cppad_interface_ext, m)
183184
.def_ro("hessian", &CppADAutodiffGraph::hessian_graph);
184185

185186
m.def("cppad_trace_graph_constraints", cppad_trace_graph_constraints, nb::arg("graph"),
186-
nb::arg("selected_outputs") = std::vector<size_t>{});
187+
nb::arg("selected") = std::vector<size_t>{});
187188
m.def("cppad_trace_graph_objective", cppad_trace_graph_objective, nb::arg("graph"),
188-
nb::arg("aggregate") = true, nb::arg("selected_outputs") = std::vector<size_t>{});
189+
nb::arg("selected") = std::vector<size_t>{}, nb::arg("aggregate") = true);
189190
m.def("cppad_autodiff", &cppad_autodiff);
190191
}

lib/knitro_model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ void KNITROModel::_add_callback(const ExpressionGraph &graph, const std::vector<
914914
}
915915
if (evaluator->is_objective())
916916
{
917-
evaluator->fun = cppad_trace_graph_objective(graph, true, outputs);
917+
evaluator->fun = cppad_trace_graph_objective(graph, outputs, true);
918918
}
919919
else
920920
{

0 commit comments

Comments
 (0)