Skip to content

Commit 4f24f31

Browse files
committed
Refactor callback implementation to use outputs in trace functions
1 parent 60d6d75 commit 4f24f31

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

include/pyoptinterface/knitro_model.hpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,50 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
674674
m_is_dirty = true;
675675
}
676676

677+
template <typename F, typename G, typename H>
678+
void _register_callback(CallbackEvaluator<double> *evaluator, const F f, const G g, const H h)
679+
{
680+
CB_context *cb = nullptr;
681+
auto p = evaluator->get_callback_pattern();
682+
int error;
683+
error = knitro::KN_add_eval_callback(m_kc.get(), p.indexCons.empty(), p.indexCons.size(),
684+
p.indexCons.data(), f, &cb);
685+
_check_error(error);
686+
error = knitro::KN_set_cb_user_params(m_kc.get(), cb, evaluator);
687+
_check_error(error);
688+
error = knitro::KN_set_cb_grad(m_kc.get(), cb, p.objGradIndexVars.size(),
689+
p.objGradIndexVars.data(), p.jacIndexCons.size(),
690+
p.jacIndexCons.data(), p.jacIndexVars.data(), g);
691+
_check_error(error);
692+
error = knitro::KN_set_cb_hess(m_kc.get(), cb, p.hessIndexVars1.size(),
693+
p.hessIndexVars1.data(), p.hessIndexVars2.data(), h);
694+
_check_error(error);
695+
}
696+
697+
template <typename T, typename F, typename G, typename H>
698+
void _add_callback_impl(const ExpressionGraph &graph, const std::vector<size_t> &rows,
699+
const std::vector<ConstraintIndex> cons, const T &trace, const F f,
700+
const G g, const H h)
701+
{
702+
auto evaluator_ptr = std::make_unique<CallbackEvaluator<double>>();
703+
auto *evaluator = evaluator_ptr.get();
704+
evaluator->fun = trace(graph);
705+
evaluator->fun_rows = rows;
706+
evaluator->indexVars.resize(graph.n_variables());
707+
for (size_t i = 0; i < graph.n_variables(); i++)
708+
{
709+
evaluator->indexVars[i] = _variable_index(graph.m_variables[i]);
710+
}
711+
evaluator->indexCons.resize(cons.size());
712+
for (size_t i = 0; i < cons.size(); i++)
713+
{
714+
evaluator->indexCons[i] = _constraint_index(cons[i]);
715+
}
716+
evaluator->setup();
717+
_register_callback(evaluator, f, g, h);
718+
m_evaluators.push_back(std::move(evaluator_ptr));
719+
}
720+
677721
template <typename V>
678722
using Getter = std::function<int(KN_context *, V *)>;
679723
template <typename V>

0 commit comments

Comments
 (0)