Skip to content

Commit f3c0f82

Browse files
authored
Speed up Constant @ MatrixExpr (#1159)
* Add numpy as build dependency and enhance matrix operations Added numpy to build requirements in pyproject.toml and setup.py, ensuring numpy headers are included during compilation. Refactored matrix.pxi to improve matrix operation support, including custom __array_ufunc__ handling for dot/matmul, utility functions for type checking and array conversion, and a vectorized _core_dot implementation for efficient matrix multiplication. * Add and update matrix dot product tests Introduces a new parameterized test for matrix dot product performance and updates an assertion in test_matrix_matmul_return_type to expect Expr instead of MatrixExpr for 1D @ 1D operations. * Replace MatrixExpr type checks with np.ndarray Updated type checks throughout expr.pxi to use np.ndarray instead of MatrixExpr, improving compatibility with numpy arrays. Also adjusted matrix.pxi to ensure ufunc results are returned as MatrixExpr views when appropriate. * Remove redundant 'out' handling in __array_ufunc__ Deleted unnecessary conversion of the 'out' keyword argument to an array in MatrixExpr.__array_ufunc__, as it is not required for correct operation. * Update CHANGELOG.md * Remove custom __matmul__ from MatrixExpr Deleted the overridden __matmul__ method in MatrixExpr, reverting to the default numpy ndarray behavior for matrix multiplication. This simplifies the class and avoids unnecessary type casting. * Refactor matrix multiplication logic in MatrixExpr Introduces a new _core_dot function to handle matrix multiplication between constant arrays and arrays of Expr objects, supporting both 1-D and N-D cases. The original _core_dot is renamed to _core_dot_2d and improved for clarity and efficiency. Updates __array_ufunc__ to use the new logic, ensuring correct handling of mixed-type matrix operations. * Update matrix matmul return type tests Adjusted tests in test_matrix_matmul_return_type to check the return type when performing matrix multiplication with numpy arrays and matrix variables, including new cases for ND arrays. Ensures correct type inference for various matrix multiplication scenarios. * Remove `_is_num_dt` Replaces the _is_num_dt helper function with direct dtype.kind checks for numeric types in _core_dot. This simplifies the code and removes an unnecessary inline function. * Enhance MatrixExpr __array_ufunc__ with type hints and docs Added detailed type annotations and a docstring to the MatrixExpr.__array_ufunc__ method for improved clarity and maintainability. Also clarified handling of ufunc methods and argument unboxing. * Rename test_matrix_dot to test_matrix_dot_performance Renamed the test function to better reflect its purpose of testing performance for matrix dot operations. * Add test for matrix dot value retrieval Introduces test_matrix_dot_value to verify correct value retrieval from matrix variable dot products using getVal. Ensures expected results for both 1D and higher-dimensional dot operations.
1 parent fd2e0c6 commit f3c0f82

File tree

6 files changed

+179
-17
lines changed

6 files changed

+179
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- Fixed incorrect getVal() result when _bestSol.sol was outdated
1515
### Changed
1616
- changed default value of enablepricing flag to True
17+
- Speed up np.ndarray(..., dtype=np.float64) @ MatrixExpr
1718
### Removed
1819

1920
## 6.0.0 - 2025.xx.yy

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[build-system]
2-
requires = ['setuptools', 'cython >=0.21']
2+
requires = ["setuptools", "cython >=0.21", "numpy"]
33
build-backend = "setuptools.build_meta"
44

55
[project]

setup.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
from setuptools import find_packages, setup, Extension
2-
import os, platform, sys
1+
import os
2+
import platform
3+
import sys
4+
5+
import numpy as np
6+
from setuptools import Extension, find_packages, setup
37

48
# look for environment variable that specifies path to SCIP
59
scipoptdir = os.environ.get("SCIPOPTDIR", "").strip('"')
@@ -112,7 +116,7 @@
112116
Extension(
113117
"pyscipopt.scip",
114118
[os.path.join(packagedir, "scip%s" % ext)],
115-
include_dirs=includedirs,
119+
include_dirs=includedirs + [np.get_include()],
116120
library_dirs=[libdir],
117121
libraries=[libname],
118122
extra_compile_args=extra_compile_args,

src/pyscipopt/expr.pxi

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# which should, in princple, modify the expr. However, since we do not implement __isub__, __sub__
4343
# gets called (I guess) and so a copy is returned.
4444
# Modifying the expression directly would be a bug, given that the expression might be re-used by the user. </pre>
45-
include "matrix.pxi"
45+
import numpy as np
4646

4747

4848
def _is_number(e):
@@ -61,7 +61,7 @@ def _expr_richcmp(self, other, op):
6161
return (self - other) <= 0.0
6262
elif _is_number(other):
6363
return ExprCons(self, rhs=float(other))
64-
elif isinstance(other, MatrixExpr):
64+
elif isinstance(other, np.ndarray):
6565
return _expr_richcmp(other, self, 5)
6666
else:
6767
raise TypeError(f"Unsupported type {type(other)}")
@@ -70,7 +70,7 @@ def _expr_richcmp(self, other, op):
7070
return (self - other) >= 0.0
7171
elif _is_number(other):
7272
return ExprCons(self, lhs=float(other))
73-
elif isinstance(other, MatrixExpr):
73+
elif isinstance(other, np.ndarray):
7474
return _expr_richcmp(other, self, 1)
7575
else:
7676
raise TypeError(f"Unsupported type {type(other)}")
@@ -79,7 +79,7 @@ def _expr_richcmp(self, other, op):
7979
return (self - other) == 0.0
8080
elif _is_number(other):
8181
return ExprCons(self, lhs=float(other), rhs=float(other))
82-
elif isinstance(other, MatrixExpr):
82+
elif isinstance(other, np.ndarray):
8383
return _expr_richcmp(other, self, 2)
8484
else:
8585
raise TypeError(f"Unsupported type {type(other)}")
@@ -144,7 +144,7 @@ def buildGenExprObj(expr):
144144
sumexpr += coef * prodexpr
145145
return sumexpr
146146

147-
elif isinstance(expr, MatrixExpr):
147+
elif isinstance(expr, np.ndarray):
148148
GenExprs = np.empty(expr.shape, dtype=object)
149149
for idx in np.ndindex(expr.shape):
150150
GenExprs[idx] = buildGenExprObj(expr[idx])
@@ -200,7 +200,7 @@ cdef class Expr:
200200
terms[CONST] = terms.get(CONST, 0.0) + c
201201
elif isinstance(right, GenExpr):
202202
return buildGenExprObj(left) + right
203-
elif isinstance(right, MatrixExpr):
203+
elif isinstance(right, np.ndarray):
204204
return right + left
205205
else:
206206
raise TypeError(f"Unsupported type {type(right)}")
@@ -225,7 +225,7 @@ cdef class Expr:
225225
return self
226226

227227
def __mul__(self, other):
228-
if isinstance(other, MatrixExpr):
228+
if isinstance(other, np.ndarray):
229229
return other * self
230230

231231
if _is_number(other):
@@ -438,7 +438,7 @@ cdef class GenExpr:
438438
return UnaryExpr(Operator.fabs, self)
439439

440440
def __add__(self, other):
441-
if isinstance(other, MatrixExpr):
441+
if isinstance(other, np.ndarray):
442442
return other + self
443443

444444
left = buildGenExprObj(self)
@@ -496,7 +496,7 @@ cdef class GenExpr:
496496
# return self
497497

498498
def __mul__(self, other):
499-
if isinstance(other, MatrixExpr):
499+
if isinstance(other, np.ndarray):
500500
return other * self
501501

502502
left = buildGenExprObj(self)

src/pyscipopt/matrix.pxi

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# TODO Add tests
44
"""
55

6-
from typing import Optional, Tuple, Union
6+
from typing import Literal, Optional, Tuple, Union
77
import numpy as np
88
try:
99
# NumPy 2.x location
@@ -12,6 +12,10 @@ except ImportError:
1212
# Fallback for NumPy 1.x
1313
from numpy.core.numeric import normalize_axis_tuple
1414

15+
cimport numpy as cnp
16+
17+
cnp.import_array()
18+
1519

1620
def _is_number(e):
1721
try:
@@ -51,6 +55,48 @@ def _matrixexpr_richcmp(self, other, op):
5155

5256
class MatrixExpr(np.ndarray):
5357

58+
def __array_ufunc__(
59+
self,
60+
ufunc: np.ufunc,
61+
method: Literal["__call__", "reduce", "reduceat", "accumulate", "outer", "at"],
62+
*args,
63+
**kwargs,
64+
):
65+
"""
66+
Customizes the behavior of NumPy ufuncs for MatrixExpr.
67+
68+
Parameters
69+
----------
70+
ufunc : numpy.ufunc
71+
The ufunc object that was called.
72+
73+
method : {"__call__", "reduce", "reduceat", "accumulate", "outer", "at"}
74+
A string indicating which UFunc method was called.
75+
76+
*args : tuple
77+
The input arguments to the ufunc.
78+
79+
**kwargs : dict
80+
Additional keyword arguments to the ufunc.
81+
82+
Returns
83+
-------
84+
Expr, GenExpr, MatrixExpr
85+
The result of the ufunc operation is wrapped back into a MatrixExpr if
86+
applicable.
87+
88+
"""
89+
res = NotImplemented
90+
if method == "__call__": # Standard ufunc call, e.g., np.add(a, b)
91+
if ufunc in {np.matmul, np.dot}:
92+
res = _core_dot(_ensure_array(args[0]), _ensure_array(args[1]))
93+
94+
if res is NotImplemented:
95+
# Unboxing MatrixExpr to stop __array_ufunc__ recursion
96+
args = tuple(_ensure_array(arg) for arg in args)
97+
res = super().__array_ufunc__(ufunc, method, *args, **kwargs)
98+
return res.view(MatrixExpr) if isinstance(res, np.ndarray) else res
99+
54100
def sum(
55101
self,
56102
axis: Optional[Union[int, Tuple[int, ...]]] = None,
@@ -145,8 +191,6 @@ class MatrixExpr(np.ndarray):
145191
def __rsub__(self, other):
146192
return super().__rsub__(other).view(MatrixExpr)
147193

148-
def __matmul__(self, other):
149-
return super().__matmul__(other).view(MatrixExpr)
150194

151195
class MatrixGenExpr(MatrixExpr):
152196
pass
@@ -161,3 +205,86 @@ class MatrixExprCons(np.ndarray):
161205

162206
def __eq__(self, other):
163207
raise NotImplementedError("Cannot compare MatrixExprCons with '=='.")
208+
209+
210+
cdef inline _ensure_array(arg, bool convert_scalar = True):
211+
if isinstance(arg, np.ndarray):
212+
return arg.view(np.ndarray)
213+
elif isinstance(arg, (list, tuple)):
214+
return np.asarray(arg)
215+
return np.array(arg, dtype=object) if convert_scalar else arg
216+
217+
218+
def _core_dot(cnp.ndarray a, cnp.ndarray b) -> Union[Expr, np.ndarray]:
219+
"""
220+
Perform matrix multiplication between a N-Demension constant array and a N-Demension
221+
`np.ndarray` of type `object` and containing `Expr` objects.
222+
223+
Parameters
224+
----------
225+
a : np.ndarray
226+
A constant n-d `np.ndarray` of type `np.float64`.
227+
228+
b : np.ndarray
229+
A n-d `np.ndarray` of type `object` and containing `Expr` objects.
230+
231+
Returns
232+
-------
233+
Expr or np.ndarray
234+
If both `a` and `b` are 1-D arrays, return an `Expr`, otherwise return a
235+
`np.ndarray` of type `object` and containing `Expr` objects.
236+
"""
237+
cdef bool a_is_1d = a.ndim == 1
238+
cdef bool b_is_1d = b.ndim == 1
239+
cdef cnp.ndarray a_nd = a[..., np.newaxis, :] if a_is_1d else a
240+
cdef cnp.ndarray b_nd = b[..., :, np.newaxis] if b_is_1d else b
241+
cdef bool a_is_num = a_nd.dtype.kind in "fiub"
242+
243+
if a_is_num ^ (b_nd.dtype.kind in "fiub"):
244+
res = _core_dot_2d(a_nd, b_nd) if a_is_num else _core_dot_2d(b_nd.T, a_nd.T).T
245+
if a_is_1d and b_is_1d:
246+
return res.item()
247+
if a_is_1d:
248+
return res.reshape(np.delete(res.shape, -2))
249+
if b_is_1d:
250+
return res.reshape(np.delete(res.shape, -1))
251+
return res
252+
return NotImplemented
253+
254+
255+
@np.vectorize(otypes=[object], signature="(m,n),(n,p)->(m,p)")
256+
def _core_dot_2d(cnp.ndarray a, cnp.ndarray x) -> np.ndarray:
257+
"""
258+
Perform matrix multiplication between a 2-Demension constant array and a 2-Demension
259+
`np.ndarray` of type `object` and containing `Expr` objects.
260+
261+
Parameters
262+
----------
263+
a : np.ndarray
264+
A 2-D `np.ndarray` of type `np.float64`.
265+
266+
x : np.ndarray
267+
A 2-D `np.ndarray` of type `object` and containing `Expr` objects.
268+
269+
Returns
270+
-------
271+
np.ndarray
272+
A 2-D `np.ndarray` of type `object` and containing `Expr` objects.
273+
"""
274+
if not a.flags.c_contiguous or a.dtype != np.float64:
275+
a = np.ascontiguousarray(a, dtype=np.float64)
276+
277+
cdef const double[:, :] a_view = a
278+
cdef int m = a.shape[0], k = x.shape[1]
279+
cdef cnp.ndarray[object, ndim=2] res = np.zeros((m, k), dtype=object)
280+
cdef Py_ssize_t[:] nonzero
281+
cdef int i, j, idx
282+
283+
for i in range(m):
284+
if (nonzero := np.flatnonzero(a_view[i, :])).size == 0:
285+
continue
286+
287+
for j in range(k):
288+
res[i, j] = quicksum(a_view[i, idx] * x[idx, j] for idx in nonzero)
289+
290+
return res

tests/test_matrix_variable.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,33 @@ def test_matrix_sum_axis_not_none_performance(n):
293293
assert model.isGT(end_orig - start_orig, end_matrix - start_matrix)
294294

295295

296+
@pytest.mark.parametrize("n", [50, 100])
297+
def test_matrix_dot_performance(n):
298+
model = Model()
299+
x = model.addMatrixVar((n, n))
300+
a = np.random.rand(n, n)
301+
302+
start = time()
303+
a @ x.view(np.ndarray)
304+
orig = time() - start
305+
306+
start = time()
307+
a @ x
308+
matrix = time() - start
309+
310+
assert model.isGT(orig, matrix)
311+
312+
313+
def test_matrix_dot_value():
314+
model = Model()
315+
x = model.addMatrixVar(3, lb=[1, 2, 3], ub=[1, 2, 3])
316+
y = model.addMatrixVar((3, 2), lb=1, ub=1)
317+
model.optimize()
318+
319+
assert model.getVal(np.ones(3) @ x) == 6
320+
assert (model.getVal(np.ones((2, 2, 3)) @ y) == np.full((2, 2, 2), 3)).all()
321+
322+
296323
def test_add_cons_matrixVar():
297324
m = Model()
298325
matrix_variable = m.addMatrixVar(shape=(3, 3), vtype="B", name="A", obj=1)
@@ -574,7 +601,7 @@ def test_matrix_matmul_return_type():
574601

575602
# test 1D @ 1D → 0D
576603
x = m.addMatrixVar(3)
577-
assert type(x @ x) is MatrixExpr
604+
assert type(np.ones(3) @ x) is Expr
578605

579606
# test 1D @ 1D → 2D
580607
assert type(x[:, None] @ x[None, :]) is MatrixExpr
@@ -584,6 +611,9 @@ def test_matrix_matmul_return_type():
584611
z = m.addMatrixVar((3, 4))
585612
assert type(y @ z) is MatrixExpr
586613

614+
# test ND @ 2D → ND
615+
assert type(np.ones((2, 4, 3)) @ z) is MatrixExpr
616+
587617

588618
def test_matrix_sum_return_type():
589619
# test #1117, require returning type is MatrixExpr not MatrixVariable

0 commit comments

Comments
 (0)