Skip to content

Commit 0a63048

Browse files
Zeroto521CopilotJoao-Dionisio
authored
Speed up MatrixExpr.sum(axis=...) via quicksum (#1135)
* Add test for matrix sum return type Adds a test to ensure that summing a matrix variable along an axis returns a MatrixExpr type instead of MatrixVariable, addressing issue #1117. * return `MatrixExpr` type * Update CHANGELOG.md * Speed up `MatrixExpr.sum(axis=...)` Enhanced the MatrixExpr.sum method to accept axis as int or tuple, handle keepdims, and provide better error checking for axis bounds. This improves compatibility with numpy's sum behavior and allows more flexible summation over matrix expressions. * Test `MatrixExpr.sum(axis=...)` Removed n=200 from the sum performance test to limit test size. Added additional performance assertions comparing np.ndarray.sum and the optimized sum method for matrix variables. * Test `MatrixExpr.sum(axis=tuple(range(ndim))` Renamed test_matrix_sum_argument to test_matrix_sum_axis and updated tests to use explicit axis arguments in sum operations. This clarifies the behavior when summing over all axes and improves test coverage for axis handling. * Refactor MatrixExpr.sum axis handling with numpy utility Replaces manual axis validation and normalization in MatrixExpr.sum with numpy's normalize_axis_tuple for improved reliability and code clarity. Updates type hints and simplifies logic for summing across all axes. * Add tests for matrix sum error Added tests to verify error handling in matrix variable sum operations for invalid axis types, out-of-range values, and duplicate axes. * Add tests for matrix sum with keepdims parameter Introduces test cases to verify the shape of matrix variable sums when using the keepdims argument, ensuring correct behavior for both full and axis-specific summation. * call `.sum` via positional argument Replaces the use of the 'axis' keyword argument with a positional argument in the z.sum() method call to align with the expected function signature. * Refactor sum method to use np.apply_along_axis Replaces np.fromiter with np.apply_along_axis for summing along specified axes in MatrixExpr. This simplifies the code and improves readability. * Directly test the `.sum` result Added a comment to clarify the purpose of the test_matrix_sum_axis function, indicating it compares the result of summing a matrix variable after optimization. * Update CHANGELOG.md * Expand docstring for MatrixExpr.sum method The docstring for the MatrixExpr.sum method was updated to provide detailed information about parameters, return values, and behavior, improving clarity and alignment with numpy conventions. * Clarify MatrixExpr.sum docstring and note quicksum usage Updated the docstring for MatrixExpr.sum to specify that it uses quicksum for speed optimization instead of numpy.ndarray.sum. Added a detailed note explaining the difference between quicksum (using __iadd__) and numpy's sum (using __add__). * Split up two costing time test cases Renamed the existing performance test to clarify it tests the case where axis is None. Added a new test to measure performance when summing along a specific axis. * Supports numpy 1.x Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Supports Python 3.8 * Simplify a bit Refactored the sum method in MatrixExpr to clarify axis typing and simplify the application of np.apply_along_axis. The new implementation improves readability and maintains the intended behavior. * suggestion --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Joao-Dionisio <up201606210@up.pt> Co-authored-by: João Dionísio <57299939+Joao-Dionisio@users.noreply.github.com>
1 parent ac1465a commit 0a63048

3 files changed

Lines changed: 141 additions & 16 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
## Unreleased
44
### Added
5+
- Speed up MatrixExpr.sum(axis=...) via quicksum
56
### Fixed
67
- all fundamental callbacks now raise an error if not implemented
8+
- Fixed the type of MatrixExpr.sum(axis=...) result from MatrixVariable to MatrixExpr.
79
### Changed
810
- changed default value of enablepricing flag to True
911
### Removed

src/pyscipopt/matrix.pxi

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
# TODO Add tests
44
"""
55

6+
from typing import Optional, Tuple, Union
67
import numpy as np
7-
from typing import Union
8+
try:
9+
# NumPy 2.x location
10+
from numpy.lib.array_utils import normalize_axis_tuple
11+
except ImportError:
12+
# Fallback for NumPy 1.x
13+
from numpy.core.numeric import normalize_axis_tuple
814

915

1016
def _is_number(e):
@@ -44,16 +50,61 @@ def _matrixexpr_richcmp(self, other, op):
4450

4551

4652
class MatrixExpr(np.ndarray):
47-
def sum(self, **kwargs):
48-
"""
49-
Based on `numpy.ndarray.sum`, but returns a scalar if `axis=None`.
50-
This is useful for matrix expressions to compare with a matrix or a scalar.
53+
54+
def sum(
55+
self,
56+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
57+
keepdims: bool = False,
58+
**kwargs,
59+
) -> Union[Expr, MatrixExpr]:
5160
"""
61+
Return the sum of the array elements over the given axis.
62+
63+
Parameters
64+
----------
65+
axis : None or int or tuple of ints, optional
66+
Axis or axes along which a sum is performed. The default, axis=None, will
67+
sum all of the elements of the input array. If axis is negative it counts
68+
from the last to the first axis. If axis is a tuple of ints, a sum is
69+
performed on all of the axes specified in the tuple instead of a single axis
70+
or all the axes as before.
71+
72+
keepdims : bool, optional
73+
If this is set to True, the axes which are reduced are left in the result as
74+
dimensions with size one. With this option, the result will broadcast
75+
correctly against the input array.
76+
77+
**kwargs : ignored
78+
Additional keyword arguments are ignored. They exist for compatibility
79+
with `numpy.ndarray.sum`.
80+
81+
Returns
82+
-------
83+
Expr or MatrixExpr
84+
If the sum is performed over all axes, return an Expr, otherwise return
85+
a MatrixExpr.
5286

53-
if kwargs.get("axis") is None:
54-
# Speed up `.sum()` #1070
55-
return quicksum(self.flat)
56-
return super().sum(**kwargs)
87+
"""
88+
axis: Tuple[int, ...] = normalize_axis_tuple(
89+
range(self.ndim) if axis is None else axis, self.ndim
90+
)
91+
if len(axis) == self.ndim:
92+
res = quicksum(self.flat)
93+
return (
94+
np.array([res], dtype=object).reshape([1] * self.ndim).view(MatrixExpr)
95+
if keepdims
96+
else res
97+
)
98+
99+
keep_axes = tuple(i for i in range(self.ndim) if i not in axis)
100+
shape = (
101+
tuple(1 if i in axis else self.shape[i] for i in range(self.ndim))
102+
if keepdims
103+
else tuple(self.shape[i] for i in keep_axes)
104+
)
105+
return np.apply_along_axis(
106+
quicksum, -1, self.transpose(keep_axes + axis).reshape(shape + (-1,))
107+
).view(MatrixExpr)
57108

58109
def __le__(self, other: Union[float, int, "Expr", np.ndarray, "MatrixExpr"]) -> MatrixExprCons:
59110
return _matrixexpr_richcmp(self, other, 1)

tests/test_matrix_variable.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
sin,
2020
sqrt,
2121
)
22-
from pyscipopt.scip import GenExpr
22+
from pyscipopt.scip import CONST, GenExpr
2323

2424

2525
def test_catching_errors():
@@ -181,7 +181,30 @@ def test_expr_from_matrix_vars():
181181
for term, coeff in expr_list:
182182
assert len(term) == 3
183183

184-
def test_matrix_sum_argument():
184+
185+
def test_matrix_sum_error():
186+
m = Model()
187+
x = m.addMatrixVar((2, 3), "x", "I", ub=4)
188+
189+
# test axis type
190+
with pytest.raises(TypeError):
191+
x.sum("0")
192+
193+
# test axis value (out of range)
194+
with pytest.raises(ValueError):
195+
x.sum(2)
196+
197+
# test axis value (out of range)
198+
with pytest.raises(ValueError):
199+
x.sum((-3,))
200+
201+
# test axis value (duplicate)
202+
with pytest.raises(ValueError):
203+
x.sum((0, 0))
204+
205+
206+
def test_matrix_sum_axis():
207+
# compare the result of summing matrix variable after optimization
185208
m = Model()
186209

187210
# Return a array when axis isn't None
@@ -190,29 +213,52 @@ def test_matrix_sum_argument():
190213

191214
# compare the result of summing 2d array to a scalar with a scalar
192215
x = m.addMatrixVar((2, 3), "x", "I", ub=4)
193-
m.addMatrixCons(x.sum() == 24)
216+
# `axis=tuple(range(x.ndim))` is `axis=None`
217+
m.addMatrixCons(x.sum(axis=tuple(range(x.ndim))) == 24)
194218

195219
# compare the result of summing 2d array to 1d array
196220
y = m.addMatrixVar((2, 4), "y", "I", ub=4)
197221
m.addMatrixCons(x.sum(axis=1) == y.sum(axis=1))
198222

199223
# compare the result of summing 3d array to a 2d array with a 2d array
200224
z = m.addMatrixVar((2, 3, 4), "z", "I", ub=4)
201-
m.addMatrixCons(z.sum(axis=2) == x)
225+
m.addMatrixCons(z.sum(2) == x)
202226
m.addMatrixCons(z.sum(axis=1) == y)
203227

204228
# to fix the element values
205229
m.addMatrixCons(z == np.ones((2, 3, 4)))
206230

207-
m.setObjective(x.sum() + y.sum() + z.sum(), "maximize")
231+
m.setObjective(x.sum() + y.sum() + z.sum(tuple(range(z.ndim))), "maximize")
208232
m.optimize()
209233

210234
assert (m.getVal(x) == np.full((2, 3), 4)).all().all()
211235
assert (m.getVal(y) == np.full((2, 4), 3)).all().all()
212236

213237

214-
@pytest.mark.parametrize("n", [50, 100, 200])
215-
def test_sum_performance(n):
238+
@pytest.mark.parametrize(
239+
"axis, keepdims",
240+
[
241+
(0, False),
242+
(0, True),
243+
(1, False),
244+
(1, True),
245+
((0, 2), False),
246+
((0, 2), True),
247+
],
248+
)
249+
def test_matrix_sum_result(axis, keepdims):
250+
# directly compare the result of np.sum and MatrixExpr.sum
251+
_getVal = np.vectorize(lambda e: e.terms[CONST])
252+
a = np.arange(6).reshape((1, 2, 3))
253+
254+
np_res = a.sum(axis, keepdims=keepdims)
255+
scip_res = MatrixExpr.sum(a, axis, keepdims=keepdims)
256+
assert (np_res == _getVal(scip_res)).all()
257+
assert np_res.shape == _getVal(scip_res).shape
258+
259+
260+
@pytest.mark.parametrize("n", [50, 100])
261+
def test_matrix_sum_axis_is_none_performance(n):
216262
model = Model()
217263
x = model.addMatrixVar((n, n))
218264

@@ -229,6 +275,24 @@ def test_sum_performance(n):
229275
assert model.isGT(end_orig - start_orig, end_matrix - start_matrix)
230276

231277

278+
@pytest.mark.parametrize("n", [50, 100])
279+
def test_matrix_sum_axis_not_none_performance(n):
280+
model = Model()
281+
x = model.addMatrixVar((n, n))
282+
283+
# Original sum via `np.ndarray.sum`, `np.sum` will call subclass method
284+
start_orig = time()
285+
np.ndarray.sum(x, axis=0)
286+
end_orig = time()
287+
288+
# Optimized sum via `quicksum`
289+
start_matrix = time()
290+
x.sum(axis=0)
291+
end_matrix = time()
292+
293+
assert model.isGT(end_orig - start_orig, end_matrix - start_matrix)
294+
295+
232296
def test_add_cons_matrixVar():
233297
m = Model()
234298
matrix_variable = m.addMatrixVar(shape=(3, 3), vtype="B", name="A", obj=1)
@@ -521,6 +585,14 @@ def test_matrix_matmul_return_type():
521585
assert type(y @ z) is MatrixExpr
522586

523587

588+
def test_matrix_sum_return_type():
589+
# test #1117, require returning type is MatrixExpr not MatrixVariable
590+
m = Model()
591+
592+
x = m.addMatrixVar((3, 2))
593+
assert type(x.sum(axis=1)) is MatrixExpr
594+
595+
524596
def test_broadcast():
525597
# test #1065
526598
m = Model()

0 commit comments

Comments
 (0)