Skip to content

Commit 2843e56

Browse files
committed
refactor: replace raw constraint and variable counts with a map and getter methods
1 parent edb4b78 commit 2843e56

4 files changed

Lines changed: 51 additions & 71 deletions

File tree

include/pyoptinterface/knitro_model.hpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -555,15 +555,13 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
555555
KNINT _variable_index(const VariableIndex &variable) const;
556556
KNINT _constraint_index(const ConstraintIndex &constraint) const;
557557

558-
size_t n_vars = 0;
559-
size_t n_cons = 0;
560-
size_t n_lincons = 0;
561-
size_t n_quadcons = 0;
562-
size_t n_soccons = 0;
563-
size_t n_nlcons = 0;
558+
size_t get_num_vars() const;
559+
size_t get_num_cons(std::optional<ConstraintType> type = std::nullopt) const;
564560

565561
private:
566562
// Member variables
563+
size_t m_n_vars = 0;
564+
std::unordered_map<ConstraintType, size_t> m_n_cons_map;
567565
std::shared_ptr<LM_context> m_lm = nullptr;
568566
std::unique_ptr<KN_context, KNITROFreeProblemT> m_kc = nullptr;
569567

@@ -605,7 +603,7 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
605603
template <typename F>
606604
ConstraintIndex _add_constraint_impl(ConstraintType type,
607605
const std::tuple<double, double> &interval,
608-
const char *name, size_t *np, const F &setter)
606+
const char *name, const F &setter)
609607
{
610608
KNINT indexCon;
611609
int error = knitro::KN_add_con(m_kc.get(), &indexCon);
@@ -632,11 +630,16 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
632630

633631
m_con_sense_flags[indexCon] = CON_UPBND;
634632

635-
n_cons++;
636-
if (np != nullptr)
633+
auto it = m_n_cons_map.find(type);
634+
if (it != m_n_cons_map.end())
637635
{
638-
(*np)++;
636+
it->second++;
639637
}
638+
else
639+
{
640+
m_n_cons_map[type] = 1;
641+
}
642+
640643
m_is_dirty = true;
641644

642645
return constraint;

lib/knitro_model.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ VariableIndex KNITROModel::add_variable(VariableDomain domain, double lb, double
203203
_set_name(knitro::KN_set_var_name, indexVar, name);
204204
}
205205

206-
n_vars++;
206+
m_n_vars++;
207207
_mark_dirty();
208208

209209
return variable;
@@ -304,7 +304,7 @@ void KNITROModel::delete_variable(const VariableIndex &variable)
304304
_set_value<KNINT, int>(knitro::KN_set_var_type, indexVar, KN_VARTYPE_CONTINUOUS);
305305
_set_value<KNINT, double>(knitro::KN_set_var_lobnd, indexVar, -get_infinity());
306306
_set_value<KNINT, double>(knitro::KN_set_var_upbnd, indexVar, get_infinity());
307-
n_vars--;
307+
m_n_vars--;
308308
_mark_dirty();
309309
}
310310

@@ -330,7 +330,7 @@ ConstraintIndex KNITROModel::add_linear_constraint(const ScalarAffineFunction &f
330330
auto setter = [this, &function](const ConstraintIndex &constraint) {
331331
_set_linear_constraint(constraint, function);
332332
};
333-
return _add_constraint_impl(ConstraintType::Linear, interval, name, &n_lincons, setter);
333+
return _add_constraint_impl(ConstraintType::Linear, interval, name, setter);
334334
}
335335

336336
ConstraintIndex KNITROModel::add_quadratic_constraint(const ScalarQuadraticFunction &function,
@@ -349,7 +349,7 @@ ConstraintIndex KNITROModel::add_quadratic_constraint(const ScalarQuadraticFunct
349349
auto setter = [this, &function](const ConstraintIndex &constraint) {
350350
_set_quadratic_constraint(constraint, function);
351351
};
352-
return _add_constraint_impl(ConstraintType::Quadratic, interval, name, &n_quadcons, setter);
352+
return _add_constraint_impl(ConstraintType::Quadratic, interval, name, setter);
353353
}
354354

355355
ConstraintIndex KNITROModel::add_second_order_cone_constraint(
@@ -362,17 +362,15 @@ ConstraintIndex KNITROModel::add_second_order_cone_constraint(
362362
_set_second_order_cone_constraint_rotated(constraint, variables);
363363
};
364364
std::pair<double, double> interval = {0.0, get_infinity()};
365-
return _add_constraint_impl(ConstraintType::SecondOrderCone, interval, name, &n_soccons,
366-
setter);
365+
return _add_constraint_impl(ConstraintType::SecondOrderCone, interval, name, setter);
367366
}
368367
else
369368
{
370369
auto setter = [this, &variables](const ConstraintIndex &constraint) {
371370
_set_second_order_cone_constraint(constraint, variables);
372371
};
373372
std::pair<double, double> interval = {0.0, get_infinity()};
374-
return _add_constraint_impl(ConstraintType::SecondOrderCone, interval, name, &n_soccons,
375-
setter);
373+
return _add_constraint_impl(ConstraintType::SecondOrderCone, interval, name, setter);
376374
}
377375
}
378376

@@ -389,7 +387,7 @@ ConstraintIndex KNITROModel::add_single_nl_constraint(ExpressionGraph &graph,
389387
m_pending_outputs[&graph].cons.push_back(constraint);
390388
m_need_to_add_callbacks = true;
391389
};
392-
return _add_constraint_impl(ConstraintType::NL, interval, name, &n_nlcons, setter);
390+
return _add_constraint_impl(ConstraintType::NL, interval, name, setter);
393391
}
394392

395393
ConstraintIndex KNITROModel::add_single_nl_constraint_sense_rhs(ExpressionGraph &graph,
@@ -514,24 +512,7 @@ void KNITROModel::delete_constraint(const ConstraintIndex &constraint)
514512
_set_value<KNINT, double>(knitro::KN_set_con_lobnd, indexCon, -get_infinity());
515513
_set_value<KNINT, double>(knitro::KN_set_con_upbnd, indexCon, get_infinity());
516514

517-
n_cons--;
518-
switch (constraint.type)
519-
{
520-
case ConstraintType::Linear:
521-
n_lincons--;
522-
break;
523-
case ConstraintType::Quadratic:
524-
n_quadcons--;
525-
break;
526-
case ConstraintType::SecondOrderCone:
527-
n_soccons--;
528-
break;
529-
case ConstraintType::NL:
530-
n_nlcons--;
531-
break;
532-
default:
533-
break;
534-
}
515+
m_n_cons_map[constraint.type]--;
535516

536517
auto it = m_soc_aux_cons.find(indexCon);
537518
if (it != m_soc_aux_cons.end())
@@ -1007,6 +988,29 @@ bool KNITROModel::empty() const
1007988
return m_kc == nullptr;
1008989
}
1009990

991+
size_t KNITROModel::get_num_vars() const
992+
{
993+
return m_n_vars;
994+
}
995+
996+
size_t KNITROModel::get_num_cons(std::optional<ConstraintType> type) const
997+
{
998+
if (!type.has_value())
999+
{
1000+
size_t total = 0;
1001+
for (const auto &[_, count] : m_n_cons_map)
1002+
{
1003+
total += count;
1004+
}
1005+
return total;
1006+
}
1007+
else
1008+
{
1009+
auto it = m_n_cons_map.find(type.value());
1010+
return it != m_n_cons_map.end() ? it->second : 0;
1011+
}
1012+
}
1013+
10101014
int KNITROModel::get_solve_status() const
10111015
{
10121016
_check_dirty();
@@ -1025,12 +1029,8 @@ void KNITROModel::_check_dirty() const
10251029
void KNITROModel::_reset_state()
10261030
{
10271031
m_kc.reset();
1028-
n_vars = 0;
1029-
n_cons = 0;
1030-
n_lincons = 0;
1031-
n_quadcons = 0;
1032-
n_soccons = 0;
1033-
n_nlcons = 0;
1032+
m_n_vars = 0;
1033+
m_n_cons_map.clear();
10341034
m_soc_aux_cons.clear();
10351035
m_con_sense_flags.clear();
10361036
m_obj_flag = 0;

lib/knitro_model_ext.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <nanobind/nanobind.h>
2+
#include <nanobind/stl/optional.h>
23
#include <nanobind/stl/string.h>
34
#include <nanobind/stl/vector.h>
45
#include <nanobind/stl/tuple.h>
@@ -31,6 +32,7 @@ NB_MODULE(knitro_model_ext, m)
3132
.def(nb::init<const KNITROEnv &>())
3233
.def("init", nb::overload_cast<>(&KNITROModel::init))
3334
.def("init", nb::overload_cast<const KNITROEnv &>(&KNITROModel::init))
35+
3436
// clang-format off
3537
BIND_F(close)
3638
BIND_F(get_infinity)
@@ -43,12 +45,8 @@ NB_MODULE(knitro_model_ext, m)
4345
BIND_F(get_release)
4446
// clang-format on
4547

46-
.def_ro("n_vars", &KNITROModel::n_vars)
47-
.def_ro("n_cons", &KNITROModel::n_cons)
48-
.def_ro("n_lincons", &KNITROModel::n_lincons)
49-
.def_ro("n_quadcons", &KNITROModel::n_quadcons)
50-
.def_ro("n_soccons", &KNITROModel::n_soccons)
51-
.def_ro("n_nlcons", &KNITROModel::n_nlcons)
48+
.def("number_of_variables", &KNITROModel::get_num_vars)
49+
.def("number_of_constraints", &KNITROModel::get_num_cons, nb::arg("type") = nb::none())
5250

5351
.def("add_variable", &KNITROModel::add_variable,
5452
nb::arg("domain") = VariableDomain::Continuous, nb::arg("lb") = -KN_INFINITY,
@@ -162,10 +160,9 @@ NB_MODULE(knitro_model_ext, m)
162160
nb::arg("expr"), nb::arg("sense") = ObjectiveSense::Minimize)
163161
.def("_add_single_nl_objective", &KNITROModel::add_single_nl_objective, nb::arg("graph"),
164162
nb::arg("result"))
165-
.def("set_objective_coefficient", &KNITROModel::set_objective_coefficient,
166-
nb::arg("variable"), nb::arg("coefficient"))
167163

168164
// clang-format off
165+
BIND_F(set_objective_coefficient)
169166
BIND_F(get_obj_value)
170167
BIND_F(set_obj_sense)
171168
BIND_F(get_obj_sense)

src/pyoptinterface/_src/knitro.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .core_ext import (
1818
ConstraintIndex,
1919
ConstraintSense,
20-
ConstraintType,
2120
ExprBuilder,
2221
ScalarAffineFunction,
2322
ScalarQuadraticFunction,
@@ -334,25 +333,6 @@ def supports_model_attribute(
334333
else:
335334
return attribute in model_attribute_get_func_map
336335

337-
def number_of_variables(self) -> int:
338-
return self.n_vars
339-
340-
def number_of_constraints(
341-
self, constraint_type: Union[ConstraintType, None] = None
342-
) -> int:
343-
if constraint_type is None:
344-
return self.n_cons
345-
elif constraint_type == ConstraintType.Linear:
346-
return self.n_lincons
347-
elif constraint_type == ConstraintType.Quadratic:
348-
return self.n_quadcons
349-
elif constraint_type == ConstraintType.SecondOrderCone:
350-
return self.n_soccons
351-
elif constraint_type == ConstraintType.KNITRO_NL:
352-
return self.n_nlcons
353-
else:
354-
raise ValueError(f"Unknown constraint type: {constraint_type}")
355-
356336
@overload
357337
def add_linear_constraint(
358338
self,

0 commit comments

Comments
 (0)