Skip to content

Commit 300cfed

Browse files
committed
Refactor CallbackEvaluator to use template parameters for improved flexibility
1 parent 209b722 commit 300cfed

2 files changed

Lines changed: 110 additions & 98 deletions

File tree

include/pyoptinterface/knitro_model.hpp

Lines changed: 105 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -128,58 +128,107 @@ enum ConstraintSenseFlags
128128
CON_UPBND = 1 << 1, // 0x02
129129
};
130130

131+
template <typename I>
131132
struct CallbackPattern
132133
{
133-
std::vector<KNINT> indexCons;
134-
std::vector<KNINT> objGradIndexVars;
135-
std::vector<KNINT> jacIndexCons;
136-
std::vector<KNINT> jacIndexVars;
137-
std::vector<KNINT> hessIndexVars1;
138-
std::vector<KNINT> hessIndexVars2;
134+
std::vector<I> indexCons;
135+
std::vector<I> objGradIndexVars;
136+
std::vector<I> jacIndexCons;
137+
std::vector<I> jacIndexVars;
138+
std::vector<I> hessIndexVars1;
139+
std::vector<I> hessIndexVars2;
139140
};
140141

141-
template <typename V>
142+
enum class CopyMode
143+
{
144+
Normal,
145+
Aggregate,
146+
Duplicate
147+
};
148+
149+
template <typename T, typename I>
150+
static void copy(const size_t n, const T *src, const I *idx, T *dst,
151+
CopyMode mode = CopyMode::Normal)
152+
{
153+
if (mode == CopyMode::Duplicate)
154+
{
155+
for (size_t i = 0; i < n; i++)
156+
{
157+
dst[i] = src[0];
158+
}
159+
}
160+
else if (mode == CopyMode::Aggregate)
161+
{
162+
dst[0] = T(0.0);
163+
for (size_t i = 0; i < n; i++)
164+
{
165+
dst[0] += src[i];
166+
}
167+
}
168+
else
169+
{
170+
if (idx == nullptr)
171+
{
172+
for (size_t i = 0; i < n; i++)
173+
{
174+
dst[i] = src[i];
175+
}
176+
}
177+
else
178+
{
179+
for (size_t i = 0; i < n; i++)
180+
{
181+
dst[i] = src[idx[i]];
182+
}
183+
}
184+
}
185+
}
186+
187+
using namespace CppAD;
188+
189+
template <typename V, typename S, typename I>
142190
struct CallbackEvaluator
143191
{
192+
144193
static inline constexpr const char *JAC_CLRNG = "cppad";
145194
static inline constexpr const char *HES_CLRNG = "cppad.symmetric";
146-
std::vector<KNINT> indexVars;
147-
std::vector<KNINT> indexCons;
195+
std::vector<I> indexVars;
196+
std::vector<I> indexCons;
148197

149-
CppAD::ADFun<V> fun; /// < CppAD tape.
150-
CppAD::ADFun<V> jfun; /// < CppAD tape for Jacobian
198+
ADFun<V> fun; /// < CppAD tape.
199+
ADFun<V> jfun; /// < CppAD tape for Jacobian
151200

152201
/// Sparsity patterns
153-
CppAD::sparse_rc<std::vector<size_t>> jp;
154-
CppAD::sparse_rc<std::vector<size_t>> hp;
202+
sparse_rc<vector<S>> jp;
203+
sparse_rc<vector<S>> hp;
155204

156205
/// Workspaces for sparse Jacobian and Hessian calculations
157-
CppAD::sparse_jac_work jw;
158-
CppAD::sparse_jac_work hw;
206+
sparse_jac_work jw;
207+
sparse_jac_work hw;
159208

160209
/// Temporary vectors for evaluations
161-
std::vector<V> x;
162-
std::vector<V> w;
163-
std::vector<V> xw;
164-
CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>> jac;
165-
CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>> hes;
210+
vector<V> x;
211+
vector<V> w;
212+
vector<V> xw;
213+
sparse_rcv<vector<S>, vector<V>> jac;
214+
sparse_rcv<vector<S>, vector<V>> hes;
166215

167216
void setup()
168217
{
169218
fun.optimize();
170219
size_t nx = fun.Domain();
171220
size_t ny = fun.Range();
172221

173-
std::vector<bool> dom(nx, true);
174-
std::vector<bool> rng(ny, true);
222+
vector<bool> dom(nx, true);
223+
vector<bool> rng(ny, true);
175224
fun.subgraph_sparsity(dom, rng, false, jp);
176225

177226
auto af = fun.base2ad();
178-
std::vector<CppAD::AD<V>> jaxw(nx + ny);
179-
CppAD::Independent(jaxw);
180-
std::vector<CppAD::AD<V>> jax(nx);
181-
std::vector<CppAD::AD<V>> jaw(ny);
182-
std::vector<CppAD::AD<V>> jaz(nx);
227+
vector<AD<V>> jaxw(nx + ny);
228+
Independent(jaxw);
229+
vector<AD<V>> jax(nx);
230+
vector<AD<V>> jaw(ny);
231+
vector<AD<V>> jaz(nx);
183232
for (size_t i = 0; i < nx; i++)
184233
{
185234
jax[i] = jaxw[i];
@@ -192,21 +241,21 @@ struct CallbackEvaluator
192241
jaz = af.Reverse(1, jaw);
193242
jfun.Dependent(jaxw, jaz);
194243
jfun.optimize();
195-
std::vector<bool> jdom(nx + ny, false);
244+
vector<bool> jdom(nx + ny, false);
196245
for (size_t i = 0; i < nx; i++)
197246
{
198247
jdom[i] = true;
199248
}
200-
std::vector<bool> jrng(nx, true);
201-
CppAD::sparse_rc<std::vector<size_t>> hsp;
249+
vector<bool> jrng(nx, true);
250+
sparse_rc<vector<S>> hsp;
202251
jfun.subgraph_sparsity(jdom, jrng, false, hsp);
203252

204253
auto &hrow = hsp.row();
205254
auto &hcol = hsp.col();
206255
for (size_t k = 0; k < hsp.nnz(); k++)
207256
{
208-
size_t row = hrow[k];
209-
size_t col = hcol[k];
257+
S row = hrow[k];
258+
S col = hcol[k];
210259
if (row <= col)
211260
{
212261
hp.push_back(row, col);
@@ -215,8 +264,8 @@ struct CallbackEvaluator
215264
x.resize(nx);
216265
w.resize(ny);
217266
xw.resize(nx + ny);
218-
jac = CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>>(jp);
219-
hes = CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>>(hp);
267+
jac = sparse_rcv<vector<S>, vector<V>>(jp);
268+
hes = sparse_rcv<vector<S>, vector<V>>(hp);
220269
}
221270

222271
bool is_objective() const
@@ -226,34 +275,31 @@ struct CallbackEvaluator
226275

227276
void eval_fun(const V *req_x, V *res_y)
228277
{
229-
size_t nx = fun.Domain();
230-
size_t ny = fun.Range();
231-
copy(nx, req_x, indexVars.data(), x.data());
278+
copy(fun.Domain(), req_x, indexVars.data(), x.data());
232279
auto y = fun.Forward(0, x);
233-
copy(ny, y.data(), nullptr, res_y, is_objective());
280+
CopyMode mode = is_objective() ? CopyMode::Aggregate : CopyMode::Normal;
281+
copy(fun.Range(), y.data(), (const KNINT *)nullptr, res_y, mode);
234282
}
235283

236284
void eval_jac(const V *req_x, V *res_jac)
237285
{
238-
size_t nx = fun.Domain();
239-
copy(nx, req_x, indexVars.data(), x.data());
286+
copy(fun.Domain(), req_x, indexVars.data(), x.data());
240287
fun.sparse_jac_rev(x, jac, jp, JAC_CLRNG, jw);
241-
copy_vec(jac.nnz(), jac.val().data(), nullptr, res_jac);
288+
copy(jac.nnz(), jac.val().data(), (const I *)nullptr, res_jac);
242289
}
243290

244291
void eval_hess(const V *req_x, const V *req_w, V *res_hess)
245292
{
246-
size_t nx = fun.Domain();
247-
size_t ny = fun.Range();
248-
copy(nx, req_x, indexVars.data(), xw.data());
249-
copy(ny, req_w, indexCons.data(), xw.data() + nx, false, is_objective());
293+
copy(fun.Domain(), req_x, indexVars.data(), xw.data());
294+
CopyMode mode = is_objective() ? CopyMode::Duplicate : CopyMode::Normal;
295+
copy(fun.Range(), req_w, indexCons.data(), xw.data() + fun.Domain(), mode);
250296
jfun.sparse_jac_rev(xw, hes, hp, JAC_CLRNG, hw);
251-
copy_vec(hes.nnz(), hes.val().data(), nullptr, res_hess);
297+
copy(hes.nnz(), hes.val().data(), (const I *)nullptr, res_hess);
252298
}
253299

254-
CallbackPattern get_callback_pattern() const
300+
CallbackPattern<I> get_callback_pattern() const
255301
{
256-
CallbackPattern p;
302+
CallbackPattern<I> p;
257303
p.indexCons = indexCons;
258304

259305
auto &jrow = jp.row();
@@ -284,49 +330,6 @@ struct CallbackEvaluator
284330

285331
return p;
286332
}
287-
288-
private:
289-
template <typename T, typename I>
290-
static void copy(const size_t n, const T *src, const I *idx, V *dst, bool aggregate = false,
291-
bool duplicate = false)
292-
{
293-
if (duplicate)
294-
{
295-
for (size_t i = 0; i < n; i++)
296-
{
297-
dst[i] = src[0];
298-
}
299-
}
300-
else if (aggregate)
301-
{
302-
dst[0] = 0.0;
303-
for (size_t i = 0; i < n; i++)
304-
{
305-
dst[0] += src[i];
306-
}
307-
}
308-
else if (idx == nullptr)
309-
{
310-
for (size_t i = 0; i < n; i++)
311-
{
312-
dst[i] = src[i];
313-
}
314-
}
315-
else
316-
{
317-
for (size_t i = 0; i < n; i++)
318-
{
319-
dst[i] = src[idx[i]];
320-
}
321-
}
322-
}
323-
};
324-
325-
struct Outputs
326-
{
327-
std::vector<size_t> objective_outputs;
328-
std::vector<size_t> constraint_outputs;
329-
std::vector<ConstraintIndex> constraints;
330333
};
331334

332335
inline bool is_name_empty(const char *name)
@@ -610,8 +613,17 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
610613
std::unordered_map<KNINT, uint8_t> m_con_sense_flags;
611614
uint8_t m_obj_flag = 0;
612615

616+
struct Outputs
617+
{
618+
std::vector<size_t> objective_outputs;
619+
std::vector<size_t> constraint_outputs;
620+
std::vector<ConstraintIndex> constraints;
621+
};
622+
623+
using Evaluator = CallbackEvaluator<double, size_t, KNINT>;
624+
613625
std::unordered_map<ExpressionGraph *, Outputs> m_pending_outputs;
614-
std::vector<std::unique_ptr<CallbackEvaluator<double>>> m_evaluators;
626+
std::vector<std::unique_ptr<Evaluator>> m_evaluators;
615627
bool m_has_pending_callbacks = false;
616628
int m_solve_status = 0;
617629
bool m_is_dirty = true;
@@ -637,7 +649,7 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
637649
void _add_callbacks(const ExpressionGraph &graph, const Outputs &outputs);
638650
void _add_callback(const ExpressionGraph &graph, const std::vector<size_t> &outputs,
639651
const std::vector<ConstraintIndex> &constraints);
640-
void _register_callback(CallbackEvaluator<double> *evaluator);
652+
void _register_callback(Evaluator *evaluator);
641653
void _update();
642654
void _pre_solve();
643655
void _solve();

lib/knitro_model.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,11 @@ double KNITROModel::get_obj_value() const
836836
return _get_value<double>(knitro::KN_get_obj_value);
837837
}
838838

839-
void KNITROModel::_register_callback(CallbackEvaluator<double> *evaluator)
839+
void KNITROModel::_register_callback(Evaluator *evaluator)
840840
{
841841
auto f = [](KN_context *, CB_context *cb, KN_eval_request *req, KN_eval_result *res,
842842
void *data) -> int {
843-
auto evaluator = static_cast<CallbackEvaluator<double> *>(data);
843+
auto evaluator = static_cast<Evaluator *>(data);
844844
if (evaluator->is_objective())
845845
{
846846
evaluator->eval_fun(req->x, res->obj);
@@ -854,7 +854,7 @@ void KNITROModel::_register_callback(CallbackEvaluator<double> *evaluator)
854854

855855
auto g = [](KN_context *, CB_context *cb, KN_eval_request *req, KN_eval_result *res,
856856
void *data) -> int {
857-
auto evaluator = static_cast<CallbackEvaluator<double> *>(data);
857+
auto evaluator = static_cast<Evaluator *>(data);
858858
if (evaluator->is_objective())
859859
{
860860
evaluator->eval_jac(req->x, res->objGrad);
@@ -868,7 +868,7 @@ void KNITROModel::_register_callback(CallbackEvaluator<double> *evaluator)
868868

869869
auto h = [](KN_context *, CB_context *cb, KN_eval_request *req, KN_eval_result *res,
870870
void *data) -> int {
871-
auto evaluator = static_cast<CallbackEvaluator<double> *>(data);
871+
auto evaluator = static_cast<Evaluator *>(data);
872872
if (evaluator->is_objective())
873873
{
874874
evaluator->eval_hess(req->x, req->sigma, res->hess);
@@ -900,7 +900,7 @@ void KNITROModel::_register_callback(CallbackEvaluator<double> *evaluator)
900900
void KNITROModel::_add_callback(const ExpressionGraph &graph, const std::vector<size_t> &outputs,
901901
const std::vector<ConstraintIndex> &constraints)
902902
{
903-
auto evaluator_ptr = std::make_unique<CallbackEvaluator<double>>();
903+
auto evaluator_ptr = std::make_unique<Evaluator>();
904904
auto *evaluator = evaluator_ptr.get();
905905
evaluator->indexVars.resize(graph.n_variables());
906906
for (size_t i = 0; i < graph.n_variables(); i++)

0 commit comments

Comments
 (0)