Skip to content

Commit 9444ad7

Browse files
committed
Refactor CallbackEvaluator to unify constant names for clarity and improve code readability
1 parent d510a54 commit 9444ad7

1 file changed

Lines changed: 9 additions & 11 deletions

File tree

include/pyoptinterface/knitro_model.hpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,19 +145,19 @@ template <typename V, typename S, typename I>
145145
struct CallbackEvaluator
146146
{
147147

148-
static inline constexpr const char *JAC_CLRNG = "cppad";
149-
static inline constexpr const char *HES_CLRNG = "cppad.symmetric";
148+
static inline constexpr const char *CLRNG = "cppad";
149+
150150
std::vector<I> indexVars;
151151
std::vector<I> indexCons;
152152

153153
ADFun<V> fun; /// < CppAD tape.
154-
ADFun<V> jfun; /// < CppAD tape for Jacobian
154+
ADFun<V> jfun; /// < CppAD tape for Aggregated Jacobian
155155

156156
/// Sparsity patterns
157157
sparse_rc<vector<S>> jp;
158158
sparse_rc<vector<S>> hp;
159159

160-
/// Workspaces for sparse Jacobian and Hessian calculations
160+
/// Workspaces for Jacobian and Hessian calculations
161161
sparse_jac_work jw;
162162
sparse_jac_work hw;
163163

@@ -177,7 +177,7 @@ struct CallbackEvaluator
177177
vector<bool> rng(ny, true);
178178
fun.subgraph_sparsity(dom, rng, false, jp);
179179

180-
auto af = fun.base2ad();
180+
ADFun<AD<V>, V> af = fun.base2ad();
181181
vector<AD<V>> jaxw(nx + ny);
182182
Independent(jaxw);
183183
vector<AD<V>> jax(nx);
@@ -208,11 +208,9 @@ struct CallbackEvaluator
208208
auto &hcol = hsp.col();
209209
for (size_t k = 0; k < hsp.nnz(); k++)
210210
{
211-
S row = hrow[k];
212-
S col = hcol[k];
213-
if (row <= col)
211+
if (hrow[k] <= hcol[k])
214212
{
215-
hp.push_back(row, col);
213+
hp.push_back(hrow[k], hcol[k]);
216214
}
217215
}
218216
x.resize(nx);
@@ -237,7 +235,7 @@ struct CallbackEvaluator
237235
void eval_jac(const V *req_x, V *res_jac)
238236
{
239237
copy(fun.Domain(), req_x, indexVars.data(), x.data());
240-
fun.sparse_jac_rev(x, jac, jp, JAC_CLRNG, jw);
238+
fun.sparse_jac_rev(x, jac, jp, CLRNG, jw);
241239
copy(jac.nnz(), jac.val().data(), (const I *)nullptr, res_jac);
242240
}
243241

@@ -246,7 +244,7 @@ struct CallbackEvaluator
246244
copy(fun.Domain(), req_x, indexVars.data(), xw.data());
247245
int mode = is_objective() ? 1 : 0;
248246
copy(fun.Range(), req_w, indexCons.data(), xw.data() + fun.Domain(), mode);
249-
jfun.sparse_jac_rev(xw, hes, hp, JAC_CLRNG, hw);
247+
jfun.sparse_jac_rev(xw, hes, hp, CLRNG, hw);
250248
copy(hes.nnz(), hes.val().data(), (const I *)nullptr, res_hess);
251249
}
252250

0 commit comments

Comments
 (0)