@@ -674,6 +674,50 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
674674 m_is_dirty = true ;
675675 }
676676
677+ template <typename F, typename G, typename H>
678+ void _register_callback (CallbackEvaluator<double > *evaluator, const F f, const G g, const H h)
679+ {
680+ CB_context *cb = nullptr ;
681+ auto p = evaluator->get_callback_pattern ();
682+ int error;
683+ error = knitro::KN_add_eval_callback (m_kc.get (), p.indexCons .empty (), p.indexCons .size (),
684+ p.indexCons .data (), f, &cb);
685+ _check_error (error);
686+ error = knitro::KN_set_cb_user_params (m_kc.get (), cb, evaluator);
687+ _check_error (error);
688+ error = knitro::KN_set_cb_grad (m_kc.get (), cb, p.objGradIndexVars .size (),
689+ p.objGradIndexVars .data (), p.jacIndexCons .size (),
690+ p.jacIndexCons .data (), p.jacIndexVars .data (), g);
691+ _check_error (error);
692+ error = knitro::KN_set_cb_hess (m_kc.get (), cb, p.hessIndexVars1 .size (),
693+ p.hessIndexVars1 .data (), p.hessIndexVars2 .data (), h);
694+ _check_error (error);
695+ }
696+
697+ template <typename T, typename F, typename G, typename H>
698+ void _add_callback_impl (const ExpressionGraph &graph, const std::vector<size_t > &rows,
699+ const std::vector<ConstraintIndex> cons, const T &trace, const F f,
700+ const G g, const H h)
701+ {
702+ auto evaluator_ptr = std::make_unique<CallbackEvaluator<double >>();
703+ auto *evaluator = evaluator_ptr.get ();
704+ evaluator->fun = trace (graph);
705+ evaluator->fun_rows = rows;
706+ evaluator->indexVars .resize (graph.n_variables ());
707+ for (size_t i = 0 ; i < graph.n_variables (); i++)
708+ {
709+ evaluator->indexVars [i] = _variable_index (graph.m_variables [i]);
710+ }
711+ evaluator->indexCons .resize (cons.size ());
712+ for (size_t i = 0 ; i < cons.size (); i++)
713+ {
714+ evaluator->indexCons [i] = _constraint_index (cons[i]);
715+ }
716+ evaluator->setup ();
717+ _register_callback (evaluator, f, g, h);
718+ m_evaluators.push_back (std::move (evaluator_ptr));
719+ }
720+
677721 template <typename V>
678722 using Getter = std::function<int (KN_context *, V *)>;
679723 template <typename V>
0 commit comments