Skip to content

Commit ac6ebad

Browse files
committed
Move CopyMode enum and copy function into CallbackEvaluator for better encapsulation
1 parent 300cfed commit ac6ebad

1 file changed

Lines changed: 50 additions & 46 deletions

File tree

include/pyoptinterface/knitro_model.hpp

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -139,51 +139,6 @@ struct CallbackPattern
139139
std::vector<I> hessIndexVars2;
140140
};
141141

142-
enum class CopyMode
143-
{
144-
Normal,
145-
Aggregate,
146-
Duplicate
147-
};
148-
149-
template <typename T, typename I>
150-
static void copy(const size_t n, const T *src, const I *idx, T *dst,
151-
CopyMode mode = CopyMode::Normal)
152-
{
153-
if (mode == CopyMode::Duplicate)
154-
{
155-
for (size_t i = 0; i < n; i++)
156-
{
157-
dst[i] = src[0];
158-
}
159-
}
160-
else if (mode == CopyMode::Aggregate)
161-
{
162-
dst[0] = T(0.0);
163-
for (size_t i = 0; i < n; i++)
164-
{
165-
dst[0] += src[i];
166-
}
167-
}
168-
else
169-
{
170-
if (idx == nullptr)
171-
{
172-
for (size_t i = 0; i < n; i++)
173-
{
174-
dst[i] = src[i];
175-
}
176-
}
177-
else
178-
{
179-
for (size_t i = 0; i < n; i++)
180-
{
181-
dst[i] = src[idx[i]];
182-
}
183-
}
184-
}
185-
}
186-
187142
using namespace CppAD;
188143

189144
template <typename V, typename S, typename I>
@@ -278,7 +233,7 @@ struct CallbackEvaluator
278233
copy(fun.Domain(), req_x, indexVars.data(), x.data());
279234
auto y = fun.Forward(0, x);
280235
CopyMode mode = is_objective() ? CopyMode::Aggregate : CopyMode::Normal;
281-
copy(fun.Range(), y.data(), (const KNINT *)nullptr, res_y, mode);
236+
copy(fun.Range(), y.data(), (const I *)nullptr, res_y, mode);
282237
}
283238

284239
void eval_jac(const V *req_x, V *res_jac)
@@ -330,6 +285,55 @@ struct CallbackEvaluator
330285

331286
return p;
332287
}
288+
289+
private:
290+
enum class CopyMode
291+
{
292+
Normal,
293+
Aggregate,
294+
Duplicate
295+
};
296+
297+
static void copy(const size_t n, const V *src, const I *idx, V *dst,
298+
CopyMode mode = CopyMode::Normal)
299+
{
300+
if (mode == CopyMode::Duplicate)
301+
{
302+
for (size_t i = 0; i < n; i++)
303+
{
304+
dst[i] = src[0];
305+
}
306+
}
307+
else if (mode == CopyMode::Aggregate)
308+
{
309+
if (n == 0)
310+
{
311+
return;
312+
}
313+
dst[0] = src[0];
314+
for (size_t i = 1; i < n; i++)
315+
{
316+
dst[0] += src[i];
317+
}
318+
}
319+
else
320+
{
321+
if (idx == nullptr)
322+
{
323+
for (size_t i = 0; i < n; i++)
324+
{
325+
dst[i] = src[i];
326+
}
327+
}
328+
else
329+
{
330+
for (size_t i = 0; i < n; i++)
331+
{
332+
dst[i] = src[idx[i]];
333+
}
334+
}
335+
}
336+
}
333337
};
334338

335339
inline bool is_name_empty(const char *name)

0 commit comments

Comments
 (0)