Skip to content

Commit 89d0391

Browse files
Zeroto521CopilotJoao-Dionisio
authored
BUG: can't call getVal from GenExpr (#1148)
* Add tests for expression evaluation in Model Introduces the test_evaluate function to verify correct evaluation of various expressions using Model.getVal, including arithmetic and trigonometric operations, and checks for TypeError on invalid input. * Add unified expression evaluation with _evaluate methods Introduced _evaluate methods to Term, Expr, and all GenExpr subclasses for consistent evaluation of expressions and variables. Refactored Solution and Model to use these methods, simplifying and unifying value retrieval for variables and expressions. Cleaned up class definitions and improved hashing and equality for Term and Variable. * Optimize matrix expression evaluation in Solution Replaced explicit loop with flat iterator for evaluating matrix expressions in Solution.getVal, improving performance and code clarity. * Add test for matrix variable evaluation Introduces a new test 'test_evaluate' to verify correct evaluation of matrix variable division and summation in the model. * Remove matrix expression handling in Solution __getitem__ Eliminated special handling for MatrixExpr and MatrixGenExpr in Solution.__getitem__, simplifying the method to only evaluate single expressions. This change streamlines the code and removes unnecessary numpy array construction. * Refactor Term class variable naming and usage Renamed 'vars' to 'vartuple' and '_hash' to 'hashval' in the Term class for clarity. Updated all references and methods to use the new attribute names, improving code readability and consistency. * Use double precision for expression evaluation Changed internal evaluation methods in expression classes from float to double for improved numerical precision. Updated type declarations and imports accordingly. * Add type annotations and improve Term class attributes Added type annotations to the Term constructor and specified types for class attributes. This improves code clarity and type safety. * Make _evaluate methods public and refactor Expr/Variable Changed all _evaluate methods in expression classes from cdef to cpdef to make them accessible from Python. Moved the definition of terms and _evaluate to the Expr class in scip.pxd, and refactored Variable to inherit from Expr in scip.pxd, consolidating class member definitions. These changes improve the interface and maintainability of expression evaluation. * Update CHANGELOG.md * Remove hash method from Variable * back to old behavior * Fix MatrixExpr _evaluate to return ndarray type Ensures that MatrixExpr._evaluate returns a numpy ndarray when appropriate. Adds a test to verify the return type of getVal for matrix variables. * Remove unused _evaluate method from Solution stub Deleted the _evaluate method from the Solution class in the scip.pyi stub file, as it is no longer needed. Also added TYPE_CHECKING to the typing imports. * Add @disjoint_base decorator to Term and UnaryExpr Applied the @disjoint_base decorator to the Term and UnaryExpr classes in scip.pyi to clarify their base class relationships. This may improve type checking or class hierarchy handling. * Add noqa to suppress unused import warning Appended '# noqa: F401' to the 'ClassVar' import to suppress linter warnings about unused imports in scip.pyi. * cache `_evaluate` function for matrix * Refactor _evaluate to use np.frompyfunc Replaces the @np.vectorize-decorated _evaluate function with an np.frompyfunc-based implementation for evaluating expressions with solutions. This change may improve compatibility and performance when applying _evaluate to arrays. * Simplify _evaluate return in MatrixExpr Refactored the _evaluate method in MatrixExpr to always return the result as a NumPy ndarray, removing the conditional type check. * Expand test_evaluate with additional variable and cases Added a second variable 'y' to test_evaluate and included new assertions for expressions involving both variables. This improves coverage of expression evaluation, including addition and division of variables. * Update expected value in test_evaluate assertion Changed the expected result of m.getVal(x + y + 1) from 3 to 4 in test_evaluate to reflect updated logic or correct the test expectation. * Optimize Expr._evaluate by iterating dict with PyDict_Next Replaces the standard Python dict iteration in Expr._evaluate with a more efficient C-level iteration using PyDict_Next. This change improves performance by avoiding Python-level overhead when evaluating expressions. * Optimize Term evaluation with early exit on zero Improved the _evaluate method in Term to return early if the product becomes zero, enhancing performance. Added a test case to verify correct evaluation when a variable is zero. * Fix loop variable usage in Term evaluation Replaces iteration over self.vartuple with a range-based loop to correctly access variables by index, preventing potential errors in variable assignment during term evaluation. * Refactor variable names in Term class evaluation Renamed local variables 'scip' and 'sol' to 'scip_ptr' and 'sol_ptr' in the Term class to improve clarity and avoid shadowing the input parameter 'sol'. * Refactor variable initialization in _evaluate methods Updated variable initialization in Term._evaluate and Expr._evaluate for clarity and consistency. Changed res initialization to 1.0 and consolidated variable declarations to improve code readability. * Add exception specification to _evaluate methods Appended 'except *' to all cpdef double _evaluate methods in expression classes to ensure proper exception handling in Cython. This change improves error propagation and consistency across the codebase. * Fix _evaluate method calls in expression classes Corrects the _evaluate method implementations in VarExpr, PowExpr, and UnaryExpr to explicitly cast children[0] to GenExpr before calling _evaluate. This ensures proper method resolution and avoids potential attribute errors. * Optimize SumExpr and ProdExpr evaluation loops Refactored the _evaluate methods in SumExpr and ProdExpr to use indexed loops for better performance and type safety. Added early exit in ProdExpr when result becomes zero to improve efficiency. * Fix evaluation of UnaryExpr in expression module Corrects the _evaluate method in UnaryExpr to properly evaluate the child expression before applying the math operation. This ensures that the math function receives a numeric value instead of an expression object. * Fix error message grammar Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Replace _evaluate with _vec_evaluate in matrix.pxi Removed the unused _evaluate function and replaced its usage with _vec_evaluate in the MatrixExpr class. This change clarifies the evaluation logic and ensures consistency in function naming. * Change hashval type to Py_ssize_t in Term class Updated the Term class in expr.pxi to use Py_ssize_t for the hashval attribute instead of int, ensuring compatibility with Python's hash function return type and improving type safety. * Handle 'abs' operator in UnaryExpr evaluation Special-cases the 'abs' operator in UnaryExpr._evaluate to use math.fabs, ensuring correct evaluation for absolute value expressions. * Update getSolVal return type to support NDArray Modified the getSolVal method in the Model class to return either a float or a numpy NDArray of float64, reflecting support for vectorized solution values. Also imported NDArray from numpy.typing to enable this type annotation. * Expand getVal to support GenExpr in Model Updated the getVal method in the Model class to accept GenExpr in addition to Expr and MatrixExpr. This broadens the method's applicability to more expression types. * Fix type cast in VarExpr _evaluate method Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove unused include of matrix.pxi The include statement for matrix.pxi was removed from expr.pxi as it is no longer needed. * Update tests/test_expr.py --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: João Dionísio <57299939+Joao-Dionisio@users.noreply.github.com>
1 parent 4c3391d commit 89d0391

File tree

8 files changed

+167
-44
lines changed

8 files changed

+167
-44
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- all fundamental callbacks now raise an error if not implemented
1111
- Fixed the type of MatrixExpr.sum(axis=...) result from MatrixVariable to MatrixExpr.
1212
- Updated IIS result in PyiisfinderExec()
13+
- Model.getVal now supports GenExpr type
1314
- Fixed lotsizing_lazy example
1415
- Fixed incorrect getVal() result when _bestSol.sol was outdated
1516
### Changed

src/pyscipopt/expr.pxi

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,20 @@
4242
# which should, in princple, modify the expr. However, since we do not implement __isub__, __sub__
4343
# gets called (I guess) and so a copy is returned.
4444
# Modifying the expression directly would be a bug, given that the expression might be re-used by the user. </pre>
45+
import math
46+
from typing import TYPE_CHECKING
47+
48+
from pyscipopt.scip cimport Variable, Solution
49+
from cpython.dict cimport PyDict_Next
50+
from cpython.ref cimport PyObject
51+
4552
import numpy as np
4653

4754

55+
if TYPE_CHECKING:
56+
double = float
57+
58+
4859
def _is_number(e):
4960
try:
5061
f = float(e)
@@ -87,23 +98,25 @@ def _expr_richcmp(self, other, op):
8798
raise NotImplementedError("Can only support constraints with '<=', '>=', or '=='.")
8899

89100

90-
class Term:
101+
cdef class Term:
91102
'''This is a monomial term'''
92103

93-
__slots__ = ('vartuple', 'ptrtuple', 'hashval')
104+
cdef readonly tuple vartuple
105+
cdef readonly tuple ptrtuple
106+
cdef Py_ssize_t hashval
94107

95-
def __init__(self, *vartuple):
108+
def __init__(self, *vartuple: Variable):
96109
self.vartuple = tuple(sorted(vartuple, key=lambda v: v.ptr()))
97110
self.ptrtuple = tuple(v.ptr() for v in self.vartuple)
98-
self.hashval = sum(self.ptrtuple)
111+
self.hashval = <Py_ssize_t>hash(self.ptrtuple)
99112

100113
def __getitem__(self, idx):
101114
return self.vartuple[idx]
102115

103-
def __hash__(self):
116+
def __hash__(self) -> Py_ssize_t:
104117
return self.hashval
105118

106-
def __eq__(self, other):
119+
def __eq__(self, other: Term):
107120
return self.ptrtuple == other.ptrtuple
108121

109122
def __len__(self):
@@ -116,6 +129,20 @@ class Term:
116129
def __repr__(self):
117130
return 'Term(%s)' % ', '.join([str(v) for v in self.vartuple])
118131

132+
cpdef double _evaluate(self, Solution sol) except *:
133+
cdef double res = 1.0
134+
cdef SCIP* scip_ptr = sol.scip
135+
cdef SCIP_SOL* sol_ptr = sol.sol
136+
cdef int i = 0, n = len(self)
137+
cdef Variable var
138+
139+
for i in range(n):
140+
var = <Variable>self.vartuple[i]
141+
res *= SCIPgetSolVal(scip_ptr, sol_ptr, var.scip_var)
142+
if res == 0: # early stop
143+
return 0.0
144+
return res
145+
119146

120147
CONST = Term()
121148

@@ -157,7 +184,7 @@ def buildGenExprObj(expr):
157184
##@details Polynomial expressions of variables with operator overloading. \n
158185
#See also the @ref ExprDetails "description" in the expr.pxi.
159186
cdef class Expr:
160-
187+
161188
def __init__(self, terms=None):
162189
'''terms is a dict of variables to coefficients.
163190
@@ -318,6 +345,20 @@ cdef class Expr:
318345
else:
319346
return max(len(v) for v in self.terms)
320347

348+
cpdef double _evaluate(self, Solution sol) except *:
349+
cdef double res = 0
350+
cdef Py_ssize_t pos = <Py_ssize_t>0
351+
cdef PyObject* key_ptr
352+
cdef PyObject* val_ptr
353+
cdef Term term
354+
cdef double coef
355+
356+
while PyDict_Next(self.terms, &pos, &key_ptr, &val_ptr):
357+
term = <Term>key_ptr
358+
coef = <double>(<object>val_ptr)
359+
res += coef * term._evaluate(sol)
360+
return res
361+
321362

322363
cdef class ExprCons:
323364
'''Constraints with a polynomial expressions and lower/upper bounds.'''
@@ -427,10 +468,10 @@ Operator = Op()
427468
#
428469
#See also the @ref ExprDetails "description" in the expr.pxi.
429470
cdef class GenExpr:
471+
430472
cdef public _op
431473
cdef public children
432474

433-
434475
def __init__(self): # do we need it
435476
''' '''
436477

@@ -625,44 +666,88 @@ cdef class SumExpr(GenExpr):
625666
def __repr__(self):
626667
return self._op + "(" + str(self.constant) + "," + ",".join(map(lambda child : child.__repr__(), self.children)) + ")"
627668

669+
cpdef double _evaluate(self, Solution sol) except *:
670+
cdef double res = self.constant
671+
cdef int i = 0, n = len(self.children)
672+
cdef list children = self.children
673+
cdef list coefs = self.coefs
674+
for i in range(n):
675+
res += <double>coefs[i] * (<GenExpr>children[i])._evaluate(sol)
676+
return res
677+
678+
628679
# Prod Expressions
629680
cdef class ProdExpr(GenExpr):
681+
630682
cdef public constant
683+
631684
def __init__(self):
632685
self.constant = 1.0
633686
self.children = []
634687
self._op = Operator.prod
688+
635689
def __repr__(self):
636690
return self._op + "(" + str(self.constant) + "," + ",".join(map(lambda child : child.__repr__(), self.children)) + ")"
637691

692+
cpdef double _evaluate(self, Solution sol) except *:
693+
cdef double res = self.constant
694+
cdef list children = self.children
695+
cdef int i = 0, n = len(children)
696+
for i in range(n):
697+
res *= (<GenExpr>children[i])._evaluate(sol)
698+
if res == 0: # early stop
699+
return 0.0
700+
return res
701+
702+
638703
# Var Expressions
639704
cdef class VarExpr(GenExpr):
705+
640706
cdef public var
707+
641708
def __init__(self, var):
642709
self.children = [var]
643710
self._op = Operator.varidx
711+
644712
def __repr__(self):
645713
return self.children[0].__repr__()
646714

715+
cpdef double _evaluate(self, Solution sol) except *:
716+
return (<Expr>self.children[0])._evaluate(sol)
717+
718+
647719
# Pow Expressions
648720
cdef class PowExpr(GenExpr):
721+
649722
cdef public expo
723+
650724
def __init__(self):
651725
self.expo = 1.0
652726
self.children = []
653727
self._op = Operator.power
728+
654729
def __repr__(self):
655730
return self._op + "(" + self.children[0].__repr__() + "," + str(self.expo) + ")"
656731

732+
cpdef double _evaluate(self, Solution sol) except *:
733+
return (<GenExpr>self.children[0])._evaluate(sol) ** self.expo
734+
735+
657736
# Exp, Log, Sqrt, Sin, Cos Expressions
658737
cdef class UnaryExpr(GenExpr):
659738
def __init__(self, op, expr):
660739
self.children = []
661740
self.children.append(expr)
662741
self._op = op
742+
663743
def __repr__(self):
664744
return self._op + "(" + self.children[0].__repr__() + ")"
665745

746+
cpdef double _evaluate(self, Solution sol) except *:
747+
cdef double res = (<GenExpr>self.children[0])._evaluate(sol)
748+
return math.fabs(res) if self._op == "abs" else getattr(math, self._op)(res)
749+
750+
666751
# class for constant expressions
667752
cdef class Constant(GenExpr):
668753
cdef public number
@@ -673,6 +758,10 @@ cdef class Constant(GenExpr):
673758
def __repr__(self):
674759
return str(self.number)
675760

761+
cpdef double _evaluate(self, Solution sol) except *:
762+
return self.number
763+
764+
676765
def exp(expr):
677766
"""returns expression with exp-function"""
678767
if isinstance(expr, MatrixExpr):

src/pyscipopt/matrix.pxi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
from typing import Literal, Optional, Tuple, Union
66
import numpy as np
7+
from numpy.typing import NDArray
78
try:
89
# NumPy 2.x location
910
from numpy.lib.array_utils import normalize_axis_tuple
@@ -12,6 +13,7 @@ except ImportError:
1213
from numpy.core.numeric import normalize_axis_tuple
1314

1415
cimport numpy as cnp
16+
from pyscipopt.scip cimport Expr, Solution
1517

1618
cnp.import_array()
1719

@@ -142,6 +144,10 @@ class MatrixExpr(np.ndarray):
142144
return super().__rsub__(other).view(MatrixExpr)
143145

144146

147+
def _evaluate(self, Solution sol) -> NDArray[np.float64]:
148+
return _vec_evaluate(self, sol).view(np.ndarray)
149+
150+
145151
class MatrixGenExpr(MatrixExpr):
146152
pass
147153

@@ -166,6 +172,9 @@ cdef inline _ensure_array(arg, bool convert_scalar = True):
166172
return np.array(arg, dtype=object) if convert_scalar else arg
167173

168174

175+
_vec_evaluate = np.frompyfunc(lambda expr, sol: expr._evaluate(sol), 2, 1)
176+
177+
169178
def _core_dot(cnp.ndarray a, cnp.ndarray b) -> Union[Expr, np.ndarray]:
170179
"""
171180
Perform matrix multiplication between a N-Demension constant array and a N-Demension

src/pyscipopt/scip.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,6 +2110,8 @@ cdef extern from "tpi/tpi.h":
21102110
cdef class Expr:
21112111
cdef public terms
21122112

2113+
cpdef double _evaluate(self, Solution sol)
2114+
21132115
cdef class Event:
21142116
cdef SCIP_EVENT* event
21152117
# can be used to store problem data

src/pyscipopt/scip.pxi

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ from dataclasses import dataclass
2121
from typing import Union
2222

2323
import numpy as np
24+
from numpy.typing import NDArray
2425

2526
include "expr.pxi"
2627
include "lp.pxi"
@@ -1099,29 +1100,8 @@ cdef class Solution:
10991100
return sol
11001101

11011102
def __getitem__(self, expr: Union[Expr, MatrixExpr]):
1102-
if isinstance(expr, MatrixExpr):
1103-
result = np.zeros(expr.shape, dtype=np.float64)
1104-
for idx in np.ndindex(expr.shape):
1105-
result[idx] = self.__getitem__(expr[idx])
1106-
return result
1107-
1108-
# fast track for Variable
1109-
cdef SCIP_Real coeff
1110-
cdef _VarArray wrapper
1111-
if isinstance(expr, Variable):
1112-
wrapper = _VarArray(expr)
1113-
self._checkStage("SCIPgetSolVal")
1114-
return SCIPgetSolVal(self.scip, self.sol, wrapper.ptr[0])
1115-
return sum(self._evaluate(term)*coeff for term, coeff in expr.terms.items() if coeff != 0)
1116-
1117-
def _evaluate(self, term):
11181103
self._checkStage("SCIPgetSolVal")
1119-
result = 1
1120-
cdef _VarArray wrapper
1121-
wrapper = _VarArray(term.vartuple)
1122-
for i in range(len(term.vartuple)):
1123-
result *= SCIPgetSolVal(self.scip, self.sol, wrapper.ptr[i])
1124-
return result
1104+
return expr._evaluate(self)
11251105

11261106
def __setitem__(self, Variable var, value):
11271107
PY_SCIP_CALL(SCIPsetSolVal(self.scip, self.sol, var.scip_var, value))
@@ -10747,7 +10727,11 @@ cdef class Model:
1074710727

1074810728
return self.getSolObjVal(self._bestSol, original)
1074910729

10750-
def getSolVal(self, Solution sol, Expr expr):
10730+
def getSolVal(
10731+
self,
10732+
Solution sol,
10733+
expr: Union[Expr, GenExpr],
10734+
) -> Union[float, NDArray[np.float64]]:
1075110735
"""
1075210736
Retrieve value of given variable or expression in the given solution or in
1075310737
the LP/pseudo solution if sol == None
@@ -10767,24 +10751,22 @@ cdef class Model:
1076710751
A variable is also an expression.
1076810752

1076910753
"""
10754+
if not isinstance(expr, (Expr, GenExpr)):
10755+
raise TypeError(
10756+
"Argument 'expr' has incorrect type (expected 'Expr' or 'GenExpr', "
10757+
f"got {type(expr)})"
10758+
)
1077010759
# no need to create a NULL solution wrapper in case we have a variable
10771-
cdef _VarArray wrapper
10772-
if sol == None and isinstance(expr, Variable):
10773-
wrapper = _VarArray(expr)
10774-
return SCIPgetSolVal(self._scip, NULL, wrapper.ptr[0])
10775-
if sol == None:
10776-
sol = Solution.create(self._scip, NULL)
10777-
return sol[expr]
10760+
return (sol or Solution.create(self._scip, NULL))[expr]
1077810761

10779-
def getVal(self, expr: Union[Expr, MatrixExpr] ):
10762+
def getVal(self, expr: Union[Expr, GenExpr, MatrixExpr] ):
1078010763
"""
1078110764
Retrieve the value of the given variable or expression in the best known solution.
1078210765
Can only be called after solving is completed.
1078310766
1078410767
Parameters
1078510768
----------
10786-
expr : Expr ot MatrixExpr
10787-
polynomial expression to query the value of
10769+
expr : Expr, GenExpr or MatrixExpr
1078810770
1078910771
Returns
1079010772
-------

src/pyscipopt/scip.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import ClassVar
2+
from typing import TYPE_CHECKING, ClassVar # noqa: F401
33

44
import numpy
55
from _typeshed import Incomplete
@@ -2062,7 +2062,6 @@ class Solution:
20622062
data: Incomplete
20632063
def __init__(self, *args: Incomplete, **kwargs: Incomplete) -> None: ...
20642064
def _checkStage(self, method: Incomplete) -> Incomplete: ...
2065-
def _evaluate(self, term: Incomplete) -> Incomplete: ...
20662065
def getOrigin(self) -> Incomplete: ...
20672066
def retransform(self) -> Incomplete: ...
20682067
def translate(self, target: Incomplete) -> Incomplete: ...
@@ -2122,6 +2121,7 @@ class SumExpr(GenExpr):
21222121
constant: Incomplete
21232122
def __init__(self, *args: Incomplete, **kwargs: Incomplete) -> None: ...
21242123

2124+
@disjoint_base
21252125
class Term:
21262126
hashval: Incomplete
21272127
ptrtuple: Incomplete
@@ -2138,6 +2138,7 @@ class Term:
21382138
def __lt__(self, other: object) -> bool: ...
21392139
def __ne__(self, other: object) -> bool: ...
21402140

2141+
@disjoint_base
21412142
class UnaryExpr(GenExpr):
21422143
def __init__(self, *args: Incomplete, **kwargs: Incomplete) -> None: ...
21432144

0 commit comments

Comments
 (0)