Skip to content

Commit 766148b

Browse files
committed
Equality refacto
1 parent 2178cfb commit 766148b

3 files changed

Lines changed: 22 additions & 15 deletions

File tree

src/gems/model/constraint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ def __post_init__(
7575
if is_unbounded(self.upper_bound) and not is_non_negative(self.upper_bound):
7676
raise ValueError("Upper bound should not be -Inf")
7777

78+
@property
79+
def is_equality(self) -> bool:
80+
return (
81+
not is_unbounded(self.lower_bound)
82+
and not is_unbounded(self.upper_bound)
83+
and expressions_equal(self.lower_bound, self.upper_bound)
84+
)
85+
7886
def replicate(self, /, **changes: Any) -> "Constraint":
7987
return replace(self, **changes)
8088

src/gems/simulation/optimization.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -691,23 +691,22 @@ def _create_constraints_for_model(
691691
# Sanitize constraint name for LP format (spaces → underscores)
692692
safe_name = constraint.name.replace(" ", "_").replace("-", "_")
693693

694-
# Lower bound constraint: lhs >= lb (if lb != -inf)
695-
if not is_unbounded(constraint.lower_bound):
694+
if constraint.is_equality:
696695
lb = visit(constraint.lower_bound, builder)
697696
if validity_mask is not None:
698697
lb = _apply_validity_mask(lb, validity_mask)
699-
name = f"{prefix}__{safe_name}__lb"
700-
con_lb = lhs >= lb # type: ignore[operator]
701-
self.linopy_model.add_constraints(con_lb, name=name) # type: ignore[arg-type]
702-
703-
# Upper bound constraint: lhs <= ub (if ub != +inf)
704-
if not is_unbounded(constraint.upper_bound):
705-
ub = visit(constraint.upper_bound, builder)
706-
if validity_mask is not None:
707-
ub = _apply_validity_mask(ub, validity_mask)
708-
name = f"{prefix}__{safe_name}__ub"
709-
con_ub = lhs <= ub # type: ignore[operator]
710-
self.linopy_model.add_constraints(con_ub, name=name) # type: ignore[arg-type]
698+
self.linopy_model.add_constraints(lhs == lb, name=f"{prefix}__{safe_name}__eq") # type: ignore[operator,arg-type]
699+
else:
700+
if not is_unbounded(constraint.lower_bound):
701+
lb = visit(constraint.lower_bound, builder)
702+
if validity_mask is not None:
703+
lb = _apply_validity_mask(lb, validity_mask)
704+
self.linopy_model.add_constraints(lhs >= lb, name=f"{prefix}__{safe_name}__lb") # type: ignore[operator,arg-type]
705+
if not is_unbounded(constraint.upper_bound):
706+
ub = visit(constraint.upper_bound, builder)
707+
if validity_mask is not None:
708+
ub = _apply_validity_mask(ub, validity_mask)
709+
self.linopy_model.add_constraints(lhs <= ub, name=f"{prefix}__{safe_name}__ub") # type: ignore[operator,arg-type]
711710

712711
def _add_objectives_for_model(
713712
self,

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)