Skip to content

Commit d510a54

Browse files
committed
Refactor CallbackEvaluator to replace CopyMode enum with integer constants for copy modes
1 parent f22659d commit d510a54

1 file changed

Lines changed: 9 additions & 15 deletions

File tree

include/pyoptinterface/knitro_model.hpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ struct CallbackEvaluator
163163

164164
/// Temporary vectors for evaluations
165165
vector<V> x;
166-
vector<V> w;
167166
vector<V> xw;
168167
sparse_rcv<vector<S>, vector<V>> jac;
169168
sparse_rcv<vector<S>, vector<V>> hes;
@@ -217,7 +216,6 @@ struct CallbackEvaluator
217216
}
218217
}
219218
x.resize(nx);
220-
w.resize(ny);
221219
xw.resize(nx + ny);
222220
jac = sparse_rcv<vector<S>, vector<V>>(jp);
223221
hes = sparse_rcv<vector<S>, vector<V>>(hp);
@@ -232,7 +230,7 @@ struct CallbackEvaluator
232230
{
233231
copy(fun.Domain(), req_x, indexVars.data(), x.data());
234232
auto y = fun.Forward(0, x);
235-
CopyMode mode = is_objective() ? CopyMode::Aggregate : CopyMode::Normal;
233+
int mode = is_objective() ? 2 : 0;
236234
copy(fun.Range(), y.data(), (const I *)nullptr, res_y, mode);
237235
}
238236

@@ -246,7 +244,7 @@ struct CallbackEvaluator
246244
void eval_hess(const V *req_x, const V *req_w, V *res_hess)
247245
{
248246
copy(fun.Domain(), req_x, indexVars.data(), xw.data());
249-
CopyMode mode = is_objective() ? CopyMode::Duplicate : CopyMode::Normal;
247+
int mode = is_objective() ? 1 : 0;
250248
copy(fun.Range(), req_w, indexCons.data(), xw.data() + fun.Domain(), mode);
251249
jfun.sparse_jac_rev(xw, hes, hp, JAC_CLRNG, hw);
252250
copy(hes.nnz(), hes.val().data(), (const I *)nullptr, res_hess);
@@ -287,24 +285,20 @@ struct CallbackEvaluator
287285
}
288286

289287
private:
290-
enum class CopyMode
288+
// Copy mode:
289+
// - 0: normal copy
290+
// - 1: duplicate (copy first element of src to all elements of dst)
291+
// - 2: aggregate (sum all elements of src and copy to all elements of dst)
292+
static void copy(const size_t n, const V *src, const I *idx, V *dst, int mode = 0)
291293
{
292-
Normal,
293-
Aggregate,
294-
Duplicate
295-
};
296-
297-
static void copy(const size_t n, const V *src, const I *idx, V *dst,
298-
CopyMode mode = CopyMode::Normal)
299-
{
300-
if (mode == CopyMode::Duplicate)
294+
if (mode == 1)
301295
{
302296
for (size_t i = 0; i < n; i++)
303297
{
304298
dst[i] = src[0];
305299
}
306300
}
307-
else if (mode == CopyMode::Aggregate)
301+
else if (mode == 2)
308302
{
309303
if (n == 0)
310304
{

0 commit comments

Comments
 (0)