@@ -161,93 +161,60 @@ struct CallbackEvaluator
161161 void setup ()
162162 {
163163 fun.optimize ();
164- CppAD::sparse_rc<std::vector<size_t >> jac_pattern_in (fun.Range (), fun.Range (), fun.Range ());
165- for (size_t k = 0 ; k < fun.Range (); k++)
166- {
167- jac_pattern_in.set (k, k, k);
168- }
169- fun.rev_jac_sparsity (jac_pattern_in, false , false , true , jac_pattern_);
170- jac_pattern_in.resize (fun.Domain (), fun.Domain (), fun.Domain ());
171- for (size_t i = 0 ; i < fun.Domain (); i++)
164+ auto nx = fun.Domain ();
165+ auto ny = fun.Range ();
166+ CppAD::sparse_rc<std::vector<size_t >> jac_pattern_in (nx, nx, nx);
167+ for (size_t i = 0 ; i < nx; i++)
172168 {
173169 jac_pattern_in.set (i, i, i);
174170 }
175- CppAD::sparse_rc<std::vector<size_t >> jac_pattern_out;
176- fun.for_jac_sparsity (jac_pattern_in, false , false , true , jac_pattern_out);
177- std::vector<bool > select_rows (fun.Range (), true );
171+ fun.for_jac_sparsity (jac_pattern_in, false , false , true , jac_pattern_);
172+ std::vector<bool > select_rows (ny, true );
178173 fun.rev_hes_sparsity (select_rows, false , true , hess_pattern_);
174+ auto &hess_rows = hess_pattern_.row ();
175+ auto &hess_cols = hess_pattern_.col ();
179176 for (size_t k = 0 ; k < hess_pattern_.nnz (); k++)
180177 {
181- size_t row = hess_pattern_. row () [k];
182- size_t col = hess_pattern_. col () [k];
178+ size_t row = hess_rows [k];
179+ size_t col = hess_cols [k];
183180 if (row <= col)
184181 {
185182 hess_pattern_symm_.push_back (row, col);
186183 }
187184 }
188- x.resize (fun. Domain () , 0.0 );
189- w.resize (fun. Range () , 0.0 );
185+ x.resize (nx , 0.0 );
186+ w.resize (ny , 0.0 );
190187 jac_ = CppAD::sparse_rcv<std::vector<size_t >, std::vector<V>>(jac_pattern_);
191188 hess_ = CppAD::sparse_rcv<std::vector<size_t >, std::vector<V>>(hess_pattern_symm_);
192189 }
193190
194- void eval_fun ( const V *req_x, V *res_y, bool aggregate = false )
191+ bool is_objective () const
195192 {
196- for (size_t i = 0 ; i < indexVars.size (); i++)
197- {
198- x[i] = req_x[indexVars[i]];
199- }
193+ return indexCons.empty ();
194+ }
195+
196+ void eval_fun (const V *req_x, V *res_y)
197+ {
198+ copy_ptr (req_x, indexVars.data (), x);
200199 auto y = fun.Forward (0 , x);
201- for (size_t k = 0 ; k < fun.Range (); k++)
202- {
203- if (aggregate)
204- {
205- res_y[0 ] += y[k];
206- }
207- else
208- {
209- res_y[k] = y[k];
210- }
211- }
200+ copy_vec (y, res_y, is_objective ());
212201 }
213202
214203 void eval_jac (const V *req_x, V *res_jac)
215204 {
216- for (size_t i = 0 ; i < indexVars.size (); i++)
217- {
218- x[i] = req_x[indexVars[i]];
219- }
205+ copy_ptr (req_x, indexVars.data (), x);
220206 fun.sparse_jac_rev (x, jac_, jac_pattern_, jac_coloring_, jac_work_);
221207 auto &jac = jac_.val ();
222- for (size_t i = 0 ; i < jac_.nnz (); i++)
223- {
224- res_jac[i] = jac[i];
225- }
208+ copy_vec (jac, res_jac);
226209 }
227210
228- void eval_hess (const V *req_x, const V *req_w, V *res_hess, bool aggregate = false )
211+ void eval_hess (const V *req_x, const V *req_w, V *res_hess)
229212 {
230- for (size_t i = 0 ; i < indexVars.size (); i++)
231- {
232- x[i] = req_x[indexVars[i]];
233- }
234- for (size_t k = 0 ; k < fun.Range (); k++)
235- {
236- if (aggregate)
237- {
238- w[k] = req_w[0 ];
239- }
240- else
241- {
242- w[k] = req_w[indexCons[k]];
243- }
244- }
213+ copy_ptr (req_x, indexVars.data (), x);
214+ copy_ptr (req_w, indexCons.data (), w, is_objective ());
245215 fun.sparse_hes (x, w, hess_, hess_pattern_, hess_coloring_, hess_work_);
246216 auto &hess = hess_.val ();
247- for (size_t i = 0 ; i < hess_.nnz (); i++)
248- {
249- res_hess[i] = hess[i];
250- }
217+ copy_vec (hess, res_hess);
251218 }
252219
253220 CallbackPattern get_callback_pattern () const
@@ -283,13 +250,50 @@ struct CallbackEvaluator
283250
284251 return pattern;
285252 }
253+
254+ private:
255+ template <typename T, typename I>
256+ static void copy_ptr (const T *src, const I *idx, std::vector<V> &dst, bool duplicate = false )
257+ {
258+ for (size_t i = 0 ; i < dst.size (); i++)
259+ {
260+ if (duplicate)
261+ {
262+ dst[i] = src[0 ];
263+ }
264+ else
265+ {
266+ dst[i] = src[idx[i]];
267+ }
268+ }
269+ }
270+
271+ template <typename T>
272+ static void copy_vec (const std::vector<T> &src, T *dst, bool aggregate = false )
273+ {
274+ if (aggregate)
275+ {
276+ dst[0 ] = 0.0 ;
277+ }
278+ for (size_t i = 0 ; i < src.size (); i++)
279+ {
280+ if (aggregate)
281+ {
282+ dst[0 ] += src[i];
283+ }
284+ else
285+ {
286+ dst[i] = src[i];
287+ }
288+ }
289+ }
286290};
287291
288292struct Outputs
289293{
290- std::vector<size_t > obj_idxs ;
291- std::vector<size_t > con_idxs ;
292- std::vector<ConstraintIndex> cons ;
294+ std::vector<size_t > objective_outputs ;
295+ std::vector<size_t > constraint_outputs ;
296+ std::vector<ConstraintIndex> constraints ;
293297};
294298
295299inline bool is_name_empty (const char *name)
@@ -575,7 +579,7 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
575579
576580 std::unordered_map<ExpressionGraph *, Outputs> m_pending_outputs;
577581 std::vector<std::unique_ptr<CallbackEvaluator<double >>> m_evaluators;
578- bool m_need_to_add_callbacks = false ;
582+ bool m_has_pending_callbacks = false ;
579583 int m_solve_status = 0 ;
580584 bool m_is_dirty = true ;
581585
@@ -596,9 +600,11 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
596600 void _set_quadratic_objective (const ScalarQuadraticFunction &f);
597601 void _reset_objective ();
598602 void _add_graph (ExpressionGraph &graph);
599- void _add_callbacks ();
600- void _add_constraint_callback (ExpressionGraph *graph, const Outputs &outputs);
601- void _add_objective_callback (ExpressionGraph *graph, const Outputs &outputs);
603+ void _add_pending_callbacks ();
604+ void _add_callbacks (const ExpressionGraph &graph, const Outputs &outputs);
605+ void _add_callback (const ExpressionGraph &graph, const std::vector<size_t > &outputs,
606+ const std::vector<ConstraintIndex> &constraints);
607+ void _register_callback (CallbackEvaluator<double > *evaluator);
602608 void _update ();
603609 void _pre_solve ();
604610 void _solve ();
@@ -668,48 +674,6 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
668674 m_is_dirty = true ;
669675 }
670676
671- template <typename F, typename G, typename H>
672- void _register_callback (CallbackEvaluator<double > *evaluator, const F f, const G g, const H h)
673- {
674- CB_context *cb = nullptr ;
675- auto p = evaluator->get_callback_pattern ();
676- int error;
677- error = knitro::KN_add_eval_callback (m_kc.get (), p.indexCons .empty (), p.indexCons .size (),
678- p.indexCons .data (), f, &cb);
679- _check_error (error);
680- error = knitro::KN_set_cb_user_params (m_kc.get (), cb, evaluator);
681- _check_error (error);
682- error = knitro::KN_set_cb_grad (m_kc.get (), cb, p.objGradIndexVars .size (),
683- p.objGradIndexVars .data (), p.jacIndexCons .size (),
684- p.jacIndexCons .data (), p.jacIndexVars .data (), g);
685- _check_error (error);
686- error = knitro::KN_set_cb_hess (m_kc.get (), cb, p.hessIndexVars1 .size (),
687- p.hessIndexVars1 .data (), p.hessIndexVars2 .data (), h);
688- _check_error (error);
689- }
690-
691- template <typename T, typename F, typename G, typename H>
692- void _add_callback_impl (const ExpressionGraph &graph, const std::vector<ConstraintIndex> cons,
693- const T &trace, const F f, const G g, const H h)
694- {
695- auto evaluator_ptr = std::make_unique<CallbackEvaluator<double >>();
696- auto *evaluator = evaluator_ptr.get ();
697- evaluator->fun = trace (graph);
698- evaluator->indexVars .resize (graph.n_variables ());
699- for (size_t i = 0 ; i < graph.n_variables (); i++)
700- {
701- evaluator->indexVars [i] = _variable_index (graph.m_variables [i]);
702- }
703- evaluator->indexCons .resize (cons.size ());
704- for (size_t i = 0 ; i < cons.size (); i++)
705- {
706- evaluator->indexCons [i] = _constraint_index (cons[i]);
707- }
708- evaluator->setup ();
709- _register_callback (evaluator, f, g, h);
710- m_evaluators.push_back (std::move (evaluator_ptr));
711- }
712-
713677 template <typename V>
714678 using Getter = std::function<int (KN_context *, V *)>;
715679 template <typename V>
0 commit comments