Skip to content

Commit cce6c92

Browse files
committed
Dispatch numpy ufuncs for expr functions
Replace per-function PyNumber_Check lambdas with a unified _dispatch_ufunc helper that vectorizes conversion of Python numbers to Constant via _to_const and np.frompyfunc. Removed the direct cimport of PyNumber_Check and updated exp/log/sqrt/sin/cos to call _dispatch_ufunc. _dispatch_ufunc also coerces ndarray results to MatrixExpr, centralizing numpy ufunc handling and improving array support for expression construction.
1 parent 590859d commit cce6c92

1 file changed

Lines changed: 15 additions & 6 deletions

File tree

src/pyscipopt/expr.pxi

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ from typing import TYPE_CHECKING, Literal
4848
import numpy as np
4949

5050
from cpython.dict cimport PyDict_Next, PyDict_GetItem
51-
from cpython.number cimport PyNumber_Check
5251
from cpython.object cimport Py_TYPE
5352
from cpython.ref cimport PyObject
5453
from cpython.tuple cimport PyTuple_GET_ITEM
@@ -857,11 +856,21 @@ cdef class Constant(GenExpr):
857856
return self.number
858857

859858

860-
exp = lambda x: Constant(x) if PyNumber_Check(x) else np.exp(x)
861-
log = lambda x: Constant(x) if PyNumber_Check(x) else np.log(x)
862-
sqrt = lambda x: Constant(x) if PyNumber_Check(x) else np.sqrt(x)
863-
sin = lambda x: Constant(x) if PyNumber_Check(x) else np.sin(x)
864-
cos = lambda x: Constant(x) if PyNumber_Check(x) else np.cos(x)
859+
exp = lambda x: _dispatch_ufunc(x, np.exp)
860+
log = lambda x: _dispatch_ufunc(x, np.log)
861+
log = lambda x: _dispatch_ufunc(x, np.log)
862+
sqrt = lambda x: _dispatch_ufunc(x, np.sqrt)
863+
sin = lambda x: _dispatch_ufunc(x, np.sin)
864+
cos = lambda x: _dispatch_ufunc(x, np.cos)
865+
866+
cdef inline object _to_const(object x):
867+
return Constant(<double>x) if _is_number(x) else x
868+
869+
cdef object _vec_const = np.frompyfunc(_to_const, 1, 1)
870+
871+
cdef inline object _dispatch_ufunc(object x, object ufunc):
872+
res = ufunc(_vec_const(x))
873+
return res.view(MatrixExpr) if isinstance(res, np.ndarray) else res
865874

866875

867876
def expr_to_nodes(expr):

0 commit comments

Comments
 (0)