Skip to content

Commit 4c3391d

Browse files
authored
Speed up MatrixExpr.add.reduce via quicksum (#1157)
* Speed up `np.add.reduce` Implements the __array_ufunc__ method in MatrixExpr to handle numpy ufuncs, specifically enabling correct behavior for reductions like np.add.reduce by delegating to the sum method. This improves compatibility with numpy operations. * BUG: `MatrixExpr.mean(axis=1)` will crush kernel Moved the sum computation logic from MatrixExpr.sum and __array_ufunc__ into a new _core_sum function for better code reuse and maintainability. * Add tests for matrix mean performance and type Introduces tests to compare the performance of the mean operation on matrix variables and checks the return types of mean with and without axis argument. * Update changelog for MatrixExpr.add.reduce optimization Documented the speed improvement for MatrixExpr.add.reduce using quicksum in the changelog. * Remove MatrixExpr.sum method and update _core_sum docstring The sum method was removed from the MatrixExpr class, consolidating summation logic in the _core_sum function. The docstring for _core_sum was expanded to include detailed parameter and return value descriptions, improving code clarity and maintainability. * Remove sum method from MatrixExpr stub Deleted the type stub for the sum method in MatrixExpr, likely to reflect changes in the underlying implementation or to correct type information. * Improve MatrixExpr ufunc handling and remove unused include Enhanced the __array_ufunc__ method in MatrixExpr to ensure proper array conversion and consistent return types. Added the _ensure_array helper for argument handling. Also removed an unused include of matrix.pxi from expr.pxi. * Fix MatrixExpr matmul return type and update tests Updated MatrixExpr.__matmul__ to return the correct type when the result is not an ndarray. Adjusted tests to reflect the expected return type for 1D matrix multiplication and improved performance test timing logic. * Refactor matrix sum performance tests timing logic Simplifies timing measurement in matrix sum performance tests by directly calculating elapsed time instead of storing start and end times separately. This improves code readability and reduces variable usage. * Update matmul return type assertion in test Changed the expected type of 1D @ 1D matrix multiplication from MatrixExpr to Expr in test_matrix_matmul_return_type to reflect updated behavior. * Refactor matrix variable tests to use view casting Updated tests to use x.view(MatrixExpr) and x.view(np.ndarray) instead of direct subclass method calls. This clarifies the intent and ensures the correct method resolution for sum and mean operations in performance and result comparison tests. * define the variable * Refactor __array_ufunc__ in MatrixExpr for clarity Reorganize logic in MatrixExpr.__array_ufunc__ to handle the 'reduce' method and argument conversion more clearly. This improves readability and ensures correct handling of the 'out' keyword and argument conversion only when necessary. * Add type hints and docstring to __array_ufunc__ in MatrixExpr Enhanced the MatrixExpr.__array_ufunc__ method with detailed type annotations and a comprehensive docstring. This improves code clarity and developer experience when working with custom NumPy ufunc behavior. * Add comment clarifying 'reduce' method handling A comment was added to explain that the 'reduce' method handles reductions like np.sum(a), improving code readability. * Refactor matrix sum test for clarity and type consistency Updated test_matrix_sum_result to extract and cast the result of the MatrixExpr sum operation before assertions. This improves clarity and ensures type consistency between the expected and actual results. * Fix handling of 'out' kwarg in __array_ufunc__ Updates the __array_ufunc__ method in MatrixExpr to ensure all elements in the 'out' tuple are unboxed with _ensure_array, preventing recursion issues when 'out' is provided as a tuple. * Fix comments to reference mean instead of sum Updated comments in test_matrix_mean_performance to correctly refer to 'mean' instead of 'sum', reflecting the actual operations being performed in the test. * Pass 'False' to _ensure_array
1 parent 149b057 commit 4c3391d

File tree

4 files changed

+116
-82
lines changed

4 files changed

+116
-82
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- Fixed incorrect getVal() result when _bestSol.sol was outdated
1515
### Changed
1616
- changed default value of enablepricing flag to True
17+
- Speed up MatrixExpr.add.reduce via quicksum
1718
- Speed up np.ndarray(..., dtype=np.float64) @ MatrixExpr
1819
- Minimum numpy version increased from 1.16.0 to 1.19.0
1920
### Removed

src/pyscipopt/matrix.pxi

Lines changed: 71 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# TODO Cythonize things. Improve performance.
33
# TODO Add tests
44
"""
5-
65
from typing import Literal, Optional, Tuple, Union
76
import numpy as np
87
try:
@@ -71,7 +70,7 @@ class MatrixExpr(np.ndarray):
7170
The ufunc object that was called.
7271
7372
method : {"__call__", "reduce", "reduceat", "accumulate", "outer", "at"}
74-
A string indicating which UFunc method was called.
73+
A string indicating which ufunc method was called.
7574
7675
*args : tuple
7776
The input arguments to the ufunc.
@@ -81,77 +80,28 @@ class MatrixExpr(np.ndarray):
8180
8281
Returns
8382
-------
84-
Expr, GenExpr, MatrixExpr
83+
Expr, MatrixExpr
8584
The result of the ufunc operation is wrapped back into a MatrixExpr if
8685
applicable.
8786
8887
"""
8988
res = NotImplemented
89+
# Unboxing MatrixExpr to stop __array_ufunc__ recursion
90+
args = tuple(_ensure_array(arg) for arg in args)
91+
if method == "reduce": # Handle reduction operations, e.g., np.sum(a)
92+
if ufunc is np.add:
93+
res = _core_sum(args[0], **kwargs)
94+
9095
if method == "__call__": # Standard ufunc call, e.g., np.add(a, b)
9196
if ufunc in {np.matmul, np.dot}:
92-
res = _core_dot(_ensure_array(args[0]), _ensure_array(args[1]))
97+
res = _core_dot(args[0], args[1])
9398

9499
if res is NotImplemented:
95-
# Unboxing MatrixExpr to stop __array_ufunc__ recursion
96-
args = tuple(_ensure_array(arg) for arg in args)
100+
if "out" in kwargs: # Unboxing MatrixExpr to stop __array_ufunc__ recursion
101+
kwargs["out"] = tuple(_ensure_array(arg, False) for arg in kwargs["out"])
97102
res = super().__array_ufunc__(ufunc, method, *args, **kwargs)
98103
return res.view(MatrixExpr) if isinstance(res, np.ndarray) else res
99104

100-
def sum(
101-
self,
102-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
103-
keepdims: bool = False,
104-
**kwargs,
105-
) -> Union[Expr, MatrixExpr]:
106-
"""
107-
Return the sum of the array elements over the given axis.
108-
109-
Parameters
110-
----------
111-
axis : None or int or tuple of ints, optional
112-
Axis or axes along which a sum is performed. The default, axis=None, will
113-
sum all of the elements of the input array. If axis is negative it counts
114-
from the last to the first axis. If axis is a tuple of ints, a sum is
115-
performed on all of the axes specified in the tuple instead of a single axis
116-
or all the axes as before.
117-
118-
keepdims : bool, optional
119-
If this is set to True, the axes which are reduced are left in the result as
120-
dimensions with size one. With this option, the result will broadcast
121-
correctly against the input array.
122-
123-
**kwargs : ignored
124-
Additional keyword arguments are ignored. They exist for compatibility
125-
with `numpy.ndarray.sum`.
126-
127-
Returns
128-
-------
129-
Expr or MatrixExpr
130-
If the sum is performed over all axes, return an Expr, otherwise return
131-
a MatrixExpr.
132-
133-
"""
134-
axis: Tuple[int, ...] = normalize_axis_tuple(
135-
range(self.ndim) if axis is None else axis, self.ndim
136-
)
137-
if len(axis) == self.ndim:
138-
res = quicksum(self.flat)
139-
return (
140-
np.array([res], dtype=object).reshape([1] * self.ndim).view(MatrixExpr)
141-
if keepdims
142-
else res
143-
)
144-
145-
keep_axes = tuple(i for i in range(self.ndim) if i not in axis)
146-
shape = (
147-
tuple(1 if i in axis else self.shape[i] for i in range(self.ndim))
148-
if keepdims
149-
else tuple(self.shape[i] for i in keep_axes)
150-
)
151-
return np.apply_along_axis(
152-
quicksum, -1, self.transpose(keep_axes + axis).reshape(shape + (-1,))
153-
).view(MatrixExpr)
154-
155105
def __le__(self, other: Union[float, int, "Expr", np.ndarray, "MatrixExpr"]) -> MatrixExprCons:
156106
return _matrixexpr_richcmp(self, other, 1)
157107

@@ -195,6 +145,7 @@ class MatrixExpr(np.ndarray):
195145
class MatrixGenExpr(MatrixExpr):
196146
pass
197147

148+
198149
class MatrixExprCons(np.ndarray):
199150

200151
def __le__(self, other: Union[float, int, np.ndarray]) -> MatrixExprCons:
@@ -288,3 +239,62 @@ def _core_dot_2d(cnp.ndarray a, cnp.ndarray x) -> np.ndarray:
288239
res[i, j] = quicksum(a_view[i, idx] * x[idx, j] for idx in nonzero)
289240

290241
return res
242+
243+
244+
def _core_sum(
245+
cnp.ndarray a,
246+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
247+
keepdims: bool = False,
248+
**kwargs,
249+
) -> Union[Expr, np.ndarray]:
250+
"""
251+
Return the sum of the array elements over the given axis.
252+
253+
Parameters
254+
----------
255+
a : cnp.ndarray
256+
A `np.ndarray` of type `object` and containing `Expr` objects.
257+
258+
axis : None or int or tuple of ints, optional
259+
Axis or axes along which a sum is performed. The default, axis=None, will
260+
sum all of the elements of the input array. If axis is negative it counts
261+
from the last to the first axis. If axis is a tuple of ints, a sum is
262+
performed on all of the axes specified in the tuple instead of a single axis
263+
or all the axes as before.
264+
265+
keepdims : bool, optional
266+
If this is set to True, the axes which are reduced are left in the result as
267+
dimensions with size one. With this option, the result will broadcast
268+
correctly against the input array.
269+
270+
**kwargs : ignored
271+
Additional keyword arguments are ignored. They exist for compatibility
272+
with `numpy.ndarray.sum`.
273+
274+
Returns
275+
-------
276+
Expr or np.ndarray
277+
If the sum is performed over all axes, return an Expr, otherwise return
278+
a np.ndarray.
279+
280+
"""
281+
axis: Tuple[int, ...] = normalize_axis_tuple(
282+
range(a.ndim) if axis is None else axis, a.ndim
283+
)
284+
if len(axis) == a.ndim:
285+
res = quicksum(a.flat)
286+
return (
287+
np.array([res], dtype=object).reshape([1] * a.ndim)
288+
if keepdims
289+
else res
290+
)
291+
292+
keep_axes = tuple(i for i in range(a.ndim) if i not in axis)
293+
shape = (
294+
tuple(1 if i in axis else a.shape[i] for i in range(a.ndim))
295+
if keepdims
296+
else tuple(a.shape[i] for i in keep_axes)
297+
)
298+
return np.apply_along_axis(
299+
quicksum, -1, a.transpose(keep_axes + axis).reshape(shape + (-1,))
300+
)

src/pyscipopt/scip.pyi

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,6 @@ class MatrixConstraint(numpy.ndarray):
509509
def isStickingAtNode(self) -> Incomplete: ...
510510

511511
class MatrixExpr(numpy.ndarray):
512-
def sum( # type: ignore[override]
513-
self, axis: Incomplete = ..., keepdims: Incomplete = ..., **kwargs: Incomplete
514-
) -> Incomplete: ...
515512
def __add__(self, other: Incomplete) -> Incomplete: ...
516513
def __eq__(self, other: Incomplete)-> Incomplete: ...
517514
def __ge__(self, other: Incomplete) -> MatrixExprCons: ...

tests/test_matrix_variable.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -248,49 +248,75 @@ def test_matrix_sum_axis():
248248
)
249249
def test_matrix_sum_result(axis, keepdims):
250250
# directly compare the result of np.sum and MatrixExpr.sum
251-
_getVal = np.vectorize(lambda e: e.terms[CONST])
251+
_getVal = np.vectorize(lambda e: e[CONST])
252252
a = np.arange(6).reshape((1, 2, 3))
253253

254254
np_res = a.sum(axis, keepdims=keepdims)
255-
scip_res = MatrixExpr.sum(a, axis, keepdims=keepdims)
256-
assert (np_res == _getVal(scip_res)).all()
257-
assert np_res.shape == _getVal(scip_res).shape
255+
scip_res = _getVal(a.view(MatrixExpr).sum(axis, keepdims=keepdims)).view(np.ndarray)
256+
assert (np_res == scip_res).all()
257+
assert np_res.shape == scip_res.shape
258258

259259

260260
@pytest.mark.parametrize("n", [50, 100])
261261
def test_matrix_sum_axis_is_none_performance(n):
262262
model = Model()
263263
x = model.addMatrixVar((n, n))
264264

265-
# Original sum via `np.ndarray.sum`, `np.sum` will call subclass method
266-
start_orig = time()
267-
np.ndarray.sum(x)
268-
end_orig = time()
265+
# Original sum via `np.ndarray.sum`
266+
start = time()
267+
x.view(np.ndarray).sum()
268+
orig = time() - start
269269

270270
# Optimized sum via `quicksum`
271-
start_matrix = time()
271+
start = time()
272272
x.sum()
273-
end_matrix = time()
273+
matrix = time() - start
274274

275-
assert model.isGT(end_orig - start_orig, end_matrix - start_matrix)
275+
assert model.isGT(orig, matrix)
276276

277277

278278
@pytest.mark.parametrize("n", [50, 100])
279279
def test_matrix_sum_axis_not_none_performance(n):
280280
model = Model()
281281
x = model.addMatrixVar((n, n))
282282

283-
# Original sum via `np.ndarray.sum`, `np.sum` will call subclass method
284-
start_orig = time()
285-
np.ndarray.sum(x, axis=0)
286-
end_orig = time()
283+
# Original sum via `np.ndarray.sum`
284+
start = time()
285+
x.view(np.ndarray).sum(axis=0)
286+
orig = time() - start
287287

288288
# Optimized sum via `quicksum`
289-
start_matrix = time()
289+
start = time()
290290
x.sum(axis=0)
291-
end_matrix = time()
291+
matrix = time() - start
292+
293+
assert model.isGT(orig, matrix)
294+
295+
296+
@pytest.mark.parametrize("n", [50, 100])
297+
def test_matrix_mean_performance(n):
298+
model = Model()
299+
x = model.addMatrixVar((n, n))
300+
301+
# Original mean via `np.ndarray.mean`
302+
start = time()
303+
x.view(np.ndarray).mean(axis=0)
304+
orig = time() - start
305+
306+
# Optimized mean via `quicksum`
307+
start = time()
308+
x.mean(axis=0)
309+
matrix = time() - start
310+
311+
assert model.isGT(orig, matrix)
312+
313+
314+
def test_matrix_mean():
315+
model = Model()
316+
x = model.addMatrixVar((2, 2))
292317

293-
assert model.isGT(end_orig - start_orig, end_matrix - start_matrix)
318+
assert isinstance(x.mean(), Expr)
319+
assert isinstance(x.mean(1), MatrixExpr)
294320

295321

296322
@pytest.mark.parametrize("n", [50, 100])

0 commit comments

Comments
 (0)