@@ -128,64 +128,95 @@ enum ConstraintSenseFlags
128128 CON_UPBND = 1 << 1 , // 0x02
129129};
130130
131+ template <typename I>
131132struct CallbackPattern
132133{
133- std::vector<KNINT > indexCons;
134- std::vector<KNINT > objGradIndexVars;
135- std::vector<KNINT > jacIndexCons;
136- std::vector<KNINT > jacIndexVars;
137- std::vector<KNINT > hessIndexVars1;
138- std::vector<KNINT > hessIndexVars2;
134+ std::vector<I > indexCons;
135+ std::vector<I > objGradIndexVars;
136+ std::vector<I > jacIndexCons;
137+ std::vector<I > jacIndexVars;
138+ std::vector<I > hessIndexVars1;
139+ std::vector<I > hessIndexVars2;
139140};
140141
141- template <typename V>
142+ using namespace CppAD ;
143+
144+ template <typename V, typename S, typename I>
142145struct CallbackEvaluator
143146{
144- static inline const std::string jac_coloring_ = " cppad" ;
145- static inline const std::string hess_coloring_ = " cppad.symmetric" ;
146- std::vector<KNINT> indexVars;
147- std::vector<KNINT> indexCons;
148-
149- CppAD::ADFun<V> fun;
150- CppAD::sparse_rc<std::vector<size_t >> jac_pattern_;
151- CppAD::sparse_rcv<std::vector<size_t >, std::vector<V>> jac_;
152- CppAD::sparse_jac_work jac_work_;
153- CppAD::sparse_rc<std::vector<size_t >> hess_pattern_;
154- CppAD::sparse_rc<std::vector<size_t >> hess_pattern_symm_;
155- CppAD::sparse_rcv<std::vector<size_t >, std::vector<V>> hess_;
156- CppAD::sparse_hes_work hess_work_;
157-
158- std::vector<V> x;
159- std::vector<V> w;
147+
148+ static inline constexpr const char *CLRNG = " cppad" ;
149+
150+ std::vector<I> indexVars;
151+ std::vector<I> indexCons;
152+
153+ ADFun<V> fun; // / < CppAD tape.
154+ ADFun<V> jfun; // / < CppAD tape for Aggregated Jacobian
155+
156+ // / Sparsity patterns
157+ sparse_rc<vector<S>> jp;
158+ sparse_rc<vector<S>> hp;
159+
160+ // / Workspaces for Jacobian and Hessian calculations
161+ sparse_jac_work jw;
162+ sparse_jac_work hw;
163+
164+ // / Temporary vectors for evaluations
165+ vector<V> x;
166+ vector<V> xw;
167+ sparse_rcv<vector<S>, vector<V>> jac;
168+ sparse_rcv<vector<S>, vector<V>> hes;
160169
161170 void setup ()
162171 {
163172 fun.optimize ();
164- auto nx = fun.Domain ();
165- auto ny = fun.Range ();
166- CppAD::sparse_rc<std::vector<size_t >> jac_pattern_in (nx, nx, nx);
173+ size_t nx = fun.Domain ();
174+ size_t ny = fun.Range ();
175+
176+ vector<bool > dom (nx, true );
177+ vector<bool > rng (ny, true );
178+ fun.subgraph_sparsity (dom, rng, false , jp);
179+
180+ ADFun<AD<V>, V> af = fun.base2ad ();
181+ vector<AD<V>> jaxw (nx + ny);
182+ Independent (jaxw);
183+ vector<AD<V>> jax (nx);
184+ vector<AD<V>> jaw (ny);
185+ vector<AD<V>> jaz (nx);
167186 for (size_t i = 0 ; i < nx; i++)
168187 {
169- jac_pattern_in. set (i, i, i) ;
188+ jax[i] = jaxw[i] ;
170189 }
171- fun.for_jac_sparsity (jac_pattern_in, false , false , false , jac_pattern_);
172- std::vector<bool > select_rows (ny, true );
173- fun.rev_hes_sparsity (select_rows, false , false , hess_pattern_);
174- auto &hess_rows = hess_pattern_.row ();
175- auto &hess_cols = hess_pattern_.col ();
176- for (size_t k = 0 ; k < hess_pattern_.nnz (); k++)
190+ for (size_t i = 0 ; i < ny; i++)
177191 {
178- size_t row = hess_rows[k];
179- size_t col = hess_cols[k];
180- if (row <= col)
192+ jaw[i] = jaxw[nx + i];
193+ }
194+ af.Forward (0 , jax);
195+ jaz = af.Reverse (1 , jaw);
196+ jfun.Dependent (jaxw, jaz);
197+ jfun.optimize ();
198+ vector<bool > jdom (nx + ny, false );
199+ for (size_t i = 0 ; i < nx; i++)
200+ {
201+ jdom[i] = true ;
202+ }
203+ vector<bool > jrng (nx, true );
204+ sparse_rc<vector<S>> hsp;
205+ jfun.subgraph_sparsity (jdom, jrng, false , hsp);
206+
207+ auto &hrow = hsp.row ();
208+ auto &hcol = hsp.col ();
209+ for (size_t k = 0 ; k < hsp.nnz (); k++)
210+ {
211+ if (hrow[k] <= hcol[k])
181212 {
182- hess_pattern_symm_ .push_back (row, col );
213+ hp .push_back (hrow[k], hcol[k] );
183214 }
184215 }
185- x.resize (nx, 0.0 );
186- w .resize (ny, 0.0 );
187- jac_ = CppAD:: sparse_rcv<std:: vector<size_t >, std:: vector<V>>(jac_pattern_ );
188- hess_ = CppAD:: sparse_rcv<std:: vector<size_t >, std:: vector<V>>(hess_pattern_symm_ );
216+ x.resize (nx);
217+ xw .resize (nx + ny );
218+ jac = sparse_rcv<vector<S >, vector<V>>(jp );
219+ hes = sparse_rcv<vector<S >, vector<V>>(hp );
189220 }
190221
191222 bool is_objective () const
@@ -195,107 +226,108 @@ struct CallbackEvaluator
195226
196227 void eval_fun (const V *req_x, V *res_y)
197228 {
198- copy_ptr ( req_x, indexVars.data (), x);
229+ copy (fun. Domain (), req_x, indexVars.data (), x. data () );
199230 auto y = fun.Forward (0 , x);
200- copy_vec (y, res_y, is_objective ());
231+ int mode = is_objective () ? 2 : 0 ;
232+ copy (fun.Range (), y.data (), (const I *)nullptr , res_y, mode);
201233 }
202234
203235 void eval_jac (const V *req_x, V *res_jac)
204236 {
205- copy_ptr (req_x, indexVars.data (), x);
206- fun.sparse_jac_rev (x, jac_, jac_pattern_, jac_coloring_, jac_work_);
207- auto &jac = jac_.val ();
208- copy_vec (jac, res_jac);
237+ copy (fun.Domain (), req_x, indexVars.data (), x.data ());
238+ fun.sparse_jac_rev (x, jac, jp, CLRNG, jw);
239+ copy (jac.nnz (), jac.val ().data (), (const I *)nullptr , res_jac);
209240 }
210241
211242 void eval_hess (const V *req_x, const V *req_w, V *res_hess)
212243 {
213- copy_ptr ( req_x, indexVars.data (), x );
214- copy_ptr (req_w, indexCons. data (), w, is_objective ()) ;
215- fun.sparse_hes (x, w, hess_, hess_pattern_, hess_coloring_, hess_work_ );
216- auto &hess = hess_. val ( );
217- copy_vec (hess , res_hess);
244+ copy (fun. Domain (), req_x, indexVars.data (), xw. data () );
245+ int mode = is_objective () ? 1 : 0 ;
246+ copy ( fun.Range (), req_w, indexCons. data (), xw. data () + fun. Domain (), mode );
247+ jfun. sparse_jac_rev (xw, hes, hp, CLRNG, hw );
248+ copy (hes. nnz (), hes. val (). data (), ( const I *) nullptr , res_hess);
218249 }
219250
220- CallbackPattern get_callback_pattern () const
251+ CallbackPattern<I> get_callback_pattern () const
221252 {
222- CallbackPattern pattern ;
223- pattern .indexCons = indexCons;
253+ CallbackPattern<I> p ;
254+ p .indexCons = indexCons;
224255
225- auto &jac_rows = jac_pattern_ .row ();
226- auto &jac_cols = jac_pattern_ .col ();
256+ auto &jrow = jp .row ();
257+ auto &jcol = jp .col ();
227258 if (indexCons.empty ())
228259 {
229- for (size_t k = 0 ; k < jac_pattern_ .nnz (); k++)
260+ for (size_t k = 0 ; k < jp .nnz (); k++)
230261 {
231- pattern .objGradIndexVars .push_back (indexVars[jac_cols [k]]);
262+ p .objGradIndexVars .push_back (indexVars[jcol [k]]);
232263 }
233264 }
234265 else
235266 {
236- for (size_t k = 0 ; k < jac_pattern_ .nnz (); k++)
267+ for (size_t k = 0 ; k < jp .nnz (); k++)
237268 {
238- pattern .jacIndexCons .push_back (indexCons[jac_rows [k]]);
239- pattern .jacIndexVars .push_back (indexVars[jac_cols [k]]);
269+ p .jacIndexCons .push_back (indexCons[jrow [k]]);
270+ p .jacIndexVars .push_back (indexVars[jcol [k]]);
240271 }
241272 }
242273
243- auto &hess_rows = hess_pattern_symm_ .row ();
244- auto &hess_cols = hess_pattern_symm_ .col ();
245- for (size_t k = 0 ; k < hess_pattern_symm_ .nnz (); k++)
274+ auto &hrow = hp .row ();
275+ auto &hcol = hp .col ();
276+ for (size_t k = 0 ; k < hp .nnz (); k++)
246277 {
247- pattern .hessIndexVars1 .push_back (indexVars[hess_rows [k]]);
248- pattern .hessIndexVars2 .push_back (indexVars[hess_cols [k]]);
278+ p .hessIndexVars1 .push_back (indexVars[hrow [k]]);
279+ p .hessIndexVars2 .push_back (indexVars[hcol [k]]);
249280 }
250281
251- return pattern ;
282+ return p ;
252283 }
253284
254285 private:
255- template <typename T, typename I>
256- static void copy_ptr (const T *src, const I *idx, std::vector<V> &dst, bool duplicate = false )
286+ // Copy mode:
287+ // - 0: normal copy
288+ // - 1: duplicate (copy first element of src to all elements of dst)
289+ // - 2: aggregate (sum all elements of src and copy to all elements of dst)
290+ static void copy (const size_t n, const V *src, const I *idx, V *dst, int mode = 0 )
257291 {
258- for ( size_t i = 0 ; i < dst. size (); i++ )
292+ if (mode == 1 )
259293 {
260- if (duplicate )
294+ for ( size_t i = 0 ; i < n; i++ )
261295 {
262296 dst[i] = src[0 ];
263297 }
264- else
265- {
266- dst[i] = src[idx[i]];
267- }
268298 }
269- }
270-
271- template <typename T>
272- static void copy_vec (const std::vector<T> &src, T *dst, bool aggregate = false )
273- {
274- if (aggregate)
299+ else if (mode == 2 )
275300 {
276- dst[0 ] = 0.0 ;
301+ if (n == 0 )
302+ {
303+ return ;
304+ }
305+ dst[0 ] = src[0 ];
306+ for (size_t i = 1 ; i < n; i++)
307+ {
308+ dst[0 ] += src[i];
309+ }
277310 }
278- for ( size_t i = 0 ; i < src. size (); i++)
311+ else
279312 {
280- if (aggregate )
313+ if (idx == nullptr )
281314 {
282- dst[0 ] += src[i];
315+ for (size_t i = 0 ; i < n; i++)
316+ {
317+ dst[i] = src[i];
318+ }
283319 }
284320 else
285321 {
286- dst[i] = src[i];
322+ for (size_t i = 0 ; i < n; i++)
323+ {
324+ dst[i] = src[idx[i]];
325+ }
287326 }
288327 }
289328 }
290329};
291330
292- struct Outputs
293- {
294- std::vector<size_t > objective_outputs;
295- std::vector<size_t > constraint_outputs;
296- std::vector<ConstraintIndex> constraints;
297- };
298-
299331inline bool is_name_empty (const char *name)
300332{
301333 return name == nullptr || name[0 ] == ' \0 ' ;
@@ -577,8 +609,17 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
577609 std::unordered_map<KNINT, uint8_t > m_con_sense_flags;
578610 uint8_t m_obj_flag = 0 ;
579611
612+ struct Outputs
613+ {
614+ std::vector<size_t > objective_outputs;
615+ std::vector<size_t > constraint_outputs;
616+ std::vector<ConstraintIndex> constraints;
617+ };
618+
619+ using Evaluator = CallbackEvaluator<double , size_t , KNINT>;
620+
580621 std::unordered_map<ExpressionGraph *, Outputs> m_pending_outputs;
581- std::vector<std::unique_ptr<CallbackEvaluator< double > >> m_evaluators;
622+ std::vector<std::unique_ptr<Evaluator >> m_evaluators;
582623 bool m_has_pending_callbacks = false ;
583624 int m_solve_status = 0 ;
584625 bool m_is_dirty = true ;
@@ -604,7 +645,7 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
604645 void _add_callbacks (const ExpressionGraph &graph, const Outputs &outputs);
605646 void _add_callback (const ExpressionGraph &graph, const std::vector<size_t > &outputs,
606647 const std::vector<ConstraintIndex> &constraints);
607- void _register_callback (CallbackEvaluator< double > *evaluator);
648+ void _register_callback (Evaluator *evaluator);
608649 void _update ();
609650 void _pre_solve ();
610651 void _solve ();
0 commit comments