Skip to content

Commit a03b00e

Browse files
Speed up Expr * Expr (#1175)
* Refactor Expr multiplication logic and Term operator Replaces Term.__add__ with Term.__mul__ and updates Expr.__mul__ to use more efficient Cython dict iteration and item access. This improves performance and correctness when multiplying expressions, especially for large term dictionaries. * Optimize Term multiplication in expr.pxi Replaces the simple concatenation in Term.__mul__ with an efficient merge that maintains variable order based on pointer values. This improves performance and correctness when multiplying Term objects. * Update CHANGELOG to reorder quicksum optimization entry Moved the 'Speed up MatrixExpr.sum(axis=...) via quicksum' entry from the Added section to the Changed section for better categorization and clarity. * Update changelog with Expr multiplication speedup Added a new entry to the changelog noting the performance improvement for Expr * Expr operations. * Fix Term class operator signature in type stub Corrects the Term class in scip.pyi to define __mul__ instead of __add__, updating the method signature to accept and return Term objects. * Add tests for expression multiplication Introduces test_mul to verify correct string representations of multiplied expressions involving variables and constants. * Update tests for Expr multiplication behavior Replaces Term with CONST import from pyscipopt.scip and adds new assertions in test_mul to verify multiplication involving constants and variables. Removes redundant CONST definition. * Add test for commutativity in multiplication with zero Added an assertion to test that multiplying y by (x - x) yields the same zero term as (x - x) * y. This ensures correct handling of zero expressions in multiplication. * Update changelog for Expr and Term multiplication improvements Documented performance enhancements for Expr * Expr and Term * Term operations, including use of C-level API and an O(n) algorithm. Also clarified method renaming from Term.__add__ to Term.__mul__. * Add note about sorted vartuple requirement in Term Added a comment in the Term.__mul__ method to highlight that Term.vartuple must be sorted for correct merging. Suggests ensuring sorting in the Term constructor to avoid potential issues. * Clarify algorithm complexity in changelog Updated the description of the Term * Term speedup to specify the use of an O(n) sort algorithm instead of Python's O(log(n)) sorted function. * Correct complexity notation in changelog Updated the changelog to fix the time complexity notation for the Term * Term sort algorithm from O(log(n)) to O(n log(n)). * Apply suggestion from @Joao-Dionisio * Apply suggestion from @Joao-Dionisio * Fix indentation in Expr multiplication logic Corrected the indentation of the isinstance(other, Expr) block in the Expr class to ensure proper execution flow during multiplication operations. * Preserve zero-coefficient terms in Expr mul Do not skip terms with 0.0 coefficients when multiplying Expr objects: remove earlier zero-checks and compute product values inline in src/pyscipopt/expr.pxi. This causes zero-product terms to be retained in the resulting expression. Update tests (tests/test_expr.py) to expect the preserved zero-coefficient terms for cases like (x - x) * y and y * (x - x). --------- Co-authored-by: João Dionísio <57299939+Joao-Dionisio@users.noreply.github.com>
1 parent ac47132 commit a03b00e

File tree

4 files changed

+84
-14
lines changed

4 files changed

+84
-14
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
### Added
55
- Added automated script for generating type stubs
66
- Include parameter names in type stubs
7-
- Speed up MatrixExpr.sum(axis=...) via quicksum
87
- Added pre-commit hook for automatic stub regeneration (see .pre-commit-config.yaml)
98
- Wrapped isObjIntegral() and test
109
- Added structured_optimization_trace recipe for structured optimization progress tracking
@@ -20,8 +19,12 @@
2019
- Fixed segmentation fault when using Variable or Constraint objects after freeTransform() or Model destruction
2120
### Changed
2221
- changed default value of enablepricing flag to True
22+
- Speed up MatrixExpr.sum(axis=...) via quicksum
2323
- Speed up MatrixExpr.add.reduce via quicksum
2424
- Speed up np.ndarray(..., dtype=np.float64) @ MatrixExpr
25+
- Speed up Expr * Expr via C-level API and Term * Term
26+
- Speed up Term * Term via a $O(n)$ sort algorithm instead of Python $O(n\log(n))$ sorted function. `Term.__mul__` requires that Term.vartuple is sorted.
27+
- Rename from `Term.__add__` to `Term.__mul__`, due to this method only working with Expr * Expr.
2528
- MatrixExpr and MatrixExprCons use `__array_ufunc__` protocol to control all numpy.ufunc inputs and outputs
2629
- Set `__array_priority__` for MatrixExpr and MatrixExprCons
2730
- changed addConsNode() and addConsLocal() to mirror addCons() and accept ExprCons instead of Constraint

src/pyscipopt/expr.pxi

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@
4545
import math
4646
from typing import TYPE_CHECKING
4747

48-
from cpython.dict cimport PyDict_Next
48+
from cpython.dict cimport PyDict_Next, PyDict_GetItem
4949
from cpython.object cimport Py_TYPE
5050
from cpython.ref cimport PyObject
51+
from cpython.tuple cimport PyTuple_GET_ITEM
5152
from pyscipopt.scip cimport Variable, Solution
5253

5354
import numpy as np
@@ -123,9 +124,41 @@ cdef class Term:
123124
def __len__(self):
124125
return len(self.vartuple)
125126

126-
def __add__(self, other):
127-
both = self.vartuple + other.vartuple
128-
return Term(*both)
127+
def __mul__(self, Term other):
128+
# NOTE: This merge algorithm requires a sorted `Term.vartuple`.
129+
# This should be ensured in the constructor of Term.
130+
cdef int n1 = len(self)
131+
cdef int n2 = len(other)
132+
if n1 == 0: return other
133+
if n2 == 0: return self
134+
135+
cdef list vartuple = [None] * (n1 + n2)
136+
cdef int i = 0, j = 0, k = 0
137+
cdef Variable var1, var2
138+
while i < n1 and j < n2:
139+
var1 = <Variable>PyTuple_GET_ITEM(self.vartuple, i)
140+
var2 = <Variable>PyTuple_GET_ITEM(other.vartuple, j)
141+
if var1.ptr() <= var2.ptr():
142+
vartuple[k] = var1
143+
i += 1
144+
else:
145+
vartuple[k] = var2
146+
j += 1
147+
k += 1
148+
while i < n1:
149+
vartuple[k] = <Variable>PyTuple_GET_ITEM(self.vartuple, i)
150+
i += 1
151+
k += 1
152+
while j < n2:
153+
vartuple[k] = <Variable>PyTuple_GET_ITEM(other.vartuple, j)
154+
j += 1
155+
k += 1
156+
157+
cdef Term res = Term.__new__(Term)
158+
res.vartuple = tuple(vartuple)
159+
res.ptrtuple = tuple(v.ptr() for v in res.vartuple)
160+
res.hashval = <Py_ssize_t>hash(res.ptrtuple)
161+
return res
129162

130163
def __repr__(self):
131164
return 'Term(%s)' % ', '.join([str(v) for v in self.vartuple])
@@ -248,16 +281,32 @@ cdef class Expr:
248281
if isinstance(other, np.ndarray):
249282
return other * self
250283

284+
cdef dict res = {}
285+
cdef Py_ssize_t pos1 = <Py_ssize_t>0, pos2 = <Py_ssize_t>0
286+
cdef PyObject *k1_ptr = NULL
287+
cdef PyObject *v1_ptr = NULL
288+
cdef PyObject *k2_ptr = NULL
289+
cdef PyObject *v2_ptr = NULL
290+
cdef PyObject *old_v_ptr = NULL
291+
cdef Term child
292+
cdef double prod_v
293+
251294
if _is_number(other):
252295
f = float(other)
253296
return Expr({v:f*c for v,c in self.terms.items()})
297+
254298
elif isinstance(other, Expr):
255-
terms = {}
256-
for v1, c1 in self.terms.items():
257-
for v2, c2 in other.terms.items():
258-
v = v1 + v2
259-
terms[v] = terms.get(v, 0.0) + c1 * c2
260-
return Expr(terms)
299+
while PyDict_Next(self.terms, &pos1, &k1_ptr, &v1_ptr):
300+
pos2 = <Py_ssize_t>0
301+
while PyDict_Next(other.terms, &pos2, &k2_ptr, &v2_ptr):
302+
child = (<Term>k1_ptr) * (<Term>k2_ptr)
303+
prod_v = (<double>(<object>v1_ptr)) * (<double>(<object>v2_ptr))
304+
if (old_v_ptr := PyDict_GetItem(res, child)) != NULL:
305+
res[child] = <double>(<object>old_v_ptr) + prod_v
306+
else:
307+
res[child] = prod_v
308+
return Expr(res)
309+
261310
elif isinstance(other, GenExpr):
262311
return buildGenExprObj(self) * other
263312
else:

src/pyscipopt/scip.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2187,7 +2187,7 @@ class Term:
21872187
ptrtuple: Incomplete
21882188
vartuple: Incomplete
21892189
def __init__(self, *vartuple: Incomplete) -> None: ...
2190-
def __add__(self, other: Incomplete) -> Incomplete: ...
2190+
def __mul__(self, other: Term) -> Term: ...
21912191
def __eq__(self, other: object) -> bool: ...
21922192
def __ge__(self, other: object) -> bool: ...
21932193
def __getitem__(self, index: Incomplete) -> Incomplete: ...

tests/test_expr.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from pyscipopt import Model, sqrt, log, exp, sin, cos
6-
from pyscipopt.scip import Expr, GenExpr, ExprCons, Term
6+
from pyscipopt.scip import Expr, GenExpr, ExprCons, CONST
77

88

99
@pytest.fixture(scope="module")
@@ -14,7 +14,6 @@ def model():
1414
z = m.addVar("z")
1515
return m, x, y, z
1616

17-
CONST = Term()
1817

1918
def test_upgrade(model):
2019
m, x, y, z = model
@@ -220,6 +219,25 @@ def test_getVal_with_GenExpr():
220219
m.getVal(1 / z)
221220

222221

222+
def test_mul():
223+
m = Model()
224+
x = m.addVar(name="x")
225+
y = m.addVar(name="y")
226+
227+
assert str(Expr({CONST: 1.0}) * x) == "Expr({Term(x): 1.0})"
228+
assert str(y * Expr({CONST: -1.0})) == "Expr({Term(y): -1.0})"
229+
assert str((x - x) * y) == "Expr({Term(x, y): 0.0})"
230+
assert str(y * (x - x)) == "Expr({Term(x, y): 0.0})"
231+
assert (
232+
str((x + 1) * (y - 1))
233+
== "Expr({Term(x, y): 1.0, Term(x): -1.0, Term(y): 1.0, Term(): -1.0})"
234+
)
235+
assert (
236+
str((x + 1) * (x + 1) * y)
237+
== "Expr({Term(x, x, y): 1.0, Term(x, y): 2.0, Term(y): 1.0})"
238+
)
239+
240+
223241
def test_abs_abs_expr():
224242
m = Model()
225243
x = m.addVar(name="x")

0 commit comments

Comments
 (0)