Skip to content

Commit b9f035f

Browse files
authored
[CORE] Clean ConstraintType Enum (metab0t#82)
* refactor: use specific cone type instead of general cone * refactor: unify non-linear constraint types under a single enum value * refactor: replace raw constraint and variable counts with a map and getter methods
1 parent c8e7e15 commit b9f035f

12 files changed

Lines changed: 101 additions & 113 deletions

File tree

include/pyoptinterface/core.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,10 @@ enum class ConstraintType
269269
Linear,
270270
Quadratic,
271271
SOS,
272-
Cone,
273-
Gurobi_General,
274-
COPT_ExpCone,
275-
COPT_NL,
276-
IPOPT_NL,
277-
Xpress_Nlp,
278-
KNITRO_NL,
272+
SecondOrderCone,
273+
ExponentialCone,
274+
NL,
275+
SolverDefined,
279276
};
280277

281278
enum class SOSType

include/pyoptinterface/knitro_model.hpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -566,15 +566,13 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
566566
KNINT _variable_index(const VariableIndex &variable) const;
567567
KNINT _constraint_index(const ConstraintIndex &constraint) const;
568568

569-
size_t n_vars = 0;
570-
size_t n_cons = 0;
571-
size_t n_lincons = 0;
572-
size_t n_quadcons = 0;
573-
size_t n_coniccons = 0;
574-
size_t n_nlcons = 0;
569+
size_t get_num_vars() const;
570+
size_t get_num_cons(std::optional<ConstraintType> type = std::nullopt) const;
575571

576572
private:
577573
// Member variables
574+
size_t m_n_vars = 0;
575+
std::unordered_map<ConstraintType, size_t> m_n_cons_map;
578576
std::shared_ptr<LM_context> m_lm = nullptr;
579577
std::unique_ptr<KN_context, KNITROFreeProblemT> m_kc = nullptr;
580578

@@ -616,7 +614,7 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
616614
template <typename F>
617615
ConstraintIndex _add_constraint_impl(ConstraintType type,
618616
const std::tuple<double, double> &interval,
619-
const char *name, size_t *np, const F &setter)
617+
const char *name, const F &setter)
620618
{
621619
KNINT indexCon;
622620
int error = knitro::KN_add_con(m_kc.get(), &indexCon);
@@ -643,11 +641,16 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
643641

644642
m_con_sense_flags[indexCon] = CON_UPBND;
645643

646-
n_cons++;
647-
if (np != nullptr)
644+
auto it = m_n_cons_map.find(type);
645+
if (it != m_n_cons_map.end())
648646
{
649-
(*np)++;
647+
it->second++;
650648
}
649+
else
650+
{
651+
m_n_cons_map[type] = 1;
652+
}
653+
651654
m_is_dirty = true;
652655

653656
return constraint;

include/pyoptinterface/mosek_model.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ class MOSEKModel : public OnesideLinearConstraintMixin<MOSEKModel>,
169169

170170
std::vector<MSKint64t> add_variables_as_afe(const Vector<VariableIndex> &variables);
171171
ConstraintIndex add_variables_in_cone_constraint(const Vector<VariableIndex> &variables,
172-
MSKint64t domain_index, const char *name);
172+
MSKint64t domain_index, ConstraintType type,
173+
const char *name);
173174

174175
ConstraintIndex add_second_order_cone_constraint(const Vector<VariableIndex> &variables,
175176
const char *name, bool rotated = false);

lib/copt_model.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ ConstraintIndex COPTModel::add_second_order_cone_constraint(const Vector<Variabl
420420
const char *name, bool rotated)
421421
{
422422
IndexT index = m_cone_constraint_index.add_index();
423-
ConstraintIndex constraint_index(ConstraintType::Cone, index);
423+
ConstraintIndex constraint_index(ConstraintType::SecondOrderCone, index);
424424

425425
int N = variables.size();
426426
std::vector<int> ind_v(N);
@@ -452,7 +452,7 @@ ConstraintIndex COPTModel::add_exp_cone_constraint(const Vector<VariableIndex> &
452452
const char *name, bool dual)
453453
{
454454
IndexT index = m_exp_cone_constraint_index.add_index();
455-
ConstraintIndex constraint_index(ConstraintType::COPT_ExpCone, index);
455+
ConstraintIndex constraint_index(ConstraintType::ExponentialCone, index);
456456

457457
int N = variables.size();
458458
if (N != 3)
@@ -704,7 +704,7 @@ ConstraintIndex COPTModel::add_single_nl_constraint(ExpressionGraph &graph,
704704

705705
IndexT constraint_index = m_nl_constraint_index.add_index();
706706

707-
ConstraintIndex constraint(ConstraintType::COPT_NL, constraint_index);
707+
ConstraintIndex constraint(ConstraintType::NL, constraint_index);
708708

709709
return constraint;
710710
}
@@ -729,11 +729,11 @@ void COPTModel::delete_constraint(const ConstraintIndex &constraint)
729729
m_sos_constraint_index.delete_index(constraint.index);
730730
error = copt::COPT_DelSOSs(m_model.get(), 1, &constraint_row);
731731
break;
732-
case ConstraintType::Cone:
732+
case ConstraintType::SecondOrderCone:
733733
m_cone_constraint_index.delete_index(constraint.index);
734734
error = copt::COPT_DelCones(m_model.get(), 1, &constraint_row);
735735
break;
736-
case ConstraintType::COPT_ExpCone:
736+
case ConstraintType::ExponentialCone:
737737
m_exp_cone_constraint_index.delete_index(constraint.index);
738738
error = copt::COPT_DelExpCones(m_model.get(), 1, &constraint_row);
739739
break;
@@ -754,9 +754,9 @@ bool COPTModel::is_constraint_active(const ConstraintIndex &constraint)
754754
return m_quadratic_constraint_index.has_index(constraint.index);
755755
case ConstraintType::SOS:
756756
return m_sos_constraint_index.has_index(constraint.index);
757-
case ConstraintType::Cone:
757+
case ConstraintType::SecondOrderCone:
758758
return m_cone_constraint_index.has_index(constraint.index);
759-
case ConstraintType::COPT_ExpCone:
759+
case ConstraintType::ExponentialCone:
760760
return m_exp_cone_constraint_index.has_index(constraint.index);
761761
default:
762762
throw std::runtime_error("Unknown constraint type");
@@ -1036,7 +1036,7 @@ double COPTModel::get_constraint_info(const ConstraintIndex &constraint, const c
10361036
case ConstraintType::Quadratic:
10371037
error = copt::COPT_GetQConstrInfo(m_model.get(), info_name, num, &row, &retval);
10381038
break;
1039-
case ConstraintType::COPT_NL:
1039+
case ConstraintType::NL:
10401040
error = copt::COPT_GetNLConstrInfo(m_model.get(), info_name, num, &row, &retval);
10411041
break;
10421042
default:
@@ -1059,7 +1059,7 @@ std::string COPTModel::get_constraint_name(const ConstraintIndex &constraint)
10591059
case ConstraintType::Quadratic:
10601060
error = copt::COPT_GetQConstrName(m_model.get(), row, NULL, 0, &reqsize);
10611061
break;
1062-
case ConstraintType::COPT_NL:
1062+
case ConstraintType::NL:
10631063
error = copt::COPT_GetNLConstrName(m_model.get(), row, NULL, 0, &reqsize);
10641064
break;
10651065
default:
@@ -1093,7 +1093,7 @@ void COPTModel::set_constraint_name(const ConstraintIndex &constraint, const cha
10931093
case ConstraintType::Quadratic:
10941094
error = copt::COPT_SetQConstrNames(m_model.get(), 1, &row, names);
10951095
break;
1096-
case ConstraintType::COPT_NL:
1096+
case ConstraintType::NL:
10971097
error = copt::COPT_SetNLConstrNames(m_model.get(), 1, &row, names);
10981098
break;
10991099
default:
@@ -1300,11 +1300,11 @@ int COPTModel::_constraint_index(const ConstraintIndex &constraint)
13001300
return m_quadratic_constraint_index.get_index(constraint.index);
13011301
case ConstraintType::SOS:
13021302
return m_sos_constraint_index.get_index(constraint.index);
1303-
case ConstraintType::Cone:
1303+
case ConstraintType::SecondOrderCone:
13041304
return m_cone_constraint_index.get_index(constraint.index);
1305-
case ConstraintType::COPT_ExpCone:
1305+
case ConstraintType::ExponentialCone:
13061306
return m_exp_cone_constraint_index.get_index(constraint.index);
1307-
case ConstraintType::COPT_NL:
1307+
case ConstraintType::NL:
13081308
return m_nl_constraint_index.get_index(constraint.index);
13091309
default:
13101310
throw std::runtime_error("Unknown constraint type");

lib/core_ext.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ NB_MODULE(core_ext, m)
3737
.value("Linear", ConstraintType::Linear)
3838
.value("Quadratic", ConstraintType::Quadratic)
3939
.value("SOS", ConstraintType::SOS)
40-
.value("Cone", ConstraintType::Cone);
40+
.value("SecondOrderCone", ConstraintType::SecondOrderCone)
41+
.value("ExponentialCone", ConstraintType::ExponentialCone);
4142

4243
nb::enum_<SOSType>(m, "SOSType").value("SOS1", SOSType::SOS1).value("SOS2", SOSType::SOS2);
4344

lib/gurobi_model.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ ConstraintIndex GurobiModel::add_single_nl_constraint(const ExpressionGraph &gra
629629

630630
IndexT constraint_index = m_general_constraint_index.add_index();
631631

632-
ConstraintIndex constraint(ConstraintType::Gurobi_General, constraint_index);
632+
ConstraintIndex constraint(ConstraintType::NL, constraint_index);
633633
m_nlcon_resvar_map.emplace(constraint_index, resvar.index);
634634

635635
m_update_flag |= m_general_constraint_creation;
@@ -661,7 +661,7 @@ void GurobiModel::delete_constraint(const ConstraintIndex &constraint)
661661
error = gurobi::GRBdelsos(m_model.get(), 1, &constraint_row);
662662
m_update_flag |= m_sos_constraint_deletion;
663663
break;
664-
case ConstraintType::Gurobi_General: {
664+
case ConstraintType::NL: {
665665
m_general_constraint_index.delete_index(constraint.index);
666666
error = gurobi::GRBdelgenconstrs(m_model.get(), 1, &constraint_row);
667667
// delete the corresponding resvar variable as well
@@ -687,7 +687,7 @@ bool GurobiModel::is_constraint_active(const ConstraintIndex &constraint)
687687
return m_quadratic_constraint_index.has_index(constraint.index);
688688
case ConstraintType::SOS:
689689
return m_sos_constraint_index.has_index(constraint.index);
690-
case ConstraintType::Gurobi_General:
690+
case ConstraintType::NL:
691691
return m_general_constraint_index.has_index(constraint.index);
692692
default:
693693
throw std::runtime_error("Unknown constraint type");
@@ -1253,7 +1253,7 @@ int GurobiModel::_constraint_index(const ConstraintIndex &constraint)
12531253
return m_quadratic_constraint_index.get_index(constraint.index);
12541254
case ConstraintType::SOS:
12551255
return m_sos_constraint_index.get_index(constraint.index);
1256-
case ConstraintType::Gurobi_General:
1256+
case ConstraintType::NL:
12571257
return m_general_constraint_index.get_index(constraint.index);
12581258
default:
12591259
throw std::runtime_error("Unknown constraint type");
@@ -1310,7 +1310,7 @@ void GurobiModel::_update_for_constraint_index(ConstraintType type)
13101310
case ConstraintType::SOS:
13111311
need_update = m_update_flag & (m_sos_constraint_creation | m_sos_constraint_deletion);
13121312
break;
1313-
case ConstraintType::Gurobi_General:
1313+
case ConstraintType::NL:
13141314
need_update =
13151315
m_update_flag & (m_general_constraint_creation | m_general_constraint_deletion);
13161316
break;

lib/ipopt_model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ int IpoptModel::_constraint_internal_index(const ConstraintIndex &constraint)
162162
return constraint.index;
163163
case ConstraintType::Quadratic:
164164
return m_linear_con_evaluator.n_constraints + constraint.index;
165-
case ConstraintType::IPOPT_NL: {
165+
case ConstraintType::NL: {
166166
auto base = m_linear_con_evaluator.n_constraints + m_quadratic_con_evaluator.n_constraints;
167167
auto internal_nl_index = nl_constraint_map_ext2int[constraint.index];
168168
return base + internal_nl_index;
@@ -380,7 +380,7 @@ ConstraintIndex IpoptModel::add_single_nl_constraint(size_t graph_index,
380380

381381
m_is_dirty = true;
382382

383-
return ConstraintIndex(ConstraintType::IPOPT_NL, constraint_index);
383+
return ConstraintIndex(ConstraintType::NL, constraint_index);
384384
}
385385

386386
static bool eval_f(ipindex n, ipnumber *x, bool new_x, ipnumber *obj_value, UserDataPtr user_data)

lib/knitro_model.cpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ VariableIndex KNITROModel::add_variable(VariableDomain domain, double lb, double
203203
_set_name(knitro::KN_set_var_name, indexVar, name);
204204
}
205205

206-
n_vars++;
206+
m_n_vars++;
207207
_mark_dirty();
208208

209209
return variable;
@@ -311,7 +311,7 @@ void KNITROModel::delete_variable(const VariableIndex &variable)
311311
_set_value<KNINT, int>(knitro::KN_set_var_type, indexVar, KN_VARTYPE_CONTINUOUS);
312312
_set_value<KNINT, double>(knitro::KN_set_var_lobnd, indexVar, -get_infinity());
313313
_set_value<KNINT, double>(knitro::KN_set_var_upbnd, indexVar, get_infinity());
314-
n_vars--;
314+
m_n_vars--;
315315
_mark_dirty();
316316
}
317317

@@ -337,7 +337,7 @@ ConstraintIndex KNITROModel::add_linear_constraint(const ScalarAffineFunction &f
337337
auto setter = [this, &function](const ConstraintIndex &constraint) {
338338
_set_linear_constraint(constraint, function);
339339
};
340-
return _add_constraint_impl(ConstraintType::Linear, interval, name, &n_lincons, setter);
340+
return _add_constraint_impl(ConstraintType::Linear, interval, name, setter);
341341
}
342342

343343
ConstraintIndex KNITROModel::add_quadratic_constraint(const ScalarQuadraticFunction &function,
@@ -356,7 +356,7 @@ ConstraintIndex KNITROModel::add_quadratic_constraint(const ScalarQuadraticFunct
356356
auto setter = [this, &function](const ConstraintIndex &constraint) {
357357
_set_quadratic_constraint(constraint, function);
358358
};
359-
return _add_constraint_impl(ConstraintType::Quadratic, interval, name, &n_quadcons, setter);
359+
return _add_constraint_impl(ConstraintType::Quadratic, interval, name, setter);
360360
}
361361

362362
ConstraintIndex KNITROModel::add_second_order_cone_constraint(
@@ -369,15 +369,15 @@ ConstraintIndex KNITROModel::add_second_order_cone_constraint(
369369
_set_second_order_cone_constraint_rotated(constraint, variables);
370370
};
371371
std::pair<double, double> interval = {0.0, get_infinity()};
372-
return _add_constraint_impl(ConstraintType::Cone, interval, name, &n_coniccons, setter);
372+
return _add_constraint_impl(ConstraintType::SecondOrderCone, interval, name, setter);
373373
}
374374
else
375375
{
376376
auto setter = [this, &variables](const ConstraintIndex &constraint) {
377377
_set_second_order_cone_constraint(constraint, variables);
378378
};
379379
std::pair<double, double> interval = {0.0, get_infinity()};
380-
return _add_constraint_impl(ConstraintType::Cone, interval, name, &n_coniccons, setter);
380+
return _add_constraint_impl(ConstraintType::SecondOrderCone, interval, name, setter);
381381
}
382382
}
383383

@@ -394,7 +394,7 @@ ConstraintIndex KNITROModel::add_single_nl_constraint(ExpressionGraph &graph,
394394
m_pending_outputs[&graph].cons.push_back(constraint);
395395
m_need_to_add_callbacks = true;
396396
};
397-
return _add_constraint_impl(ConstraintType::KNITRO_NL, interval, name, &n_nlcons, setter);
397+
return _add_constraint_impl(ConstraintType::NL, interval, name, setter);
398398
}
399399

400400
ConstraintIndex KNITROModel::add_single_nl_constraint_sense_rhs(ExpressionGraph &graph,
@@ -519,24 +519,7 @@ void KNITROModel::delete_constraint(const ConstraintIndex &constraint)
519519
_set_value<KNINT, double>(knitro::KN_set_con_lobnd, indexCon, -get_infinity());
520520
_set_value<KNINT, double>(knitro::KN_set_con_upbnd, indexCon, get_infinity());
521521

522-
n_cons--;
523-
switch (constraint.type)
524-
{
525-
case ConstraintType::Linear:
526-
n_lincons--;
527-
break;
528-
case ConstraintType::Quadratic:
529-
n_quadcons--;
530-
break;
531-
case ConstraintType::Cone:
532-
n_coniccons--;
533-
break;
534-
case ConstraintType::KNITRO_NL:
535-
n_nlcons--;
536-
break;
537-
default:
538-
break;
539-
}
522+
m_n_cons_map[constraint.type]--;
540523

541524
auto it = m_soc_aux_cons.find(indexCon);
542525
if (it != m_soc_aux_cons.end())
@@ -1012,6 +995,29 @@ bool KNITROModel::empty() const
1012995
return m_kc == nullptr;
1013996
}
1014997

998+
size_t KNITROModel::get_num_vars() const
999+
{
1000+
return m_n_vars;
1001+
}
1002+
1003+
size_t KNITROModel::get_num_cons(std::optional<ConstraintType> type) const
1004+
{
1005+
if (!type.has_value())
1006+
{
1007+
size_t total = 0;
1008+
for (const auto &[_, count] : m_n_cons_map)
1009+
{
1010+
total += count;
1011+
}
1012+
return total;
1013+
}
1014+
else
1015+
{
1016+
auto it = m_n_cons_map.find(type.value());
1017+
return it != m_n_cons_map.end() ? it->second : 0;
1018+
}
1019+
}
1020+
10151021
int KNITROModel::get_solve_status() const
10161022
{
10171023
_check_dirty();
@@ -1030,12 +1036,8 @@ void KNITROModel::_check_dirty() const
10301036
void KNITROModel::_reset_state()
10311037
{
10321038
m_kc.reset();
1033-
n_vars = 0;
1034-
n_cons = 0;
1035-
n_lincons = 0;
1036-
n_quadcons = 0;
1037-
n_coniccons = 0;
1038-
n_nlcons = 0;
1039+
m_n_vars = 0;
1040+
m_n_cons_map.clear();
10391041
m_soc_aux_cons.clear();
10401042
m_con_sense_flags.clear();
10411043
m_obj_flag = 0;

0 commit comments

Comments
 (0)