Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Fixed
- Raised an error when an expression is used when a variable is required
### Changed
- MatrixExpr.sum() could return MatrixExpr
Comment thread
Joao-Dionisio marked this conversation as resolved.
Outdated
### Removed

## 5.5.0 - 2025.05.06
Expand Down
9 changes: 7 additions & 2 deletions src/pyscipopt/matrix.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ def _is_number(e):

class MatrixExpr(np.ndarray):
def sum(self, **kwargs):
return super().sum(**kwargs).item()

"""
Based on `numpy.ndarray.sum`, but returns a scalar if the result is a single value.
This is useful for matrix expressions where the sum might reduce to a single value.
"""
res = super().sum(**kwargs)
return res if res.size > 1 else res.item()
Comment thread
Zeroto521 marked this conversation as resolved.

def __le__(self, other: Union[float, int, Variable, np.ndarray, 'MatrixExpr']) -> np.ndarray:

expr_cons_matrix = np.empty(self.shape, dtype=object)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_matrix_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,22 @@ def test_expr_from_matrix_vars():
for term, coeff in expr_list:
assert len(term) == 3

def test_matrix_sum_argument():
Comment thread
Zeroto521 marked this conversation as resolved.
m = Model()

# sum with 2d array
x = m.addMatrixVar((2, 3), "x", "I", ub=10)
m.addMatrixCons(x.sum(axis=1) == np.zeros(2))

# sum with 3d array, set axis=2
y = m.addMatrixVar((2, 3, 4), "y", "I", ub=10)
m.addMatrixCons(y.sum(axis=2) == np.zeros((2, 3)))

m.setObjective(x.sum() + y.sum(), "maximize")
m.optimize()

assert (m.getVal(x) == np.zeros((2, 3))).all().all()
assert (m.getVal(y) == np.zeros((2, 3, 4))).all().all().all()

def test_add_cons_matrixVar():
m = Model()
Expand Down
Loading