Skip to content

Commit a82bf15

Browse files
Zeroto521Copilot
andauthored
API: use __array_ufunc__ as numpy ufunc enter (#1163)
* Refactor matrix comparison logic using __array_ufunc__ Replaces custom rich comparison methods for MatrixExpr and MatrixExprCons with __array_ufunc__ implementations, leveraging numpy's ufuncs for elementwise comparisons. Removes legacy helper functions and streamlines operator handling, improving maintainability and performance. * Remove unused matrix.pxi include from expr.pxi Deleted the 'include "matrix.pxi"' line from expr.pxi as it is no longer needed. This helps clean up the code and avoid unnecessary dependencies. * Replace MatrixExpr checks with np.ndarray in expr.pxi Updated type checks from MatrixExpr to np.ndarray throughout expr.pxi to improve compatibility with NumPy arrays and simplify matrix expression handling. * Fix logic error in MatrixExpr __array_ufunc__ handling Corrects the condition to check for NotImplemented in MatrixExpr's __array_ufunc__ method, ensuring proper delegation to the superclass implementation. * Fix matmul return type assertion in test Update the test to assert that the result of 1D @ 1D matrix multiplication is of type Expr instead of MatrixExpr. * Restrict supported ufuncs in MatrixExprCons MatrixExprCons now only supports the '<=' and '>=' ufuncs, raising NotImplementedError for all others. This clarifies and enforces the intended usage of the class. * Update matrix sum test to use MatrixExpr view Refactored test_matrix_sum_result to use a view of MatrixExpr for sum operation, aligning the test with the expected usage pattern and simplifying result comparison. * Refactor __array_ufunc__ return handling in MatrixExpr Simplifies the __array_ufunc__ method by directly returning a MatrixExpr view when the result is a numpy ndarray, improving code clarity. * Replace np.ndarray checks with MatrixExpr in unary ops Updated type checks in exp, log, sqrt, sin, and cos functions to use MatrixExpr instead of np.ndarray. This change ensures that matrix-specific expression handling is applied only to MatrixExpr instances. * Remove method stubs from MatrixExpr and MatrixExprCons Eliminated explicit method declarations from the MatrixExpr and MatrixExprCons classes in the type stub, leaving only ellipses. This simplifies the type hints and may reflect a change in how these classes are intended to be used or documented. * Update CHANGELOG for MatrixExpr __array_ufunc__ support Added entry noting that MatrixExpr and MatrixExprCons now use the `__array_ufunc__` protocol to control all numpy.ufunc inputs and outputs. * Import Expr from pyscipopt.scip in matrix.pxi Added a cimport for Expr from pyscipopt.scip to enable usage of the Expr type in this module. * Fix error message grammar Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix error message grammar Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix argument unpacking in ufunc handling for MatrixExpr Replaces use of argument unpacking (*args) with explicit indexing (args[0], args[1]) in calls to _vec_le, _vec_ge, and _vec_eq within MatrixExpr and MatrixExprCons. This change clarifies argument passing and may prevent potential issues with argument order or count. * Move _vec_evaluate definition above _ensure_array Relocated the definition of _vec_evaluate to group all np.frompyfunc vectorized operator functions together for better code organization and readability. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b3f298d commit a82bf15

File tree

3 files changed

+25
-106
lines changed

3 files changed

+25
-106
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- Speed up MatrixExpr.add.reduce via quicksum
2121
- Speed up np.ndarray(..., dtype=np.float64) @ MatrixExpr
2222
- Minimum numpy version increased from 1.16.0 to 1.19.0
23+
- MatrixExpr and MatrixExprCons use `__array_ufunc__` protocol to control all numpy.ufunc inputs and outputs
2324
### Removed
2425

2526
## 6.0.0 - 2025.xx.yy

src/pyscipopt/matrix.pxi

Lines changed: 22 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
"""
2-
# TODO Cythonize things. Improve performance.
3-
# TODO Add tests
4-
"""
1+
import operator
52
from typing import Literal, Optional, Tuple, Union
63
import numpy as np
74
from numpy.typing import NDArray
@@ -18,42 +15,6 @@ from pyscipopt.scip cimport Expr, Solution
1815
cnp.import_array()
1916

2017

21-
def _is_number(e):
22-
try:
23-
f = float(e)
24-
return True
25-
except ValueError: # for malformed strings
26-
return False
27-
except TypeError: # for other types (Variable, Expr)
28-
return False
29-
30-
31-
def _matrixexpr_richcmp(self, other, op):
32-
def _richcmp(self, other, op):
33-
if op == 1: # <=
34-
return self.__le__(other)
35-
elif op == 5: # >=
36-
return self.__ge__(other)
37-
elif op == 2: # ==
38-
return self.__eq__(other)
39-
else:
40-
raise NotImplementedError("Can only support constraints with '<=', '>=', or '=='.")
41-
42-
if _is_number(other) or isinstance(other, Expr):
43-
res = np.empty(self.shape, dtype=object)
44-
res.flat = [_richcmp(i, other, op) for i in self.flat]
45-
46-
elif isinstance(other, np.ndarray):
47-
out = np.broadcast(self, other)
48-
res = np.empty(out.shape, dtype=object)
49-
res.flat = [_richcmp(i, j, op) for i, j in out]
50-
51-
else:
52-
raise TypeError(f"Unsupported type {type(other)}")
53-
54-
return res.view(MatrixExprCons)
55-
56-
5718
class MatrixExpr(np.ndarray):
5819

5920
def __array_ufunc__(
@@ -97,52 +58,21 @@ class MatrixExpr(np.ndarray):
9758
if method == "__call__": # Standard ufunc call, e.g., np.add(a, b)
9859
if ufunc in {np.matmul, np.dot}:
9960
res = _core_dot(args[0], args[1])
61+
elif ufunc is np.less_equal:
62+
return _vec_le(args[0], args[1]).view(MatrixExprCons)
63+
elif ufunc is np.greater_equal:
64+
return _vec_ge(args[0], args[1]).view(MatrixExprCons)
65+
elif ufunc is np.equal:
66+
return _vec_eq(args[0], args[1]).view(MatrixExprCons)
67+
elif ufunc in {np.less, np.greater, np.not_equal}:
68+
raise NotImplementedError("can only support '<=', '>=', or '=='")
10069

10170
if res is NotImplemented:
10271
if "out" in kwargs: # Unboxing MatrixExpr to stop __array_ufunc__ recursion
10372
kwargs["out"] = tuple(_ensure_array(arg, False) for arg in kwargs["out"])
10473
res = super().__array_ufunc__(ufunc, method, *args, **kwargs)
10574
return res.view(MatrixExpr) if isinstance(res, np.ndarray) else res
10675

107-
def __le__(self, other: Union[float, int, "Expr", np.ndarray, "MatrixExpr"]) -> MatrixExprCons:
108-
return _matrixexpr_richcmp(self, other, 1)
109-
110-
def __ge__(self, other: Union[float, int, "Expr", np.ndarray, "MatrixExpr"]) -> MatrixExprCons:
111-
return _matrixexpr_richcmp(self, other, 5)
112-
113-
def __eq__(self, other: Union[float, int, "Expr", np.ndarray, "MatrixExpr"]) -> MatrixExprCons:
114-
return _matrixexpr_richcmp(self, other, 2)
115-
116-
def __add__(self, other):
117-
return super().__add__(other).view(MatrixExpr)
118-
119-
def __iadd__(self, other):
120-
return super().__iadd__(other).view(MatrixExpr)
121-
122-
def __mul__(self, other):
123-
return super().__mul__(other).view(MatrixExpr)
124-
125-
def __truediv__(self, other):
126-
return super().__truediv__(other).view(MatrixExpr)
127-
128-
def __rtruediv__(self, other):
129-
return super().__rtruediv__(other).view(MatrixExpr)
130-
131-
def __pow__(self, other):
132-
return super().__pow__(other).view(MatrixExpr)
133-
134-
def __sub__(self, other):
135-
return super().__sub__(other).view(MatrixExpr)
136-
137-
def __radd__(self, other):
138-
return super().__radd__(other).view(MatrixExpr)
139-
140-
def __rmul__(self, other):
141-
return super().__rmul__(other).view(MatrixExpr)
142-
143-
def __rsub__(self, other):
144-
return super().__rsub__(other).view(MatrixExpr)
145-
14676

14777
def _evaluate(self, Solution sol) -> NDArray[np.float64]:
14878
return _vec_evaluate(self, sol).view(np.ndarray)
@@ -154,14 +84,20 @@ class MatrixGenExpr(MatrixExpr):
15484

15585
class MatrixExprCons(np.ndarray):
15686

157-
def __le__(self, other: Union[float, int, np.ndarray]) -> MatrixExprCons:
158-
return _matrixexpr_richcmp(self, other, 1)
87+
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
88+
if method == "__call__":
89+
args = tuple(_ensure_array(arg) for arg in args)
90+
if ufunc is np.less_equal:
91+
return _vec_le(args[0], args[1]).view(MatrixExprCons)
92+
elif ufunc is np.greater_equal:
93+
return _vec_ge(args[0], args[1]).view(MatrixExprCons)
94+
raise NotImplementedError("can only support '<=' or '>='")
15995

160-
def __ge__(self, other: Union[float, int, np.ndarray]) -> MatrixExprCons:
161-
return _matrixexpr_richcmp(self, other, 5)
16296

163-
def __eq__(self, other):
164-
raise NotImplementedError("Cannot compare MatrixExprCons with '=='.")
97+
_vec_le = np.frompyfunc(operator.le, 2, 1)
98+
_vec_ge = np.frompyfunc(operator.ge, 2, 1)
99+
_vec_eq = np.frompyfunc(operator.eq, 2, 1)
100+
_vec_evaluate = np.frompyfunc(lambda expr, sol: expr._evaluate(sol), 2, 1)
165101

166102

167103
cdef inline _ensure_array(arg, bool convert_scalar = True):
@@ -172,9 +108,6 @@ cdef inline _ensure_array(arg, bool convert_scalar = True):
172108
return np.array(arg, dtype=object) if convert_scalar else arg
173109

174110

175-
_vec_evaluate = np.frompyfunc(lambda expr, sol: expr._evaluate(sol), 2, 1)
176-
177-
178111
def _core_dot(cnp.ndarray a, cnp.ndarray b) -> Union[Expr, np.ndarray]:
179112
"""
180113
Perform matrix multiplication between a N-Demension constant array and a N-Demension
@@ -261,7 +194,7 @@ def _core_sum(
261194

262195
Parameters
263196
----------
264-
a : cnp.ndarray
197+
a : np.ndarray
265198
A `np.ndarray` of type `object` and containing `Expr` objects.
266199

267200
axis : None or int or tuple of ints, optional

src/pyscipopt/scip.pyi

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -510,25 +510,10 @@ class MatrixConstraint(numpy.ndarray):
510510
def isStickingAtNode(self) -> Incomplete: ...
511511

512512
class MatrixExpr(numpy.ndarray):
513-
def __add__(self, other: Incomplete) -> Incomplete: ...
514-
def __eq__(self, other: Incomplete)-> Incomplete: ...
515-
def __ge__(self, other: Incomplete) -> MatrixExprCons: ...
516-
def __iadd__(self, other: Incomplete) -> Incomplete: ... # noqa: PYI034
517-
def __le__(self, other: Incomplete) -> MatrixExprCons: ...
518-
def __matmul__(self, other: Incomplete) -> Incomplete: ...
519-
def __mul__(self, other: Incomplete) -> Incomplete: ...
520-
def __pow__(self, other: Incomplete) -> Incomplete: ... # type: ignore[override]
521-
def __radd__(self, other: Incomplete) -> Incomplete: ...
522-
def __rmul__(self, other: Incomplete) -> Incomplete: ...
523-
def __rsub__(self, other: Incomplete) -> Incomplete: ...
524-
def __rtruediv__(self, other: Incomplete) -> Incomplete: ...
525-
def __sub__(self, other: Incomplete) -> Incomplete: ...
526-
def __truediv__(self, other: Incomplete) -> Incomplete: ...
513+
...
527514

528515
class MatrixExprCons(numpy.ndarray):
529-
def __eq__(self, other: Incomplete)-> Incomplete: ...
530-
def __ge__(self, other: Incomplete) -> MatrixExprCons: ...
531-
def __le__(self, other: Incomplete) -> MatrixExprCons: ...
516+
...
532517

533518
class MatrixGenExpr(MatrixExpr):
534519
...

0 commit comments

Comments
 (0)