|
45 | 45 | import math |
46 | 46 | from typing import TYPE_CHECKING |
47 | 47 |
|
48 | | -from cpython.dict cimport PyDict_Next |
| 48 | +from cpython.dict cimport PyDict_Next, PyDict_GetItem |
49 | 49 | from cpython.object cimport Py_TYPE |
50 | 50 | from cpython.ref cimport PyObject |
| 51 | +from cpython.tuple cimport PyTuple_GET_ITEM |
51 | 52 | from pyscipopt.scip cimport Variable, Solution |
52 | 53 |
|
53 | 54 | import numpy as np |
@@ -123,9 +124,41 @@ cdef class Term: |
123 | 124 | def __len__(self): |
124 | 125 | return len(self.vartuple) |
125 | 126 |
|
126 | | - def __add__(self, other): |
127 | | - both = self.vartuple + other.vartuple |
128 | | - return Term(*both) |
| 127 | + def __mul__(self, Term other): |
| 128 | + # NOTE: This merge algorithm requires a sorted `Term.vartuple`. |
| 129 | + # This should be ensured in the constructor of Term. |
| 130 | + cdef int n1 = len(self) |
| 131 | + cdef int n2 = len(other) |
| 132 | + if n1 == 0: return other |
| 133 | + if n2 == 0: return self |
| 134 | + |
| 135 | + cdef list vartuple = [None] * (n1 + n2) |
| 136 | + cdef int i = 0, j = 0, k = 0 |
| 137 | + cdef Variable var1, var2 |
| 138 | + while i < n1 and j < n2: |
| 139 | + var1 = <Variable>PyTuple_GET_ITEM(self.vartuple, i) |
| 140 | + var2 = <Variable>PyTuple_GET_ITEM(other.vartuple, j) |
| 141 | + if var1.ptr() <= var2.ptr(): |
| 142 | + vartuple[k] = var1 |
| 143 | + i += 1 |
| 144 | + else: |
| 145 | + vartuple[k] = var2 |
| 146 | + j += 1 |
| 147 | + k += 1 |
| 148 | + while i < n1: |
| 149 | + vartuple[k] = <Variable>PyTuple_GET_ITEM(self.vartuple, i) |
| 150 | + i += 1 |
| 151 | + k += 1 |
| 152 | + while j < n2: |
| 153 | + vartuple[k] = <Variable>PyTuple_GET_ITEM(other.vartuple, j) |
| 154 | + j += 1 |
| 155 | + k += 1 |
| 156 | + |
| 157 | + cdef Term res = Term.__new__(Term) |
| 158 | + res.vartuple = tuple(vartuple) |
| 159 | + res.ptrtuple = tuple(v.ptr() for v in res.vartuple) |
| 160 | + res.hashval = <Py_ssize_t>hash(res.ptrtuple) |
| 161 | + return res |
129 | 162 |
|
130 | 163 | def __repr__(self): |
131 | 164 | return 'Term(%s)' % ', '.join([str(v) for v in self.vartuple]) |
@@ -248,16 +281,32 @@ cdef class Expr: |
248 | 281 | if isinstance(other, np.ndarray): |
249 | 282 | return other * self |
250 | 283 |
|
| 284 | + cdef dict res = {} |
| 285 | + cdef Py_ssize_t pos1 = <Py_ssize_t>0, pos2 = <Py_ssize_t>0 |
| 286 | + cdef PyObject *k1_ptr = NULL |
| 287 | + cdef PyObject *v1_ptr = NULL |
| 288 | + cdef PyObject *k2_ptr = NULL |
| 289 | + cdef PyObject *v2_ptr = NULL |
| 290 | + cdef PyObject *old_v_ptr = NULL |
| 291 | + cdef Term child |
| 292 | + cdef double prod_v |
| 293 | + |
251 | 294 | if _is_number(other): |
252 | 295 | f = float(other) |
253 | 296 | return Expr({v:f*c for v,c in self.terms.items()}) |
| 297 | + |
254 | 298 | elif isinstance(other, Expr): |
255 | | - terms = {} |
256 | | - for v1, c1 in self.terms.items(): |
257 | | - for v2, c2 in other.terms.items(): |
258 | | - v = v1 + v2 |
259 | | - terms[v] = terms.get(v, 0.0) + c1 * c2 |
260 | | - return Expr(terms) |
| 299 | + while PyDict_Next(self.terms, &pos1, &k1_ptr, &v1_ptr): |
| 300 | + pos2 = <Py_ssize_t>0 |
| 301 | + while PyDict_Next(other.terms, &pos2, &k2_ptr, &v2_ptr): |
| 302 | + child = (<Term>k1_ptr) * (<Term>k2_ptr) |
| 303 | + prod_v = (<double>(<object>v1_ptr)) * (<double>(<object>v2_ptr)) |
| 304 | + if (old_v_ptr := PyDict_GetItem(res, child)) != NULL: |
| 305 | + res[child] = <double>(<object>old_v_ptr) + prod_v |
| 306 | + else: |
| 307 | + res[child] = prod_v |
| 308 | + return Expr(res) |
| 309 | + |
261 | 310 | elif isinstance(other, GenExpr): |
262 | 311 | return buildGenExprObj(self) * other |
263 | 312 | else: |
|
0 commit comments